@@ -114,7 +114,7 @@ void *testAllReduceThread(void *arg) {
114114 TEST_INFINI_THREAD (infinirtMalloc (&buf, args->count * infiniSizeOf (args->dtype )));
115115 TEST_INFINI_THREAD (infinirtMemcpy (buf, args->data , args->count * infiniSizeOf (args->dtype ), INFINIRT_MEMCPY_H2D));
116116 TEST_INFINI_THREAD (infinicclAllReduce (buf, buf, args->count , args->dtype , INFINICCL_SUM, args->comm , stream));
117- TEST_INFINI_THREAD (infinirtDeviceSynchronize ( ));
117+ TEST_INFINI_THREAD (infinirtStreamSynchronize (stream ));
118118 TEST_INFINI_THREAD (infinirtMemcpy (output, buf, args->count * infiniSizeOf (args->dtype ), INFINIRT_MEMCPY_D2H));
119119
120120 if (checkData (output, args->ans , args->dtype , args->count ) != 0 ) {
@@ -126,14 +126,14 @@ void *testAllReduceThread(void *arg) {
126126 for (size_t i = 0 ; i < WARM_UPS; i++) {
127127 TEST_INFINI_THREAD (infinicclAllReduce (buf, buf, args->count , args->dtype , INFINICCL_SUM, args->comm , stream));
128128 }
129- TEST_INFINI_THREAD (infinirtDeviceSynchronize ( ));
129+ TEST_INFINI_THREAD (infinirtStreamSynchronize (stream ));
130130
131131 // measure time
132132 auto start = std::chrono::high_resolution_clock::now ();
133133 for (size_t i = 0 ; i < ITERATIONS; i++) {
134134 TEST_INFINI_THREAD (infinicclAllReduce (buf, buf, args->count , args->dtype , INFINICCL_SUM, args->comm , stream));
135135 }
136- TEST_INFINI_THREAD (infinirtDeviceSynchronize ( ));
136+ TEST_INFINI_THREAD (infinirtStreamSynchronize (stream ));
137137 auto end = std::chrono::high_resolution_clock::now ();
138138 double elapsed_ms = std::chrono::duration<double , std::milli>(end - start).count ();
139139 *args->time = elapsed_ms / ITERATIONS;
@@ -159,12 +159,12 @@ int testAllReduce(infiniDevice_t device_type, int ndevice) {
159159 for (int i = 0 ; i < ndevice; i++) {
160160 device_ids[i] = i;
161161 }
162- TEST_INFINI (infinicclCommInitAll (device_type, comms.data (), ndevice, device_ids.data ()));
163162
164163 for (infiniDtype_t dtype : TEST_DTYPES) {
165164 setData (dtype, data, MAX_COUNT, 1 .0f );
166165 setData (dtype, ans, MAX_COUNT, 1 .0f * ndevice);
167166 for (size_t count : TEST_COUNTS) {
167+ TEST_INFINI (infinicclCommInitAll (device_type, comms.data (), ndevice, device_ids.data ()));
168168 std::cout << " Testing AllReduce with " << count << " elements of " << infiniDtypeToString (dtype) << std::endl;
169169 for (int rank = 0 ; rank < ndevice; rank++) {
170170 thread_args[rank] = {rank, device_ids[rank], comms[rank], device_type, dtype, count, data, ans, &results[rank], ×[rank]};
0 commit comments