Skip to content

Commit 95ebc58

Browse files
authored
Merge pull request #10729 from edgargabriel/topic/gfx-arch-cleanup
UCT/ROCM: cleanup GPU detection code
2 parents 3bd5e22 + a0ebcd2 commit 95ebc58

File tree

3 files changed

+18
-37
lines changed

3 files changed

+18
-37
lines changed

src/uct/rocm/base/rocm_base.c

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -511,46 +511,28 @@ ucs_status_t uct_rocm_base_get_link_type(hsa_amd_link_info_type_t *link_type)
511511
uct_rocm_amd_gpu_product_t uct_rocm_base_get_gpu_product(void)
512512
{
513513
uct_rocm_amd_gpu_product_t gpu_product = UCT_ROCM_AMD_GPU_MI200;
514-
char product_name[64];
515514
char gfx_name[64];
516515
hsa_status_t status;
517516

518-
/* fetching data from GPU 0, assuming all GPUs on a node are
519-
identical */
517+
/* Query the gfx architecture name */
520518
status = hsa_agent_get_info(uct_rocm_base_agents.gpu_agents[0],
521-
(hsa_agent_info_t)
522-
HSA_AMD_AGENT_INFO_PRODUCT_NAME,
523-
(void*)product_name);
519+
(hsa_agent_info_t)HSA_AGENT_INFO_NAME,
520+
gfx_name);
524521
if (status != HSA_STATUS_SUCCESS) {
525-
ucs_debug("Error in hsa_agent_info %d", status);
522+
ucs_debug("hsa_agent_get_info failed: %d", status);
526523
return gpu_product;
527524
}
528525

529-
if (NULL != strstr(product_name, "MI300A")) {
530-
gpu_product = UCT_ROCM_AMD_GPU_MI300A;
531-
ucs_trace("found MI300A GPU");
532-
} else if (NULL != strstr(product_name, "MI300X")) {
533-
gpu_product = UCT_ROCM_AMD_GPU_MI300X;
534-
ucs_trace("found MI300X GPU");
526+
if (NULL != strstr(gfx_name, "gfx94")) {
527+
/* This is an MI300/325 GPU */
528+
gpu_product = UCT_ROCM_AMD_GPU_MI300;
529+
ucs_trace("found gfx94* GPU");
530+
} else if (NULL != strstr(gfx_name, "gfx95")) {
531+
/* This is an MI35x GPU */
532+
gpu_product = UCT_ROCM_AMD_GPU_MI350;
533+
ucs_trace("found gfx950 GPU");
535534
} else {
536-
/* In case product_name is not set correctly, query the gfx
537-
architecture name */
538-
status = hsa_agent_get_info(uct_rocm_base_agents.gpu_agents[0],
539-
(hsa_agent_info_t)HSA_AGENT_INFO_NAME,
540-
(void*)gfx_name);
541-
if (status != HSA_STATUS_SUCCESS) {
542-
ucs_debug("error in hsa_agent_info %d", status);
543-
return gpu_product;
544-
}
545-
546-
if (NULL != strstr(gfx_name, "gfx94")) {
547-
/* This is an MI300 GPU, but cannot say whether its the A or X
548-
variant. Assuming A variant for now*/
549-
gpu_product = UCT_ROCM_AMD_GPU_MI300A;
550-
ucs_trace("found gfx94* GPU, assuming MI300A");
551-
} else {
552-
ucs_trace("assuming MI100/MI200 GPU");
553-
}
535+
ucs_trace("assuming MI100/MI200 GPU");
554536
}
555537

556538
return gpu_product;

src/uct/rocm/base/rocm_base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
typedef enum uct_rocm_amd_gpu_product {
1717
UCT_ROCM_AMD_GPU_UNDEFINED = -1,
1818
UCT_ROCM_AMD_GPU_MI200,
19-
UCT_ROCM_AMD_GPU_MI300A,
20-
UCT_ROCM_AMD_GPU_MI300X
19+
UCT_ROCM_AMD_GPU_MI300,
20+
UCT_ROCM_AMD_GPU_MI350,
2121
} uct_rocm_amd_gpu_product_t;
2222

2323
hsa_status_t uct_rocm_base_init(void);

src/uct/rocm/copy/rocm_copy_iface.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ uct_rocm_copy_get_default_bandwidth(const uct_perf_attr_t *perf_attr)
212212
}
213213

214214
static UCS_F_ALWAYS_INLINE double
215-
uct_rocm_copy_get_mi300a_bandwidth(const uct_perf_attr_t *perf_attr)
215+
uct_rocm_copy_get_mi300_bandwidth(const uct_perf_attr_t *perf_attr)
216216
{
217217
switch (perf_attr->operation) {
218218
case UCT_EP_OP_GET_SHORT:
@@ -250,9 +250,8 @@ uct_rocm_copy_estimate_perf(uct_iface_h tl_iface, uct_perf_attr_t *perf_attr)
250250
if (!(perf_attr->field_mask & UCT_PERF_ATTR_FIELD_OPERATION)) {
251251
bandwidth.shared = 0;
252252
} else {
253-
if (gpu_product == UCT_ROCM_AMD_GPU_MI300A ||
254-
gpu_product == UCT_ROCM_AMD_GPU_MI300X) {
255-
bandwidth.shared = uct_rocm_copy_get_mi300a_bandwidth(perf_attr);
253+
if (gpu_product == UCT_ROCM_AMD_GPU_MI300) {
254+
bandwidth.shared = uct_rocm_copy_get_mi300_bandwidth(perf_attr);
256255
} else {
257256
bandwidth.shared = uct_rocm_copy_get_default_bandwidth(perf_attr);
258257
}

0 commit comments

Comments
 (0)