Skip to content

Commit 30f0d5d

Browse files
committed
Merge commit '78c8054298a81f578dcd8c79b519981c57dfb665' into amyachev/merge0
2 parents 0df7d80 + 78c8054 commit 30f0d5d

File tree

12 files changed

+287
-52
lines changed

12 files changed

+287
-52
lines changed

include/triton/Analysis/Utility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
212212

213213
bool atomicNeedsSharedMemory(Value result);
214214

215-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstT);
215+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
216216

217217
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218218

include/triton/Tools/LinearLayout.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,13 @@ class LinearLayout {
679679
// (i.e. every input bit affects the output).
680680
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks() const;
681681

682+
// Increase an input dimension without affecting the output dimension. The
683+
// added free variables are mapped to 0, ensuring that the new input
684+
// dimensions correspond directly to the existing output space. The function
685+
// errors out if `newInDimSize` is less than the current size or the new size
686+
// is not a power of 2.
687+
LinearLayout resize(StringAttr inDim, int32_t newInDimSize) const;
688+
682689
std::string toString() const;
683690

684691
friend bool operator==(LinearLayout lhs, LinearLayout rhs);

lib/Analysis/Utility.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ bool supportMMA(Value value, int version) {
543543
(elemTy.isInteger(8) && version >= 2);
544544
}
545545

546-
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
546+
bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
547547
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
548548
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
549549
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
@@ -662,8 +662,46 @@ std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
662662
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
663663
if (!(srcLayout.has_value() && dstLayout.has_value()))
664664
return std::nullopt;
665+
StringAttr kRegister = StringAttr::get(ctx, "register");
666+
StringAttr kLane = StringAttr::get(ctx, "lane");
667+
StringAttr kWarp = StringAttr::get(ctx, "warp");
668+
StringAttr kBlock = StringAttr::get(ctx, "block");
669+
auto numSrcRegs = srcLayout->getInDimSize(kRegister);
670+
auto numDstRegs = dstLayout->getInDimSize(kRegister);
671+
// The `invertAndCompose` function will generate a layout that is injective
672+
// by assigning new output dimensions to free variables. For instance,
673+
// consider a scenario where `srcLayout` has a free variable in the lane
674+
// dimension, while `dstLayout` has two free variables in the lane
675+
// dimension and also a larger number of registers.
676+
// The injective form of `srcLayout` will add only a single additional row
677+
// to the transformation matrix, whereas the injective form of `dstLayout`
678+
// will add two additional rows. This discrepancy causes misleading results
679+
// because the matrices end up with a different number of rows.
680+
//
681+
// Take `dstLayout ⋅ srcLayout^-1` as an example:
682+
//
683+
// - `injective(dstLayout)`: [n, m] → [n + 2, m]
684+
// - `injective(srcLayout)`: [n, m] → [n + 1, m]
685+
// - `injective(srcLayout)^-1`: [n + 1, m] → [m, n + 1]
686+
// - `injective(dstLayout) ⋅ injective(srcLayout)^-1`: [n + 2, m] ⋅ [m, n +
687+
// 1] → [n + 2, n + 1]
688+
//
689+
// Here, the `(n + 1)`-th row added by `dstLayout` represents the free
690+
// variable in registers, and the `(n + 2)`-th row represents the free
691+
// variable in lanes. However, the `(n + 1)`-th row added by `srcLayout`
692+
// represents the free variable in lanes. As a result, the `(n + 1)`-th row
693+
// in two layouts do not correspond to the same free variable.
694+
//
695+
// To address this issue, we pad the free variables in `srcLayout` and
696+
// `dstLayout` to ensure they have the same number of registers. This
697+
// guarantees that the resulting matrices have the same number of rows,
698+
// ensuring consistency in the composition process.
699+
auto numRegs = std::max(numSrcRegs, numDstRegs);
700+
auto srcLayoutWithFreeRegs = srcLayout->resize(kRegister, numRegs);
701+
auto dstLayoutWithFreeRegs = dstLayout->resize(kRegister, numRegs);
665702
// comp describes the layout function to create dst from src.
666-
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
703+
LinearLayout comp =
704+
dstLayoutWithFreeRegs.invertAndCompose(srcLayoutWithFreeRegs);
667705
// We try to quotient by the largest subspace first
668706
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
669707
for (auto dim : dims) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -288,60 +288,71 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
288288
return rewriter.notifyMatchFailure(
289289
op, "NYI. srcTy and/or dstTy don't implement LLs yet");
290290
}
291+
LinearLayout srcLayout =
292+
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
293+
LinearLayout dstLayout =
294+
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
295+
296+
StringAttr kBlock = str_attr("block");
297+
StringAttr kWarp = str_attr("warp");
298+
StringAttr kLane = str_attr("lane");
299+
StringAttr kRegister = str_attr("register");
291300

292301
assert(to_vector(conversion->getInDimNames()) ==
293302
to_vector(conversion->getOutDimNames()));
294303
auto dims = conversion->getInDimNames();
295-
if (llvm::is_contained(dims, str_attr("block"))) {
304+
if (llvm::is_contained(dims, kBlock)) {
296305
// Case 1: Transfer between values in different CTAs.
297306
// This requires moving values through distributed shared memory.
298307
return rewriter.notifyMatchFailure(
299308
op, "NYI: Transfer between different CTAs");
300-
} else if (llvm::is_contained(dims, str_attr("warp"))) {
309+
} else if (llvm::is_contained(dims, kWarp)) {
301310
// Case 2: Transfer between values in the same CTA, in which case we move
302311
// values through shared memory.
303-
LinearLayout srcLayout =
304-
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
305-
LinearLayout dstLayout =
306-
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
307312
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
308-
} else if (llvm::is_contained(dims, str_attr("lane"))) {
313+
} else if (llvm::is_contained(dims, kLane)) {
309314
// Case 3. Transfer between values in the same warp, in which case we try
310315
// to move values using warp shuffles, though if the pattern is
311316
// complicated enough we may fall back to using shared memory
312317
// TODO(Keren): implement warp shuffle instead of using the general
313318
// approach that uses shared memory
314-
LinearLayout srcLayout =
315-
*toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
316-
LinearLayout dstLayout =
317-
*toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
318319
return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter);
319-
} else if (llvm::is_contained(dims, str_attr("register"))) {
320+
} else if (llvm::is_contained(dims, kRegister) ||
321+
dstLayout.getInDimSize(kRegister) !=
322+
srcLayout.getInDimSize(kRegister)) {
320323
// Case 4. Transfer between values in the same thread, in which case we
321324
// simply reorder the elements of adaptor.getSrc().
322-
return transferWithinThread(op, *conversion, adaptor, rewriter);
325+
return transferWithinThread(
326+
op, dstLayout.getFreeVariableMasks()[kRegister],
327+
dstLayout.getInDimSize(kRegister), *conversion, adaptor, rewriter);
323328
} else {
324-
// The two layouts are equivalent. We should probably remove these in
325-
// RemoveLayoutConversion.
329+
// Cast 5. The two layouts are equivalent. We should probably remove
330+
// these in RemoveLayoutConversion.
326331
rewriter.replaceOp(op, adaptor.getSrc());
327332
return success();
328333
}
329334
}
330335

331336
LogicalResult
332-
transferWithinThread(ConvertLayoutOp op, const LinearLayout &conversion,
333-
OpAdaptor adaptor,
337+
transferWithinThread(ConvertLayoutOp op, int32_t regMasks, int32_t numRegs,
338+
const LinearLayout &conversion, OpAdaptor adaptor,
334339
ConversionPatternRewriter &rewriter) const {
335340
MLIRContext *ctx = op.getContext();
336341
auto loc = op.getLoc();
337342
StringAttr kRegister = str_attr("register");
338343
assert(!cvtNeedsSharedMemory(op.getSrc().getType(), op.getType()));
339344

340345
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
341-
SmallVector<Value> outVals;
342-
outVals.resize(conversion.getInDimSize(kRegister));
343-
for (int i = 0; i < conversion.getInDimSize(kRegister); i++) {
344-
auto srcIdx = conversion.apply({{kRegister, i}}).begin()->second;
346+
SmallVector<Value> outVals(numRegs);
347+
for (int i = 0; i < outVals.size(); i++) {
348+
// Remove free masks from the register index
349+
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350+
// 0b00011. It means that register 7 (0b111) has the same value as
351+
// register 3 (0b011).
352+
auto idx = i & (~regMasks);
353+
auto srcIdx = conversion.hasInDim(kRegister)
354+
? conversion.apply({{kRegister, idx}}).begin()->second
355+
: idx;
345356
outVals[i] = inVals[srcIdx];
346357
}
347358
Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter,

lib/Tools/LinearLayout.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,21 @@ bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const {
10161016
return true;
10171017
}
10181018

1019+
LinearLayout LinearLayout::resize(StringAttr inDim,
1020+
int32_t newInDimSize) const {
1021+
BasesT bases = getBases();
1022+
assert(bases.contains(inDim) && "inDim not in layout");
1023+
assert(llvm::isPowerOf2_32(newInDimSize) &&
1024+
"newInDimSize must be a power of 2");
1025+
assert(newInDimSize >= getInDimSize(inDim) &&
1026+
"newInDimSize must be >= old size");
1027+
auto numFreeVariables = llvm::Log2_32(newInDimSize) - getInDimSizeLog2(inDim);
1028+
for (int i = 0; i < numFreeVariables; i++) {
1029+
bases[inDim].push_back(std::vector<int32_t>(getNumOutDims(), 0));
1030+
}
1031+
return LinearLayout(std::move(bases), llvm::to_vector(getOutDimNames()));
1032+
}
1033+
10191034
std::string LinearLayout::toString() const {
10201035
// Start with a newline because we print out a bulleted list; it doesn't
10211036
// make sense for the first line of this list to be on the same line as

python/test/unit/language/test_core.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,18 +1490,30 @@ def kernel(X):
14901490
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
14911491
for axis in [0, 1]
14921492
for num_ctas in num_ctas_list
1493-
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
1493+
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
14941494
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
14951495
check_type_supported(dtype_x_str, device)
1496+
if is_interpreter() and dtype_x_str == 'float16':
1497+
pytest.skip('float16 atomic_add does not work in the interpreter mode')
14961498
shape0, shape1 = shape
14971499
# triton kernel
14981500

14991501
@triton.jit
1500-
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
1502+
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr):
15011503
off0 = tl.arange(0, SHAPE0)
15021504
off1 = tl.arange(0, SHAPE1)
15031505
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
1506+
1507+
if DTYPE == tl.float16:
1508+
# sum can have bad numerics when accumulating in float16.
1509+
# if we're dealing with float16, do the sum in float32.
1510+
x = x.to(tl.float32)
1511+
15041512
z = tl.sum(x, axis=AXIS)
1513+
1514+
if DTYPE == tl.float16:
1515+
z = z.to(DTYPE)
1516+
15051517
if AXIS == 1:
15061518
old = tl.atomic_add(Z + off0, z)
15071519
tl.store(OLD + off0, old)
@@ -1515,13 +1527,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
15151527
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
15161528
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
15171529
# reference results
1518-
z_ref = z + np.sum(x, axis=axis, keepdims=False)
1530+
if x.dtype == np.float16:
1531+
# do the sum in float32 to reduce numerical variation
1532+
z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype)
1533+
else:
1534+
z_ref = z + np.sum(x, axis=axis, keepdims=False)
15191535
old_ref = np.copy(z)
15201536
# triton result
15211537
x_tri = to_triton(x, device=device)
15221538
z_tri = to_triton(z, device=device)
15231539
old_tri = to_triton(old, device=device)
1524-
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
1540+
1541+
def torch_to_triton_dtype(t):
1542+
if t == torch.float16:
1543+
return tl.float16
1544+
return None
1545+
1546+
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas)
15251547
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
15261548
np.testing.assert_equal(old_ref, to_numpy(old_tri))
15271549

python/triton/compiler/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import functools
1717
import os
18+
import sysconfig
1819

1920
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
2021
# and any following whitespace
@@ -151,7 +152,8 @@ def triton_key():
151152

152153
# backend
153154
libtriton_hash = hashlib.sha256()
154-
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
155+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
156+
with open(os.path.join(TRITON_PATH, f"_C/libtriton.{ext}"), "rb") as f:
155157
while True:
156158
chunk = f.read(1024**2)
157159
if not chunk:

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
6262
tt.return
6363
}
6464
}
65+
66+
// -----
67+
68+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
69+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
70+
// CHECK-LABEL: atomic_add_f16
71+
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
72+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
73+
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
74+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
75+
// CHECK: llvm.cond_br
76+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
77+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
78+
tt.return
79+
}
80+
}
81+
82+
// -----
83+
84+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
85+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
86+
// CHECK-LABEL: atomic_add_bf16
87+
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
88+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
89+
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
90+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
91+
// CHECK: llvm.cond_br
92+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
93+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
94+
tt.return
95+
}
96+
}

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
847847

848848
// -----
849849

850+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}>
851+
#dot1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
852+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
853+
// CHECK-LABEL: convert_layout_mmav2_dot_reg
854+
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
855+
// CHECK-NOT: st.shared
856+
// CHECK-NOT: llvm.load
857+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
858+
tt.return
859+
}
860+
}
861+
862+
// -----
863+
864+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
865+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
866+
867+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
868+
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
869+
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
870+
// CHECK-NOT: st.shared
871+
// CHECK-NOT: llvm.load
872+
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
873+
tt.return
874+
}
875+
}
876+
877+
// -----
878+
879+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
880+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
881+
882+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
883+
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
884+
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
885+
// CHECK-NOT: st.shared
886+
// CHECK-NOT: llvm.load
887+
%0 = triton_gpu.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
888+
tt.return
889+
}
890+
}
891+
892+
// -----
893+
894+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
895+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
896+
897+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
898+
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
899+
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
900+
// CHECK-NOT: st.shared
901+
// CHECK-NOT: llvm.load
902+
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
903+
tt.return
904+
}
905+
}
906+
907+
// -----
908+
909+
#mma0 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 64, 16]}>
910+
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
911+
912+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
913+
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
914+
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
915+
// CHECK-NOT: st.shared
916+
// CHECK-NOT: llvm.load
917+
%0 = triton_gpu.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
918+
tt.return
919+
}
920+
}
921+
922+
// -----
923+
850924
#blocked = #triton_gpu.blocked<{sizePerThread = [16, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
851925
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
852926
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} {

0 commit comments

Comments
 (0)