Skip to content

Commit a70d585

Browse files
authored
[AMD] Turn stream pipeline v2 as the default (#4665)
This commit turns on the v2 pipeliner as the default. We still keep v1 for some extended time to make perf debugging easier; but expect to remove it soon.
1 parent 4348109 commit a70d585

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

python/test/unit/language/test_core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5266,11 +5266,13 @@ def matmul_kernel( #
52665266
@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15'])
52675267
@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128])
52685268
def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device):
5269+
num_stages = 3
52695270
if is_cuda():
52705271
cc = torch.cuda.get_device_capability()
52715272
if cc[0] >= 9 and in_type_str == "float8e4b15":
52725273
pytest.skip("Dot op does not support fp8e4b15 on CUDA arch >= 90")
52735274
elif is_hip():
5275+
num_stages = 2
52745276
if in_type_str != 'float8e5':
52755277
pytest.skip('test_fp8_dot_acc for HIP currently broken in upstream.')
52765278

@@ -5284,7 +5286,8 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
52845286
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
52855287
max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None
52865288
h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0),
5287-
C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps)
5289+
C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps,
5290+
num_pipeline_stages=num_stages)
52885291
torch_a = torch.from_numpy(A).to(device=device)
52895292
th_a = f8_to_f16(torch_a, in_type_str)
52905293
torch_b = torch.from_numpy(B).to(device=device)

python/tutorials/03-matrix-multiplication.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,19 +206,19 @@ def get_hip_autotune_config():
206206
return [
207207
triton.Config(
208208
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
209-
num_warps=4, num_stages=0),
209+
num_warps=4, num_stages=2),
210210
triton.Config(
211211
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
212-
num_warps=8, num_stages=0),
212+
num_warps=8, num_stages=2),
213213
triton.Config(
214214
{'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
215-
num_warps=8, num_stages=0),
215+
num_warps=8, num_stages=2),
216216
triton.Config(
217217
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
218-
num_warps=4, num_stages=0),
218+
num_warps=4, num_stages=2),
219219
triton.Config(
220220
{'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
221-
num_warps=4, num_stages=0),
221+
num_warps=4, num_stages=2),
222222
]
223223

224224

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def make_ttgir(mod, metadata, options):
169169
passes.ttgpuir.add_remove_layout_conversions(pm)
170170
amd.passes.ttgpuir.add_optimize_epilogue(pm)
171171
passes.ttgpuir.add_optimize_dot_operands(pm, True)
172-
use_new_pipeliner = os.getenv("TRITON_HIP_USE_NEW_STREAM_PIPELINE", "0") == "1"
172+
use_new_pipeliner = os.getenv("TRITON_HIP_USE_NEW_STREAM_PIPELINE", "1") == "1"
173173
if amd.has_matrix_core_feature(options.arch):
174174
if use_new_pipeliner:
175175
# In the old pipeliner we only support num_stages = 0/1, which means something

0 commit comments

Comments
 (0)