@@ -39,7 +39,35 @@ rocshmem_team_t team_world_dup[NUM_TEAMS];
3939/* *****************************************************************************
4040 * DEVICE TEST KERNEL
4141 *****************************************************************************/
42- __global__ void TeamCtxInfraTest (ShmemContextType ctx_type,
42+ __global__ void TeamCtxInfraSimpleTest (ShmemContextType ctx_type,
43+ rocshmem_team_t team,
44+ int expected_pe, int expected_n_pes) {
45+ __shared__ rocshmem_ctx_t ctx;
46+
47+ rocshmem_wg_init ();
48+ rocshmem_wg_team_create_ctx (team, ctx_type, &ctx);
49+
50+ int num_pes = rocshmem_ctx_n_pes (ctx);
51+ int my_pe = rocshmem_ctx_my_pe (ctx);
52+
53+ if (my_pe != expected_pe) {
54+ printf (" PE doesn't match. Expected %d got %d\n " , expected_pe, my_pe);
55+ abort ();
56+ }
57+
58+ if (num_pes != expected_n_pes) {
59+ printf (" Team size doesn't match. Expected %d got %d\n " , expected_n_pes, num_pes);
60+ abort ();
61+ }
62+
63+ __syncthreads ();
64+
65+ rocshmem_ctx_quiet (ctx);
66+ rocshmem_wg_ctx_destroy (&ctx);
67+ rocshmem_wg_finalize ();
68+ }
69+
70+ __global__ void TeamCtxInfraTest (ShmemContextType ctx_type,
4371 rocshmem_team_t *team) {
4472 __shared__ rocshmem_ctx_t ctx1, ctx2, ctx3;
4573 __shared__ rocshmem_ctx_t ctx[NUM_TEAMS];
@@ -109,42 +137,105 @@ __global__ void TeamCtxInfraTest(ShmemContextType ctx_type,
109137/* *****************************************************************************
110138 * HOST TESTER CLASS METHODS
111139 *****************************************************************************/
112- TeamCtxInfraTester::TeamCtxInfraTester (TesterArguments args) : Tester(args) {}
140+ TeamCtxInfraTester::TeamCtxInfraTester (TesterArguments args) : Tester(args) {
141+ _splitType = args.team_type ;
142+ }
113143
114144TeamCtxInfraTester::~TeamCtxInfraTester () {}
115145
116146void TeamCtxInfraTester::resetBuffers (size_t size) {}
117147
118148void TeamCtxInfraTester::preLaunchKernel () {
119- int n_pes = rocshmem_team_n_pes (ROCSHMEM_TEAM_WORLD);
120-
121- // validate we can run the test
122- if (auto maximum_num_contexts_str = getenv (" ROCSHMEM_MAX_NUM_CONTEXTS" )) {
123- int max_ctx = atoi (maximum_num_contexts_str);
124- if (max_ctx <= NUM_TEAMS) {
125- printf (" ROCSHMEM_MAX_NUM_CONTEXTS=%d is smaller than NUM_TEAMS %d, invalid test setup!\n " , max_ctx, NUM_TEAMS);
126- assert (max_ctx > NUM_TEAMS);
149+ int n_pes = rocshmem_team_n_pes (_parentTeam);
150+ int my_pe = rocshmem_team_my_pe (_parentTeam);
151+
152+ if (_splitType == ROCSHMEM_TEST_TEAM_DUP) {
153+ // validate we can run the test
154+ if (auto maximum_num_contexts_str = getenv (" ROCSHMEM_MAX_NUM_CONTEXTS" )) {
155+ int max_ctx = atoi (maximum_num_contexts_str);
156+ if (max_ctx <= NUM_TEAMS) {
157+ printf (" ROCSHMEM_MAX_NUM_CONTEXTS=%d is smaller than NUM_TEAMS %d, invalid test setup!\n " , max_ctx, NUM_TEAMS);
158+ assert (max_ctx > NUM_TEAMS);
159+ abort ();
160+ }
161+ }
162+
163+ for (int team_i = 0 ; team_i < NUM_TEAMS; team_i++) {
164+ team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
165+ rocshmem_team_split_strided (_parentTeam, 0 , 1 , n_pes, nullptr , 0 ,
166+ &team_world_dup[team_i]);
167+ if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
168+ printf (" Created team %d is invalid!\n " , team_i);
169+ abort ();
170+ }
171+ }
172+
173+ /* Assert the failure of a new team creation. */
174+ rocshmem_team_t new_team = ROCSHMEM_TEAM_INVALID;
175+ rocshmem_team_split_strided (_parentTeam, 0 , 1 , n_pes, nullptr , 0 ,
176+ &new_team);
177+ if (new_team != ROCSHMEM_TEAM_INVALID) {
178+ printf (" Created new team should have been invalid!\n " );
127179 abort ();
128180 }
129181 }
182+ else if (_splitType == ROCSHMEM_TEST_TEAM_SINGLE) {
183+ rocshmem_team_split_strided (_parentTeam, my_pe, 1 , 1 , nullptr , 0 ,
184+ &team_world_dup[0 ]);
185+ _expected_pe = rocshmem_team_my_pe (team_world_dup[0 ]);
186+ _expected_n_pes = rocshmem_team_n_pes (team_world_dup[0 ]);
187+
188+ if (_expected_n_pes != 1 ) {
189+ printf (" ROCSHMEM_TEST_TEAM_SINGLE: n_pes %d expected: 1\n " , _expected_n_pes);
190+ abort ();
191+ }
130192
131- for (int team_i = 0 ; team_i < NUM_TEAMS; team_i++) {
132- team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID;
133- rocshmem_team_split_strided (ROCSHMEM_TEAM_WORLD, 0 , 1 , n_pes, nullptr , 0 ,
134- &team_world_dup[team_i]);
135- if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) {
136- printf (" Created team %d is invalid!\n " , team_i);
193+ if (_expected_pe != 0 ) {
194+ printf (" ROCSHMEM_TEST_TEAM_SINGLE: my_pe %d expected: 0\n " , _expected_pe);
195+ abort ();
196+ }
197+ } else if (_splitType == ROCSHMEM_TEST_TEAM_BLOCK) {
198+ int mid_pe = n_pes / 2 ; // integer division
199+ int start_pe = my_pe < mid_pe ? 0 : mid_pe;
200+ int end_pe = my_pe < mid_pe ? (mid_pe - 1 ) : (n_pes - 1 );
201+ int num_pes = end_pe - start_pe + 1 ;
202+ int new_pe = my_pe < mid_pe ? my_pe : (my_pe - start_pe);
203+
204+ rocshmem_team_split_strided (_parentTeam, start_pe, 1 , num_pes, nullptr , 0 ,
205+ &team_world_dup[0 ]);
206+ _expected_pe = rocshmem_team_my_pe (team_world_dup[0 ]);
207+ _expected_n_pes = rocshmem_team_n_pes (team_world_dup[0 ]);
208+
209+ if (_expected_n_pes != num_pes) {
210+ printf (" ROCSHMEM_TEST_TEAM_BLOCK: n_pes %d expected: %d\n " , _expected_n_pes, num_pes);
137211 abort ();
138212 }
139- }
140213
141- /* Assert the failure of a new team creation. */
142- rocshmem_team_t new_team = ROCSHMEM_TEAM_INVALID;
143- rocshmem_team_split_strided (ROCSHMEM_TEAM_WORLD, 0 , 1 , n_pes, nullptr , 0 ,
144- &new_team);
145- if (new_team != ROCSHMEM_TEAM_INVALID) {
146- printf (" Created new team should have been invalid!\n " );
147- abort ();
214+ if (_expected_pe != new_pe) {
215+ printf (" ROCSHMEM_TEST_TEAM_BLOCK: my_pe %d expected: %d\n " , _expected_pe, new_pe);
216+ abort ();
217+ }
218+ } else if (_splitType == ROCSHMEM_TEST_TEAM_ODDEVEN) {
219+ int start_pe = (my_pe % 2 ) == 0 ? 0 : 1 ;
220+ int num_pes = n_pes / 2 ;
221+ if (((n_pes % 2 ) != 0 ) && ((my_pe % 2 ) == 0 ))
222+ num_pes++;
223+ int new_pe = (my_pe / 2 );
224+
225+ rocshmem_team_split_strided (_parentTeam, start_pe, 2 , num_pes, nullptr , 0 ,
226+ &team_world_dup[0 ]);
227+ _expected_pe = rocshmem_team_my_pe (team_world_dup[0 ]);
228+ _expected_n_pes = rocshmem_team_n_pes (team_world_dup[0 ]);
229+
230+ if (_expected_n_pes != num_pes) {
231+ printf (" ROCSHMEM_TEST_TEAM_ODDEVEN: n_pes %d expected: %d\n " , _expected_n_pes, num_pes);
232+ abort ();
233+ }
234+
235+ if (_expected_pe != new_pe) {
236+ printf (" ROCSHMEM_TEST_TEAM_ODDEVEN: my_pe %d expected: %d\n " , _expected_pe, new_pe);
237+ abort ();
238+ }
148239 }
149240}
150241
@@ -154,18 +245,31 @@ void TeamCtxInfraTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop,
154245
155246 /* Copy array of teams to device */
156247 rocshmem_team_t *teams_on_device;
157- CHECK_HIP (hipMalloc (&teams_on_device, sizeof (rocshmem_team_t ) * NUM_TEAMS));
158- CHECK_HIP (hipMemcpy (teams_on_device, team_world_dup,
159- sizeof (rocshmem_team_t ) * NUM_TEAMS, hipMemcpyHostToDevice));
160248
161- hipLaunchKernelGGL (TeamCtxInfraTest, gridSize, blockSize, shared_bytes,
162- stream, _shmem_context, teams_on_device);
249+ if (_splitType == ROCSHMEM_TEST_TEAM_DUP) {
250+ CHECK_HIP (hipMalloc (&teams_on_device, sizeof (rocshmem_team_t ) * NUM_TEAMS));
251+ CHECK_HIP (hipMemcpy (teams_on_device, team_world_dup,
252+ sizeof (rocshmem_team_t ) * NUM_TEAMS, hipMemcpyHostToDevice));
253+
254+ hipLaunchKernelGGL (TeamCtxInfraTest, gridSize, blockSize, shared_bytes,
255+ stream, _shmem_context, teams_on_device);
256+ } else if (_splitType == ROCSHMEM_TEST_TEAM_SINGLE ||
257+ _splitType == ROCSHMEM_TEST_TEAM_BLOCK ||
258+ _splitType == ROCSHMEM_TEST_TEAM_ODDEVEN ) {
259+ CHECK_HIP (hipMalloc (&teams_on_device, sizeof (rocshmem_team_t )));
260+ CHECK_HIP (hipMemcpy (teams_on_device, team_world_dup,
261+ sizeof (rocshmem_team_t ), hipMemcpyHostToDevice));
262+
263+ hipLaunchKernelGGL (TeamCtxInfraSimpleTest, gridSize, blockSize, shared_bytes,
264+ stream, _shmem_context, teams_on_device[0 ], _expected_pe, _expected_n_pes);
265+ }
163266
164267 CHECK_HIP (hipFree (teams_on_device));
165268}
166269
167270void TeamCtxInfraTester::postLaunchKernel () {
168- for (int team_i = 0 ; team_i < NUM_TEAMS; team_i++) {
271+ int num_teams = _splitType == ROCSHMEM_TEST_TEAM_DUP ? NUM_TEAMS : 1 ;
272+ for (int team_i = 0 ; team_i < num_teams; team_i++) {
169273 rocshmem_team_destroy (team_world_dup[team_i]);
170274 }
171275}
0 commit comments