Skip to content

Commit d514243

Browse files
authored
[Proton][AMD] Fix peak TB/s and support gfx950 specs (#7175)
Using `2 * bus_width * memory_clock_rate * 1e3 / 8` as the formula cannot deduce the proper max TB/s on AMD devices; the method is more involved on AMD. For now we just hardcode the TB/s result to get correct result and unblock supporting of gfx950.
1 parent 4d791f0 commit d514243

File tree

5 files changed

+58
-15
lines changed

5 files changed

+58
-15
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def opint(self):
7070

7171
@property
7272
def max_tbps(self):
73-
return proton.specs.max_bps(self.device_info["bus_width"], self.device_info["memory_clock_rate"]) * 1e-12
73+
return proton.specs.max_bps(self.device_type, self.device_info["arch"], self.device_info["bus_width"],
74+
self.device_info["memory_clock_rate"]) * 1e-12
7475

7576
@property
7677
def max_tflops(self):

third_party/proton/proton/specs.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@
99
(width / 8),
1010
"100":
1111
lambda width, num_sms, clock_rate, **kwargs: (num_sms * 16384 * (clock_rate / 1e3) * 1e6) / (width / 8),
12-
},
13-
"HIP": {
14-
"gfx90a": lambda width, **kwargs: 383e12 / (width / 8),
15-
"gfx942": lambda width, **kwargs: 2614.9e12 / (width / 8),
16-
},
12+
}
1713
}
1814

15+
amd_bps_by_arch = {
16+
'gfx90a': 3.2 * 1e12,
17+
'gfx942': 5.3 * 1e12,
18+
'gfx950': 8.0 * 1e12,
19+
}
20+
21+
# FP8 Matrix Performance(FLOPS/clock/CU)
22+
# For gfx90a we use the performance of INT8 since it doesn't support FP8 matrix operations.
23+
amd_fp8_flops_by_arch = {'gfx90a': 1024, 'gfx942': 4096, 'gfx950': 8192}
24+
1925

2026
def max_flops(device_type, arch, width, num_sms, clock_rate):
2127
"""
@@ -31,6 +37,9 @@ def max_flops(device_type, arch, width, num_sms, clock_rate):
3137
Returns:
3238
float: The maximum FLOPS for the given device type and width.
3339
"""
40+
if device_type == "HIP":
41+
return amd_fp8_flops_by_arch[arch] * num_sms * clock_rate * 1e3 / (width / 8)
42+
3443
if device_type not in flops_by_device:
3544
raise ValueError(f"Unsupported device type: {device_type}")
3645

@@ -42,7 +51,7 @@ def max_flops(device_type, arch, width, num_sms, clock_rate):
4251
return flops_func(width, num_sms=num_sms, clock_rate=clock_rate)
4352

4453

45-
def max_bps(bus_width, memory_clock_rate):
54+
def max_bps(device_type, arch, bus_width, memory_clock_rate):
4655
"""
4756
Calculate the maximum bytes per second for a given bus width and memory clock rate.
4857
@@ -53,4 +62,8 @@ def max_bps(bus_width, memory_clock_rate):
5362
Returns:
5463
float: The maximum bytes per second.
5564
"""
56-
return 2 * bus_width * memory_clock_rate * 1e3 / 8
65+
if device_type == "CUDA":
66+
return 2 * bus_width * memory_clock_rate * 1e3 / 8
67+
else:
68+
assert device_type == "HIP"
69+
return amd_bps_by_arch[arch]

third_party/proton/proton/viewer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def get_min_time_bytes(df, device_info):
9696
for device_index in device_info[device_type]:
9797
idx = df["device_id"] == device_index
9898
device_frames = df[idx]
99-
memory_clock_rate = device_info[device_type][device_index]["memory_clock_rate"] # in khz
100-
bus_width = device_info[device_type][device_index]["bus_width"] # in bits
101-
peak_bandwidth = specs.max_bps(bus_width, memory_clock_rate)
99+
device = device_info[device_type][device_index]
100+
memory_clock_rate = device["memory_clock_rate"] # in khz
101+
bus_width = device["bus_width"] # in bits
102+
peak_bandwidth = specs.max_bps(device_type, device['arch'], bus_width, memory_clock_rate)
102103
min_time_bytes.loc[idx, "min_time"] += device_frames["bytes"] / peak_bandwidth
103104
return min_time_bytes
104105

third_party/proton/test/examples/hip.json

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@
3030
"flops8": 1e10,
3131
"bytes": 1e7
3232
}
33+
},
34+
{
35+
"children": [],
36+
"frame": {
37+
"name": "foo2",
38+
"type": "function"
39+
},
40+
"metrics": {
41+
"count": 1,
42+
"device_id": "2",
43+
"device_type": "HIP",
44+
"time (ns)": 204800,
45+
"flops8": 1e12,
46+
"bytes": 1e9
47+
}
3348
}
3449
],
3550
"frame": {
@@ -55,9 +70,16 @@
5570
"1": {
5671
"arch": "gfx942",
5772
"bus_width": 8192,
58-
"clock_rate": 5200000,
59-
"memory_clock_rate": 2525000,
73+
"clock_rate": 2100000,
74+
"memory_clock_rate": 1200000,
6075
"num_sms": 304
76+
},
77+
"2": {
78+
"arch": "gfx950",
79+
"bus_width": 8192,
80+
"clock_rate": 2200000,
81+
"memory_clock_rate": 1900000,
82+
"num_sms": 256
6183
}
6284
}
6385
}

third_party/proton/test/test_viewer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,13 @@ def test_min_time_flops():
101101
ret = get_min_time_flops(gf.dataframe, device_info)
102102
device0_idx = gf.dataframe["device_id"] == "0"
103103
device1_idx = gf.dataframe["device_id"] == "1"
104+
device2_idx = gf.dataframe["device_id"] == "2"
104105
# CDNA2
105-
np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000026]], atol=1e-5)
106+
np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[0.000055]], atol=1e-5)
106107
# CDNA3
107108
np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[0.000038]], atol=1e-5)
109+
# CDNA4
110+
np.testing.assert_allclose(ret[device2_idx].to_numpy(), [[0.000217]], atol=1e-5)
108111

109112

110113
def test_min_time_bytes():
@@ -120,10 +123,13 @@ def test_min_time_bytes():
120123
ret = get_min_time_bytes(gf.dataframe, device_info)
121124
device0_idx = gf.dataframe["device_id"] == "0"
122125
device1_idx = gf.dataframe["device_id"] == "1"
126+
device2_idx = gf.dataframe["device_id"] == "2"
123127
# CDNA2
124-
np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[6.10351e-06]], atol=1e-6)
128+
np.testing.assert_allclose(ret[device0_idx].to_numpy(), [[3.125e-06]], atol=1e-6)
125129
# CDNA3
126130
np.testing.assert_allclose(ret[device1_idx].to_numpy(), [[1.93378e-05]], atol=1e-6)
131+
# CDNA4
132+
np.testing.assert_allclose(ret[device2_idx].to_numpy(), [[0.000125]], atol=1e-6)
127133

128134

129135
def test_percentage():

0 commit comments

Comments
 (0)