Skip to content

Commit 01536e8

Browse files
authored
Adding more unit tests for ChannelHolder class (#8668)
1 parent 12a3cea commit 01536e8

File tree

1 file changed

+338
-0
lines changed

1 file changed

+338
-0
lines changed

paddle/fluid/framework/channel_test.cc

Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,341 @@ TEST(ChannelHolder, ChannelHolderUnBufferedSendReceiveTest) {
542542
ChannelHolderSendReceive(ch);
543543
delete ch;
544544
}
545+
546+
TEST(ChannelHolder, ChannelUninitializedTest) {
547+
ChannelHolder *ch = new ChannelHolder();
548+
EXPECT_EQ(ch->IsInitialized(), false);
549+
int i = 10;
550+
EXPECT_EQ(ch->Send(&i), false);
551+
EXPECT_EQ(ch->Receive(&i), false);
552+
bool is_exception = false;
553+
try {
554+
ch->Type();
555+
} catch (paddle::platform::EnforceNotMet e) {
556+
is_exception = true;
557+
}
558+
EXPECT_EQ(is_exception, true);
559+
delete ch;
560+
}
561+
562+
TEST(ChannelHolder, ChannelInitializedTest) {
563+
ChannelHolder *ch = new ChannelHolder();
564+
ch->Reset<int>(2);
565+
EXPECT_EQ(ch->IsInitialized(), true);
566+
// Channel should remain intialized even after close
567+
ch->close();
568+
EXPECT_EQ(ch->IsInitialized(), true);
569+
delete ch;
570+
}
571+
572+
TEST(ChannelHolder, TypeMismatchSendTest) {
573+
// Test with unbuffered channel
574+
ChannelHolder *ch = new ChannelHolder();
575+
ch->Reset<int>(0);
576+
bool is_exception = false;
577+
bool boolean_data = true;
578+
try {
579+
ch->Send(&boolean_data);
580+
} catch (paddle::platform::EnforceNotMet e) {
581+
is_exception = true;
582+
}
583+
EXPECT_EQ(is_exception, true);
584+
delete ch;
585+
586+
// Test with Buffered Channel
587+
ch = new ChannelHolder();
588+
ch->Reset<float>(10);
589+
is_exception = false;
590+
int int_data = 23;
591+
try {
592+
ch->Send(&int_data);
593+
} catch (paddle::platform::EnforceNotMet e) {
594+
is_exception = true;
595+
}
596+
EXPECT_EQ(is_exception, true);
597+
delete ch;
598+
}
599+
600+
TEST(ChannelHolder, TypeMismatchReceiveTest) {
601+
// Test with unbuffered channel
602+
ChannelHolder *ch = new ChannelHolder();
603+
ch->Reset<int>(0);
604+
bool is_exception = false;
605+
bool float_data;
606+
try {
607+
ch->Receive(&float_data);
608+
} catch (paddle::platform::EnforceNotMet e) {
609+
is_exception = true;
610+
}
611+
EXPECT_EQ(is_exception, true);
612+
delete ch;
613+
614+
// Test with Buffered Channel
615+
ch = new ChannelHolder();
616+
ch->Reset<float>(10);
617+
is_exception = false;
618+
int int_data = 23;
619+
try {
620+
ch->Receive(&int_data);
621+
} catch (paddle::platform::EnforceNotMet e) {
622+
is_exception = true;
623+
}
624+
EXPECT_EQ(is_exception, true);
625+
delete ch;
626+
}
627+
628+
void ChannelHolderCloseUnblocksReceiversTest(ChannelHolder *ch) {
629+
size_t num_threads = 5;
630+
std::thread t[num_threads];
631+
bool thread_ended[num_threads];
632+
633+
// Launches threads that try to read and are blocked because of no writers
634+
for (size_t i = 0; i < num_threads; i++) {
635+
thread_ended[i] = false;
636+
t[i] = std::thread(
637+
[&](bool *p) {
638+
int data;
639+
EXPECT_EQ(ch->Receive(&data), false);
640+
*p = true;
641+
},
642+
&thread_ended[i]);
643+
}
644+
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec
645+
646+
// Verify that all the threads are blocked
647+
for (size_t i = 0; i < num_threads; i++) {
648+
EXPECT_EQ(thread_ended[i], false);
649+
}
650+
651+
// Explicitly close the channel
652+
// This should unblock all receivers
653+
ch->close();
654+
655+
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec
656+
657+
// Verify that all threads got unblocked
658+
for (size_t i = 0; i < num_threads; i++) {
659+
EXPECT_EQ(thread_ended[i], true);
660+
}
661+
662+
for (size_t i = 0; i < num_threads; i++) t[i].join();
663+
}
664+
665+
void ChannelHolderCloseUnblocksSendersTest(ChannelHolder *ch, bool isBuffered) {
666+
using paddle::framework::details::Buffered;
667+
using paddle::framework::details::UnBuffered;
668+
669+
size_t num_threads = 5;
670+
std::thread t[num_threads];
671+
bool thread_ended[num_threads];
672+
bool send_success[num_threads];
673+
674+
// Launches threads that try to write and are blocked because of no readers
675+
for (size_t i = 0; i < num_threads; i++) {
676+
thread_ended[i] = false;
677+
send_success[i] = false;
678+
t[i] = std::thread(
679+
[&](bool *ended, bool *success) {
680+
int data = 10;
681+
*success = ch->Send(&data);
682+
*ended = true;
683+
},
684+
&thread_ended[i], &send_success[i]);
685+
}
686+
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
687+
688+
if (isBuffered) {
689+
// If ch is Buffered, atleast 4 threads must be blocked.
690+
int ct = 0;
691+
for (size_t i = 0; i < num_threads; i++) {
692+
if (!thread_ended[i]) ct++;
693+
}
694+
EXPECT_GE(ct, 4);
695+
} else {
696+
// If ch is UnBuffered, all the threads should be blocked.
697+
for (size_t i = 0; i < num_threads; i++) {
698+
EXPECT_EQ(thread_ended[i], false);
699+
}
700+
}
701+
// Explicitly close the thread
702+
// This should unblock all senders
703+
ch->close();
704+
705+
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
706+
707+
// Verify that all threads got unblocked
708+
for (size_t i = 0; i < num_threads; i++) {
709+
EXPECT_EQ(thread_ended[i], true);
710+
}
711+
712+
if (isBuffered) {
713+
// Verify that only 1 send was successful
714+
int ct = 0;
715+
for (size_t i = 0; i < num_threads; i++) {
716+
if (send_success[i]) ct++;
717+
}
718+
// Only 1 send must be successful
719+
EXPECT_EQ(ct, 1);
720+
}
721+
722+
for (size_t i = 0; i < num_threads; i++) t[i].join();
723+
}
724+
725+
// This tests that closing a channelholder unblocks
726+
// any receivers waiting on the channel
727+
TEST(ChannelHolder, ChannelHolderCloseUnblocksReceiversTest) {
728+
// Check for buffered channel
729+
ChannelHolder *ch = new ChannelHolder();
730+
ch->Reset<int>(1);
731+
ChannelHolderCloseUnblocksReceiversTest(ch);
732+
delete ch;
733+
734+
// Check for unbuffered channel
735+
ch = new ChannelHolder();
736+
ch->Reset<int>(0);
737+
ChannelHolderCloseUnblocksReceiversTest(ch);
738+
delete ch;
739+
}
740+
741+
// This tests that closing a channelholder unblocks
742+
// any senders waiting for channel to have write space
743+
TEST(Channel, ChannelHolderCloseUnblocksSendersTest) {
744+
// Check for buffered channel
745+
ChannelHolder *ch = new ChannelHolder();
746+
ch->Reset<int>(1);
747+
ChannelHolderCloseUnblocksSendersTest(ch, true);
748+
delete ch;
749+
750+
// Check for unbuffered channel
751+
ch = new ChannelHolder();
752+
ch->Reset<int>(0);
753+
ChannelHolderCloseUnblocksSendersTest(ch, false);
754+
delete ch;
755+
}
756+
757+
// This tests that destroying a channelholder unblocks
758+
// any senders waiting for channel
759+
void ChannelHolderDestroyUnblockSenders(ChannelHolder *ch, bool isBuffered) {
760+
size_t num_threads = 5;
761+
std::thread t[num_threads];
762+
bool thread_ended[num_threads];
763+
bool send_success[num_threads];
764+
765+
// Launches threads that try to write and are blocked because of no readers
766+
for (size_t i = 0; i < num_threads; i++) {
767+
thread_ended[i] = false;
768+
send_success[i] = false;
769+
t[i] = std::thread(
770+
[&](bool *ended, bool *success) {
771+
int data = 10;
772+
*success = ch->Send(&data);
773+
*ended = true;
774+
},
775+
&thread_ended[i], &send_success[i]);
776+
}
777+
778+
std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec
779+
if (isBuffered) {
780+
// If channel is buffered, verify that atleast 4 threads are blocked
781+
int ct = 0;
782+
for (size_t i = 0; i < num_threads; i++) {
783+
if (thread_ended[i] == false) ct++;
784+
}
785+
// Atleast 4 threads must be blocked
786+
EXPECT_GE(ct, 4);
787+
} else {
788+
// Verify that all the threads are blocked
789+
for (size_t i = 0; i < num_threads; i++) {
790+
EXPECT_EQ(thread_ended[i], false);
791+
}
792+
}
793+
// Explicitly destroy the channel
794+
delete ch;
795+
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
796+
797+
// Verify that all threads got unblocked
798+
for (size_t i = 0; i < num_threads; i++) {
799+
EXPECT_EQ(thread_ended[i], true);
800+
}
801+
802+
// Count number of successfuld sends
803+
int ct = 0;
804+
for (size_t i = 0; i < num_threads; i++) {
805+
if (send_success[i]) ct++;
806+
}
807+
808+
if (isBuffered) {
809+
// Only 1 send must be successful
810+
EXPECT_EQ(ct, 1);
811+
} else {
812+
// In unbuffered channel, no send should be successful
813+
EXPECT_EQ(ct, 0);
814+
}
815+
816+
// Join all threads
817+
for (size_t i = 0; i < num_threads; i++) t[i].join();
818+
}
819+
820+
// This tests that destroying a channelholder also unblocks
821+
// any receivers waiting on the channel
822+
void ChannelHolderDestroyUnblockReceivers(ChannelHolder *ch) {
823+
size_t num_threads = 5;
824+
std::thread t[num_threads];
825+
bool thread_ended[num_threads];
826+
827+
// Launches threads that try to read and are blocked because of no writers
828+
for (size_t i = 0; i < num_threads; i++) {
829+
thread_ended[i] = false;
830+
t[i] = std::thread(
831+
[&](bool *p) {
832+
int data;
833+
// All reads should return false
834+
EXPECT_EQ(ch->Receive(&data), false);
835+
*p = true;
836+
},
837+
&thread_ended[i]);
838+
}
839+
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
840+
841+
// Verify that all threads are blocked
842+
for (size_t i = 0; i < num_threads; i++) {
843+
EXPECT_EQ(thread_ended[i], false);
844+
}
845+
// delete the channel
846+
delete ch;
847+
std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
848+
// Verify that all threads got unblocked
849+
for (size_t i = 0; i < num_threads; i++) {
850+
EXPECT_EQ(thread_ended[i], true);
851+
}
852+
853+
for (size_t i = 0; i < num_threads; i++) t[i].join();
854+
}
855+
856+
TEST(ChannelHolder, ChannelHolderDestroyUnblocksReceiversTest) {
857+
// Check for Buffered Channel
858+
ChannelHolder *ch = new ChannelHolder();
859+
ch->Reset<int>(1);
860+
ChannelHolderDestroyUnblockReceivers(ch);
861+
// ch is already deleted already deleted in
862+
// ChannelHolderDestroyUnblockReceivers
863+
864+
// Check for Unbuffered channel
865+
ch = new ChannelHolder();
866+
ch->Reset<int>(0);
867+
ChannelHolderDestroyUnblockReceivers(ch);
868+
}
869+
870+
TEST(ChannelHolder, ChannelHolderDestroyUnblocksSendersTest) {
871+
// Check for Buffered Channel
872+
ChannelHolder *ch = new ChannelHolder();
873+
ch->Reset<int>(1);
874+
ChannelHolderDestroyUnblockSenders(ch, true);
875+
// ch is already deleted already deleted in
876+
// ChannelHolderDestroyUnblockReceivers
877+
878+
// Check for Unbuffered channel
879+
ch = new ChannelHolder();
880+
ch->Reset<int>(0);
881+
ChannelHolderDestroyUnblockSenders(ch, false);
882+
}

0 commit comments

Comments
 (0)