Skip to content

Commit 4b349f0

Browse files
authored
add xpu_arch_to_mem_type_multiplier
1 parent ae15596 commit 4b349f0

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

third_party/proton/proton/specs.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
'gfx950': 8.0 * 1e12,
1919
}
2020

21+
xpu_arch_to_mem_type_multiplier = {
22+
"pvc": 2,
23+
"dg2": 8,
24+
"bmg": 8,
25+
}
26+
2127
# FP8 Matrix Performance(FLOPS/clock/CU)
2228
# For gfx90a we use the performance of INT8 since it doesn't support FP8 matrix operations.
2329
amd_fp8_flops_by_arch = {'gfx90a': 1024, 'gfx942': 4096, 'gfx950': 8192}
@@ -67,11 +73,5 @@ def max_bps(device_type, arch, bus_width, memory_clock_rate):
6773
elif device_type == "HIP":
6874
return amd_bps_by_arch[arch]
6975
else:
70-
assert device_type == "XPU"
71-
if arch == "Xe-HPC":
72-
multiplier = 2
73-
elif arch == "Xe2" or arch == "Xe-HPG":
74-
multiplier = 8
75-
else:
76-
raise ValueError(f"Unsupported architecture: {arch}")
77-
return multiplier * bus_width * memory_clock_rate * 1e3 / 8
76+
assert device_type == "XPU"
77+
return xpu_arch_to_mem_type_multiplier[arch] * bus_width * memory_clock_rate * 1e3 / 8

0 commit comments

Comments
 (0)