Skip to content

Commit 26dd693

Browse files
Merge OpenAI Triton commit ff77e98 (#3419)
This PR change the Triton base from df66eb5 to ff77e98 (Feb 12). Pass rate: 98.09% Please do not squash and merge this PR.
2 parents 02c26fd + 9029c5b commit 26dd693

File tree

24 files changed

+753
-128
lines changed

24 files changed

+753
-128
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ffe3129e9bdc146ee4d91e849173d1c64b1ae974
1+
1188b1ff7b956cb65d8ddda5f1e56c432f1a57c7

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,10 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
10551055
```
10561056
}];
10571057

1058-
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
1058+
let arguments = (ins FlatSymbolRefAttr:$callee,
1059+
Variadic<AnyType>:$operands,
1060+
OptionalAttr<DictArrayAttr>:$arg_attrs,
1061+
OptionalAttr<DictArrayAttr>:$res_attrs);
10591062
let results = (outs Variadic<AnyType>);
10601063

10611064
let builders = [

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ LinearLayout chooseStMatrixLayout(MLIRContext *ctx, RankedTensorType tensorTy,
256256
// tensor into shared memory using the `ldmatrix` instruction.
257257
LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
258258
bool needTrans, int32_t elemBitWidth);
259+
260+
// The primary goal of this function is to efficiently load 2D tiles of a
261+
// tensor from shared memory using the `ds_read_tr` instruction for AMD GPUs.
262+
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
263+
int32_t elemBitWidth);
259264
} // namespace mlir::triton::gpu
260265

261266
#endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
922922
if (argAttrs.empty())
923923
return;
924924
assert(type.getNumInputs() == argAttrs.size());
925-
function_interface_impl::addArgAndResultAttrs(
925+
call_interface_impl::addArgAndResultAttrs(
926926
builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
927927
getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
928928
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,135 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
391391
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
392392
}
393393

394+
LinearLayout chooseDotDsReadB64Tr16Layout(DotOperandEncodingAttr dotMfmaLayout,
395+
ArrayRef<int64_t> shape,
396+
int32_t elemBitWidth) {
397+
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
398+
assert(mfmaLayout.getMDim() == 16 || mfmaLayout.getNDim() == 32);
399+
assert(elemBitWidth == 16);
400+
401+
auto rank = shape.size();
402+
bool hasBatchDim = rank == 3;
403+
int32_t kWidthDot = dotMfmaLayout.getKWidth();
404+
// Number of bits loaded by an LDS read. ds_read_tr primarily supports 64-bit
405+
// loads for most element sizes (16b, 8b, 4b).
406+
const int32_t ldsReadWidth = 64;
407+
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
408+
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
409+
410+
int32_t kSize = shape[kDim];
411+
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
412+
413+
MLIRContext *ctx = dotMfmaLayout.getContext();
414+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
415+
416+
StringAttr kRegister = S("register");
417+
StringAttr kLane = S("lane");
418+
StringAttr kWarp = S("warp");
419+
420+
// register order
421+
// operand A: [1, 0] / [2, 1, 0]
422+
// operand B: [0, 1] / [1, 2, 0]
423+
// Regular dot mfma order for both cases is [k, nonk]/[k, nonk, batch]
424+
// For LDS transpose layout swap order to [nonk, k]/[nonk, k, batch]
425+
SmallVector<unsigned> order = triton::gpu::getOrder(dotMfmaLayout);
426+
std::swap(order[0], order[1]);
427+
428+
// In the LDS transpose logic, each thread accesses 64 bits (8 bytes) of data.
429+
// The smallest unit for transposing is a 4x4 sub-tile of threads, where each
430+
// thread reads 4 16-bit elements along the non-K dimension, resulting in a
431+
// [non-K, K] = {16, 4} sub-tile of elements. Because of transposing
432+
// mechanism, thread ends up with 4 16-bit elements along K dim.
433+
//
434+
// The MFMA selection logic prioritizes double-rate MFMA instructions whenever
435+
// possible. Specifically:
436+
// - For MFMA operations that are non-K = 16, when blockK > 16, mfma16x16x32
437+
// is selected; otherwise (blockK ≤ 16), mfma16x16x16 remains the choice.
438+
// - For MFMA operations that are non-K = 32, when blockK > 8, mfma32x32x16 is
439+
// selected; otherwise (blockK ≤ 8), mfma32x32x8 is used.
440+
//
441+
// In double-rate MFMA instructions, each thread holds 8 elements along the K
442+
// dimension.
443+
// - The first 4 elements belong to the first sub-tile.
444+
// - The next 4 elements belong to the second sub-tile.
445+
//
446+
// We then group these into larger tiles, each consisting of 8 of these 16x4
447+
// sub-tiles. These tiles correspond to data for one mfma instruction. The
448+
// shapes of these tiles depend on the MFMA instruction used:
449+
// 1. For mfma32x32x16, the tile shape is [non-K, K] = {32, 16}.
450+
// 2. For mfma16x16x32, the tile shape is [non-K, K] = {16, 32}.
451+
//
452+
// For single-rate mfma instructions, each thread holds 4 elements along K
453+
// dimension. This means larger tile (that corresponds to one mfma
454+
// instruction) consists of 4 16x4 sub-tiles.
455+
std::vector<std::vector<int32_t>> registerBase = {{1, 0},
456+
{2, 0}}; // first sub-tile
457+
std::vector<std::vector<int32_t>> laneBase = {{kWidthTransRead, 0},
458+
{2 * kWidthTransRead, 0},
459+
{0, 1},
460+
{0, 2}}; // first sub-tile
461+
462+
// Extend register base for multiple tiles in K dimension (corresponding to
463+
// multiple mfma instructions accross k dim).
464+
auto populateRegisterBase = [&](int kTileSize) {
465+
const int regsPerTile = 8;
466+
int numRegs = (kSize / kTileSize) * regsPerTile;
467+
for (int reg = regsPerTile; reg < numRegs; reg *= 2) {
468+
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
469+
}
470+
};
471+
472+
const bool isMfma32 = (mfmaLayout.getMDim() == 32);
473+
const bool isMfma16 = (mfmaLayout.getMDim() == 16);
474+
const int kTileSize = isMfma32 ? 16 : 32;
475+
476+
if (kSize >= kTileSize) {
477+
// Handles mfma32x32x16 and mfma16x16x32 cases
478+
assert(kWidthDot == 8);
479+
registerBase.push_back({0, 4}); // second sub-tile
480+
populateRegisterBase(kTileSize);
481+
auto laneBaseExt = isMfma32
482+
? std::vector<std::vector<int32_t>>{{16, 0}, {0, 8}}
483+
: std::vector<std::vector<int32_t>>{{0, 8}, {0, 16}};
484+
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
485+
} else {
486+
// Handles mfma32x32x8 and mfma16x16x16 cases
487+
assert(kWidthDot == 4);
488+
auto laneBaseExt = isMfma32
489+
? std::vector<std::vector<int32_t>>{{16, 0}, {0, 4}}
490+
: std::vector<std::vector<int32_t>>{{0, 4}, {0, 8}};
491+
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
492+
}
493+
494+
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
495+
// To assign them to actual matrix dimensions `order` array is used.
496+
// For operand A: non-k-dim -> dim0, k-dim -> dim1
497+
// For operand B: non-k-dim -> dim1, k-dim -> dim0
498+
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
499+
{outDimNames[order[0]], outDimNames[order[1]]});
500+
501+
if (hasBatchDim) {
502+
assert(order[2] == 0);
503+
// Extend the base vector with one value to accommodate for the batch
504+
// dimension, which appears at the last.
505+
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
506+
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
507+
}
508+
509+
// warp order
510+
// common for both operand A and B: [0, 1] / [0, 1, 2]
511+
// in both cases it is [M dim, N dim]/[batch, M dim, N dim]
512+
SmallVector<unsigned> warpOrder = triton::gpu::getWarpOrder(dotMfmaLayout);
513+
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);
514+
515+
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
516+
warpLayout.transposeOuts(outDimNames);
517+
auto finalLayout =
518+
combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
519+
520+
return finalLayout;
521+
}
522+
394523
LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
395524
ArrayRef<int64_t> shape) {
396525

@@ -1204,4 +1333,10 @@ LinearLayout chooseLdMatrixLayout(Attribute enc, ArrayRef<int64_t> shape,
12041333
return chooseDotLdMatrixLayout(dot, shape, needTrans, elemBitWidth);
12051334
}
12061335

1336+
LinearLayout chooseDsReadB64Tr16Layout(Attribute enc, ArrayRef<int64_t> shape,
1337+
int32_t elemBitWidth) {
1338+
auto dot = cast<DotOperandEncodingAttr>(enc);
1339+
return chooseDotDsReadB64Tr16Layout(dot, shape, elemBitWidth);
1340+
}
1341+
12071342
} // namespace mlir::triton::gpu

python/test/unit/language/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7134,3 +7134,15 @@ def _simple_add(
71347134
_simple_add[grid](x, x.stride(0), x.stride(1))
71357135

71367136
assert torch.allclose(x, torch.ones_like(x) * c_dim)
7137+
7138+
7139+
@pytest.mark.interpreter
7140+
def test_aliasing(device):
7141+
7142+
@triton.jit
7143+
def aliasing_kernel(buffer, buffer2):
7144+
triton.language.store(buffer, 1)
7145+
7146+
buffer = torch.zeros(1, device=device)
7147+
aliasing_kernel[(1, )](buffer, buffer)
7148+
assert buffer[0] == 1

python/triton/runtime/interpreter.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,16 +1157,22 @@ def __init__(self, fn, arg_names, grid):
11571157
self.constexprs = [name for name in arg_names if __annotations__.get(name) == "constexpr"]
11581158

11591159
def _init_args_hst(self, args_dev, kwargs):
1160+
storages = {}
11601161

11611162
def _to_cpu(arg):
11621163
if isinstance(arg, tuple):
11631164
return _tuple_create(arg, map(_to_cpu, arg))
11641165
elif not hasattr(arg, "data_ptr"):
11651166
return arg
1167+
11661168
unwrapped_arg = _unwrap_tensor(arg)
1169+
if unwrapped_arg.untyped_storage().data_ptr() not in storages:
1170+
storage = unwrapped_arg.untyped_storage()
1171+
storages[storage.data_ptr()] = storage.cpu()
1172+
1173+
storage = storages[unwrapped_arg.untyped_storage().data_ptr()]
11671174
cpu_arg = unwrapped_arg.new_empty(0, device='cpu')
1168-
cpu_arg.set_(unwrapped_arg.untyped_storage().cpu(), unwrapped_arg.storage_offset(), unwrapped_arg.size(),
1169-
unwrapped_arg.stride())
1175+
cpu_arg.set_(storage, unwrapped_arg.storage_offset(), unwrapped_arg.size(), unwrapped_arg.stride())
11701176
cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
11711177
return cpu_arg
11721178

@@ -1175,21 +1181,17 @@ def _to_cpu(arg):
11751181
# Process keyword arguments
11761182
kwargs_hst = {}
11771183
for key, value in kwargs.items():
1178-
if hasattr(value, "data_ptr"):
1179-
kwargs_hst[key] = value.cpu()
1180-
elif isinstance(value, tuple):
1181-
return _tuple_create(value, map(_to_cpu, value))
1182-
else:
1183-
kwargs_hst[key] = value
1184+
kwargs_hst[key] = _to_cpu(value)
11841185
return args_hst, kwargs_hst
11851186

11861187
def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
1188+
storages = {}
11871189

11881190
def _from_cpu(arg_dev, arg_hst):
11891191
if hasattr(arg_dev, "data_ptr"):
11901192
# No need to rewrap because this just modifies internal
11911193
arg_dev, arg_hst = _unwrap_tensor(arg_dev), _unwrap_tensor(arg_hst)
1192-
arg_dev.untyped_storage().copy_(arg_hst.untyped_storage())
1194+
storages[arg_dev.untyped_storage().data_ptr()] = (arg_dev.untyped_storage(), arg_hst.untyped_storage())
11931195
elif isinstance(arg_dev, tuple):
11941196
for (arg_dev, arg_hst) in zip(arg_dev, arg_hst):
11951197
_from_cpu(arg_dev, arg_hst)
@@ -1202,6 +1204,9 @@ def _from_cpu(arg_dev, arg_hst):
12021204
kwarg_hst = kwargs_hst[key]
12031205
_from_cpu(kwarg_dev, kwarg_hst)
12041206

1207+
for (arg_dev, arg_hst) in storages.values():
1208+
arg_dev.copy_(arg_hst)
1209+
12051210
def __call__(self, *args_dev, **kwargs):
12061211
if kwargs.pop("warmup", False):
12071212
return

test/Conversion/amd/async_ops_to_llvm.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
8383
%arg1: i32 {tt.divisibility = 16 : i32},
8484
%arg2: !ttg.memdesc<32x64xf16, #shared, #smem, mutable>) {
8585
// The waitcnt stores all counters in one i32 bits 15:14 and 3:0 store the vmcnt we have to wait on
86-
// CHECK: rocdl.waitcnt -49168
86+
// CHECK: rocdl.s.waitcnt -49168
8787
// CHECK: rocdl.barrier
8888
ttg.async_wait {num = 0 : i32}
89-
// CHECK: rocdl.waitcnt -49167
89+
// CHECK: rocdl.s.waitcnt -49167
9090
// CHECK: rocdl.barrier
9191
ttg.async_wait {num = 1 : i32}
92-
// CHECK: rocdl.waitcnt -2
92+
// CHECK: rocdl.s.waitcnt -2
9393
// CHECK: rocdl.barrier
9494
ttg.async_wait {num = 62 : i32}
95-
// CHECK: rocdl.waitcnt -1
95+
// CHECK: rocdl.s.waitcnt -1
9696
// CHECK: rocdl.barrier
9797
ttg.async_wait {num = 63 : i32}
9898
tt.return
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s
2+
3+
#mma16 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 16], isTransposed = true}>
4+
#mma32 = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
5+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
6+
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
7+
#smem = #ttg.shared_memory
8+
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
10+
// CHECK-LABEL: ds_transpose_n_t_fp16_mfma_16
11+
tt.func @ds_transpose_n_t_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
12+
// CHECK-COUNT-32: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
13+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
14+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
15+
tt.return
16+
}
17+
18+
// CHECK-LABEL: ds_transpose_t_t_fp16_mfma_16
19+
tt.func @ds_transpose_t_t_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
20+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
21+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
22+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
23+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
24+
tt.return
25+
}
26+
27+
// CHECK-LABEL: ds_transpose_n_n_fp16_mfma_16
28+
tt.func @ds_transpose_n_n_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
29+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
30+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
31+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
32+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
33+
tt.return
34+
}
35+
36+
// CHECK-LABEL: ds_transpose_t_n_fp16_mfma_16
37+
tt.func @ds_transpose_t_n_fp16_mfma_16(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
38+
// CHECK-NOT: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
39+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma16, kWidth = 8}>>
40+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma16, kWidth = 8}>>
41+
tt.return
42+
}
43+
44+
// CHECK-LABEL: ds_transpose_n_t_fp16_mfma32
45+
tt.func @ds_transpose_n_t_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
46+
// CHECK-COUNT-32: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
47+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
48+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
49+
tt.return
50+
}
51+
52+
// CHECK-LABEL: ds_transpose_t_t_fp16_mfma32
53+
tt.func @ds_transpose_t_t_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared1, #smem, mutable>) {
54+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
55+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
56+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
57+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared1, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
58+
tt.return
59+
}
60+
61+
// CHECK-LABEL: ds_transpose_n_n_fp16_mfma32
62+
tt.func @ds_transpose_n_n_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
63+
// CHECK-COUNT-16: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
64+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
65+
// CHECK-COUNT-8: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<8xf16>
66+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
67+
tt.return
68+
}
69+
70+
// CHECK-LABEL: ds_transpose_t_n_fp16_mfma32
71+
tt.func @ds_transpose_t_n_fp16_mfma32(%arg0: !ttg.memdesc<128x64xf16, #shared1, #smem, mutable>, %arg1: !ttg.memdesc<64x128xf16, #shared, #smem, mutable>) {
72+
// CHECK-NOT: rocdl.ds.read.tr16.b64 %{{.*}} : <3> -> vector<4xf16>
73+
%1 = ttg.local_load %arg0 : !ttg.memdesc<128x64xf16, #shared1, #smem, mutable> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma32, kWidth = 8}>>
74+
%2 = ttg.local_load %arg1 : !ttg.memdesc<64x128xf16, #shared, #smem, mutable> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma32, kWidth = 8}>>
75+
tt.return
76+
}
77+
}

0 commit comments

Comments
 (0)