Skip to content

Commit c943eb7

Browse files
Merge commit 'b1301d66d19ad2244d24861505e176ef58bf8609'
2 parents 044b64c + b1301d6 commit c943eb7

File tree

5 files changed

+100
-18
lines changed

5 files changed

+100
-18
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2009,9 +2009,30 @@ AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20092009

20102010
SmallVector<unsigned>
20112011
AMDMfmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
2012-
llvm::report_fatal_error(
2013-
"getThreadsPerWarpForOperand not implemented for AMDMfmaEncodingAttr");
2014-
return {};
2012+
auto rank = ::getOrder(*this).size();
2013+
SmallVector<unsigned> threads(rank, 1);
2014+
unsigned kThreads;
2015+
unsigned nonKThreads;
2016+
switch (getMDim()) {
2017+
case 32:
2018+
assert(getNDim() == 32);
2019+
kThreads = 2;
2020+
nonKThreads = 32;
2021+
break;
2022+
case 16:
2023+
assert(getNDim() == 16);
2024+
kThreads = 4;
2025+
nonKThreads = 16;
2026+
break;
2027+
default:
2028+
llvm::report_fatal_error(
2029+
"unexpected mfma shape encountered in getThreadsPerWarpForOperand");
2030+
}
2031+
int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2;
2032+
int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1;
2033+
threads[kDimIdx] = kThreads;
2034+
threads[nonKDimIdx] = nonKThreads;
2035+
return threads;
20152036
}
20162037

20172038
SmallVector<int64_t>
@@ -2077,9 +2098,30 @@ AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20772098

20782099
SmallVector<unsigned>
20792100
AMDWmmaEncodingAttr::getThreadsPerWarpForOperand(int opIdx) const {
2080-
llvm::report_fatal_error("getThreadsPerWarpForOperand not implemented for "
2081-
"AMDWmmaEncodingAttr");
2082-
return {};
2101+
auto rank = ::getOrder(*this).size();
2102+
SmallVector<unsigned> threads(rank, 1);
2103+
unsigned kThreads;
2104+
unsigned nonKThreads;
2105+
switch (getVersion()) {
2106+
case 1:
2107+
// kThreads * onKThreads != 32,
2108+
// because values in lanes (n, n + 16) duplicates
2109+
kThreads = 1;
2110+
nonKThreads = 16;
2111+
break;
2112+
case 2:
2113+
kThreads = 2;
2114+
nonKThreads = 16;
2115+
break;
2116+
default:
2117+
llvm::report_fatal_error(
2118+
"unsupported WMMA version in getThreadsPerWarpForOperand");
2119+
}
2120+
int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2;
2121+
int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1;
2122+
threads[kDimIdx] = kThreads;
2123+
threads[nonKDimIdx] = nonKThreads;
2124+
return threads;
20832125
}
20842126

20852127
SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {

python/test/unit/language/test_core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1773,7 +1773,6 @@ def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr):
17731773

17741774

17751775
@pytest.mark.interpreter
1776-
@pytest.mark.skipif(is_hip(), reason="Not implemented for AMD At this moment")
17771776
def test_load_scope_sem_coop_grid_cta_one(device):
17781777

17791778
@triton.jit

third_party/amd/backend/compiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class HIPOptions:
5252
default_dot_input_precision: str = "ieee"
5353
allowed_dot_input_precisions: Tuple[str] = ("ieee", )
5454
enable_fp_fusion: bool = True
55-
# TODO: Implement cooperative grid launch for AMD:
56-
# See: https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
5755
launch_cooperative_grid: bool = False
5856
matrix_instr_nonkdim: int = 0
5957
kpack: int = 1

third_party/amd/backend/driver.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def format_of(ty):
225225
}[ty_to_cpp(ty)]
226226

227227
args_format = ''.join([format_of(ty) for ty in signature.values()])
228-
format = "iiiKKOOOO" + args_format
228+
format = "piiiKKOOOO" + args_format
229229
signature = ','.join(map(_serialize_signature, signature.values()))
230230
signature = list(filter(bool, signature.split(',')))
231231
signature = {i: s for i, s in enumerate(signature)}
@@ -267,6 +267,12 @@ def format_of(ty):
267267
unsigned int blockDimY, unsigned int blockDimZ, \\
268268
unsigned int sharedMemBytes, hipStream_t stream, \\
269269
void **kernelParams, void **extra) \\
270+
FOR_EACH_ERR_FN(hipModuleLaunchCooperativeKernel, hipFunction_t f, \\
271+
unsigned int gridDimX, unsigned int gridDimY, \\
272+
unsigned int gridDimZ, unsigned int blockDimX, \\
273+
unsigned int blockDimY, unsigned int blockDimZ, \\
274+
unsigned int sharedMemBytes, hipStream_t stream, \\
275+
void **kernelParams, void **extra) \\
270276
FOR_EACH_ERR_FN(hipPointerGetAttribute, void *data, \\
271277
hipPointer_attribute attribute, hipDeviceptr_t ptr)
272278
@@ -338,14 +344,18 @@ def format_of(ty):
338344
339345
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
340346
341-
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
347+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
342348
// printf("_launch hip kernel\\n");
343349
hipDeviceptr_t global_scratch = 0;
344350
void *params[] = {{ {', '.join(params)} }};
351+
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
352+
HIP_CHECK(hipSymbolTable.hipModuleLaunchCooperativeKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
353+
return;
354+
}}
345355
if (gridX*gridY*gridZ > 0) {{
346-
HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
347-
}}
356+
HIP_CHECK(hipSymbolTable.hipModuleLaunchKernel(function, gridX, gridY, gridZ, {warp_size}*num_warps, 1, 1, shared_memory, stream, params, 0));
348357
}}
358+
}}
349359
350360
typedef struct _DevicePtrInfo {{
351361
hipDeviceptr_t dev_ptr;
@@ -398,12 +408,14 @@ def format_of(ty):
398408
int gridX, gridY, gridZ;
399409
uint64_t _stream;
400410
uint64_t _function;
411+
int launch_cooperative_grid;
401412
PyObject *launch_enter_hook = NULL;
402413
PyObject *launch_exit_hook = NULL;
403414
PyObject *kernel_metadata = NULL;
404415
PyObject *launch_metadata = NULL;
405416
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
406-
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &_stream, &_function,
417+
if(!PyArg_ParseTuple(args, \"{format}\", &launch_cooperative_grid,
418+
&gridX, &gridY, &gridZ, &_stream, &_function,
407419
&kernel_metadata, &launch_metadata,
408420
&launch_enter_hook, &launch_exit_hook {args_list})) {{
409421
return NULL;
@@ -426,7 +438,7 @@ def format_of(ty):
426438
427439
// raise exception asap
428440
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
429-
_launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
441+
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
430442
431443
if(launch_exit_hook != Py_None){{
432444
PyObject* args = Py_BuildValue("(O)", launch_metadata);
@@ -482,9 +494,10 @@ def __init__(self, src, metadata):
482494
src = make_launcher(constants, signature, metadata.warp_size)
483495
mod = compile_module_from_src(src, "__triton_launcher")
484496
self.launch = mod.launch
497+
self.launch_cooperative_grid = metadata.launch_cooperative_grid
485498

486499
def __call__(self, *args):
487-
self.launch(*args)
500+
self.launch(self.launch_cooperative_grid, *args)
488501

489502

490503
class HIPDriver(GPUDriver):

unittest/Dialect/TritonGPU/DialectTest.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,8 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) {
368368
auto dot2dOp1 = createDotOperand(1, mfma2d, 4);
369369
ASSERT_THAT(dot2dOp0.getWarpOrder(), mfma2d.getWarpOrder());
370370
ASSERT_THAT(dot2dOp1.getWarpOrder(), mfma2d.getWarpOrder());
371+
ASSERT_THAT(dot2dOp0.getThreadsPerWarp(), testing::ElementsAre(32u, 2u));
372+
ASSERT_THAT(dot2dOp1.getThreadsPerWarp(), testing::ElementsAre(2u, 32u));
371373

372374
auto tmfma2d = createTransposedMFMA(32, 32, {2, 4});
373375
auto tdot2dOp0 = createDotOperand(0, tmfma2d, 4);
@@ -380,12 +382,28 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) {
380382
auto dot3dOp1 = createDotOperand(1, mfma3d, 4);
381383
ASSERT_THAT(dot3dOp0.getWarpOrder(), mfma3d.getWarpOrder());
382384
ASSERT_THAT(dot3dOp1.getWarpOrder(), mfma3d.getWarpOrder());
385+
ASSERT_THAT(dot3dOp0.getThreadsPerWarp(), testing::ElementsAre(1u, 32u, 2u));
386+
ASSERT_THAT(dot3dOp1.getThreadsPerWarp(), testing::ElementsAre(1u, 2u, 32u));
383387

384388
auto tmfma3d = createTransposedMFMA(32, 32, {2, 4, 1});
385389
auto tdot3dOp0 = createDotOperand(0, tmfma3d, 4);
386390
auto tdot3dOp1 = createDotOperand(1, tmfma3d, 4);
387391
ASSERT_THAT(tdot3dOp0.getWarpOrder(), tmfma3d.getWarpOrder());
388392
ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder());
393+
394+
auto mfma16_2d = createMFMA(16, 16, {2, 4});
395+
auto dot16_2dOp0 = createDotOperand(0, mfma16_2d, 4);
396+
auto dot16_2dOp1 = createDotOperand(1, mfma16_2d, 4);
397+
ASSERT_THAT(dot16_2dOp0.getThreadsPerWarp(), testing::ElementsAre(16u, 4u));
398+
ASSERT_THAT(dot16_2dOp1.getThreadsPerWarp(), testing::ElementsAre(4u, 16u));
399+
400+
auto mfma16_3d = createMFMA(16, 16, {2, 4, 1});
401+
auto dot16_3dOp0 = createDotOperand(0, mfma16_3d, 4);
402+
auto dot16_3dOp1 = createDotOperand(1, mfma16_3d, 4);
403+
ASSERT_THAT(dot16_3dOp0.getThreadsPerWarp(),
404+
testing::ElementsAre(1u, 16u, 4u));
405+
ASSERT_THAT(dot16_3dOp1.getThreadsPerWarp(),
406+
testing::ElementsAre(1u, 4u, 16u));
389407
}
390408

391409
TEST_F(AMDWmmaLayoutTest, wmmaV1) {
@@ -434,24 +452,36 @@ TEST_F(AMDWmmaLayoutTest, wmma_dot_op) {
434452
auto dot2dVer1Op1 = createDotOperand(1, wmma2dVer1, 16);
435453
ASSERT_THAT(dot2dVer1Op0.getWarpOrder(), wmma2dVer1.getWarpOrder());
436454
ASSERT_THAT(dot2dVer1Op1.getWarpOrder(), wmma2dVer1.getWarpOrder());
455+
ASSERT_THAT(dot2dVer1Op0.getThreadsPerWarp(), testing::ElementsAre(16u, 1u));
456+
ASSERT_THAT(dot2dVer1Op1.getThreadsPerWarp(), testing::ElementsAre(1u, 16u));
437457

438-
auto wmma3dVer1 = createWMMAv1({2, 4});
458+
auto wmma3dVer1 = createWMMAv1({2, 4, 1});
439459
auto dot3dVer1Op0 = createDotOperand(0, wmma3dVer1, 16);
440460
auto dot3dVer1Op1 = createDotOperand(1, wmma3dVer1, 16);
441461
ASSERT_THAT(dot3dVer1Op0.getWarpOrder(), wmma3dVer1.getWarpOrder());
442462
ASSERT_THAT(dot3dVer1Op1.getWarpOrder(), wmma3dVer1.getWarpOrder());
463+
ASSERT_THAT(dot3dVer1Op0.getThreadsPerWarp(),
464+
testing::ElementsAre(1, 16u, 1u));
465+
ASSERT_THAT(dot3dVer1Op1.getThreadsPerWarp(),
466+
testing::ElementsAre(1, 1u, 16u));
443467

444468
auto wmma2dVer2 = createWMMAv2(false, {2, 4});
445469
auto dot2dVer2Op0 = createDotOperand(0, wmma2dVer2, 16);
446470
auto dot2dVer2Op1 = createDotOperand(1, wmma2dVer2, 16);
447471
ASSERT_THAT(dot2dVer2Op0.getWarpOrder(), wmma2dVer2.getWarpOrder());
448472
ASSERT_THAT(dot2dVer2Op1.getWarpOrder(), wmma2dVer2.getWarpOrder());
473+
ASSERT_THAT(dot2dVer2Op0.getThreadsPerWarp(), testing::ElementsAre(16u, 2u));
474+
ASSERT_THAT(dot2dVer2Op1.getThreadsPerWarp(), testing::ElementsAre(2u, 16u));
449475

450-
auto wmma3dVer2 = createWMMAv2(false, {2, 4});
476+
auto wmma3dVer2 = createWMMAv2(false, {2, 4, 1});
451477
auto dot3dVer2Op0 = createDotOperand(0, wmma3dVer2, 16);
452478
auto dot3dVer2Op1 = createDotOperand(1, wmma3dVer2, 16);
453479
ASSERT_THAT(dot3dVer2Op0.getWarpOrder(), wmma3dVer2.getWarpOrder());
454480
ASSERT_THAT(dot3dVer2Op1.getWarpOrder(), wmma3dVer2.getWarpOrder());
481+
ASSERT_THAT(dot3dVer2Op0.getThreadsPerWarp(),
482+
testing::ElementsAre(1, 16u, 2u));
483+
ASSERT_THAT(dot3dVer2Op1.getThreadsPerWarp(),
484+
testing::ElementsAre(1, 2u, 16u));
455485
}
456486

457487
class LinearEncodingTest : public ::testing::Test {

0 commit comments

Comments
 (0)