Skip to content

Commit 68a24ff

Browse files
lezcanoMogball
andauthored
[LAYOUTS] Enable generic swizzling on AMD (#7225)
We also fix the test that was creating an invalid linear layout. Passing by, we improve the invariants of our layouts in the IR. --------- Co-authored-by: Mogball <[email protected]>
1 parent 34758e4 commit 68a24ff

34 files changed

+166
-147
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 69 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,10 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
140140
// Return the order that represents that the batch is in row-major or
141141
// column-major order for a batch of matrices of shape [*, m, n] with
142142
// len(shape) == rank.
143-
assert(rank >= 2);
144143
SmallVector<unsigned> order(rank);
144+
if (rank < 2) {
145+
return order;
146+
}
145147
std::iota(order.rbegin(), order.rend(), 0);
146148
if (!rowMajor) {
147149
std::swap(order[0], order[1]);
@@ -394,6 +396,21 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
394396
return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
395397
"order must all have the same rank.";
396398
}
399+
if (llvm::any_of(sizePerThread,
400+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
401+
return emitError()
402+
<< "Every element in sizePerThread must be a power of two.";
403+
}
404+
if (llvm::any_of(threadsPerWarp,
405+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
406+
return emitError()
407+
<< "Every element in threadsPerWarp must be a power of two.";
408+
}
409+
if (llvm::any_of(warpsPerCTA,
410+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
411+
return emitError()
412+
<< "Every element in warpsPerCTA must be a power of two.";
413+
}
397414

398415
// Empty CTALayout is allowed, but if it's present its rank must match the
399416
// BlockedEncodingAttr's rank.
@@ -1963,6 +1980,8 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19631980
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
19641981
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
19651982
return mma.getRepOrderForOperand(getOpIdx());
1983+
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
1984+
return to_vector(blocked.getOrder());
19661985
}
19671986
llvm::report_fatal_error(
19681987
"getRepOrder not implemented for DotOperandEncodingAttr");
@@ -2660,60 +2679,56 @@ struct TritonGPUVerifyTensorLayoutInterface
26602679
LogicalResult verifyTensorLayout(
26612680
Attribute layout, RankedTensorType rankedTy, Operation *op,
26622681
function_ref<InFlightDiagnostic()> makeErr) const override {
2663-
if (isa<triton::gpu::SharedEncodingTrait>(layout))
2664-
return makeErr() << "Shared layout is not allowed on tensor type.";
2665-
// TODO(jlebar): Currently this only checks blocked layouts, but other
2666-
// layouts also have invariants!
2667-
2668-
// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2669-
if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2670-
ModuleOp module = op->getParentOfType<ModuleOp>();
2671-
2672-
// A different verifier should have checked that the layout itself is
2673-
// valid, including that threads-per-warp has the same rank as
2674-
// warps-per-block etc.
2675-
if (blocked.getRank() != rankedTy.getRank()) {
2676-
return makeErr() << layout << ".\nLayout has rank " << blocked.getRank()
2677-
<< ", but the tensor it's attached to has rank "
2678-
<< rankedTy.getRank() << ".";
2679-
}
2680-
2681-
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2682-
int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp());
2683-
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2684-
return makeErr() << layout << ".\nLayout has a total of "
2685-
<< layoutThreadsPerWarp
2686-
<< " threads per warp, but the module specifies "
2687-
<< moduleThreadsPerWarp << " threads per warp.";
2688-
}
2689-
2690-
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2691-
if (!moduleWarpsPerCTA) {
2692-
return makeErr()
2693-
<< "Could not determine the number of warps per CTA. Operation "
2694-
"is not in a context with `ttg.num-warps`.";
2695-
}
2696-
int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA());
2697-
if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2698-
return makeErr() << layout << ".\nLayout has a total of "
2699-
<< layoutWarpsPerCTA
2700-
<< " warps per CTA, but the context requires "
2701-
<< *moduleWarpsPerCTA << " warps per CTA.";
2702-
}
2703-
2704-
if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
2705-
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
2706-
int64_t layoutCTAsPerCGA =
2707-
product(blocked.getCTALayout().getCTAsPerCGA());
2708-
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
2709-
return makeErr() << layout << ".\nLayout has a total of "
2710-
<< layoutCTAsPerCGA
2711-
<< " CTAs per CGA, but the module specifies "
2712-
<< moduleCTAsPerCGA << " CTAs per CGA.";
2713-
}
2714-
}
2682+
auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2683+
if (!distr)
2684+
return makeErr()
2685+
<< "Non-distributed layout is not allowed in tensor type.";
2686+
auto rank = distr.getRepOrder().size();
2687+
if (rank != rankedTy.getRank())
2688+
return makeErr() << "Layout has rank " << rank
2689+
<< ", but the tensor it's attached to has rank "
2690+
<< rankedTy.getRank() << ".";
2691+
if (llvm::any_of(rankedTy.getShape(),
2692+
[](int64_t i) { return !llvm::isPowerOf2_64(i); })) {
2693+
return makeErr() << "Layout has shape " << rankedTy.getShape()
2694+
<< ", but the tensor it's attached to has shape "
2695+
<< rankedTy.getShape()
2696+
<< " which is not a power of two.";
2697+
}
2698+
auto ll = toLinearLayout(rankedTy.getShape(), layout);
2699+
ModuleOp module = op->getParentOfType<ModuleOp>();
2700+
2701+
// Number of threads per warp.
2702+
auto kLane = StringAttr::get(module.getContext(), "lane");
2703+
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2704+
if (ll.getInDimSize(kLane) != moduleThreadsPerWarp) {
2705+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane)
2706+
<< " threads per warp, but the module specifies "
2707+
<< moduleThreadsPerWarp << " threads per warp.";
2708+
}
2709+
2710+
// Number of warps per CTA.
2711+
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2712+
if (!moduleWarpsPerCTA) {
2713+
return makeErr()
2714+
<< "Could not determine the number of warps per CTA. Operation "
2715+
"is not in a context with `ttg.num-warps`.";
2716+
}
2717+
auto kWarp = StringAttr::get(module.getContext(), "warp");
2718+
if (ll.getInDimSize(kWarp) != *moduleWarpsPerCTA) {
2719+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp)
2720+
<< " warps per CTA, but the context requires "
2721+
<< *moduleWarpsPerCTA << " warps per CTA.";
2722+
}
2723+
2724+
// Number of CTAs per CGA.
2725+
auto kBlock = StringAttr::get(module.getContext(), "block");
2726+
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
2727+
if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) {
2728+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock)
2729+
<< " CTAs per CGA, but the context requires "
2730+
<< moduleCTAsPerCGA << " CTAs per CGA.";
27152731
}
2716-
27172732
return success();
27182733
}
27192734
};

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,14 @@ def test_warp_specialize():
356356
c = ttgl.arange(0, 4, layout=layout)
357357
pair = Pair(a, b)
358358
e: ttgl.constexpr = 42
359-
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default,
359+
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default, (pair, c, e),
360360
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
361361
anchor(a)
362362
anchor(b)
363363

364364
# CHECK: ttg.warp_specialize([[A]], [[B]], [[C]])
365365
# CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> ()
366-
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, [warp_specialize_worker1], [4], [48])
366+
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48])
367367

368368

369369
@gluon.jit

python/test/unit/language/test_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3080,6 +3080,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
30803080
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
30813081
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024:
30823082
pytest.skip("Skipping sum reduction on float16 due to accuracy issues")
3083+
if isinstance(src_layout, LinearLayout) and THREADS_PER_WARP != (1 << len(src_layout.lane)):
3084+
pytest.skip(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane)} threads per warp")
30833085

30843086
if isinstance(src_layout, MmaLayout) and src_layout.version == 3:
30853087
src_layout.instr_shape[2] = 16 if dtype_str == "float16" else 8

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,11 @@ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None
302302

303303

304304
@builtin
305-
def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
305+
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
306306
_semantic=None, _generator=None):
307307
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
308308
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
309-
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
309+
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
310310
worker_num_regs, _generator)
311311

312312

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) ->
239239
self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
240240
for i in range(len(inputs)))
241241

242-
def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
243-
worker_num_regs: Sequence[int], generator):
242+
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
243+
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
244244
num_partitions = len(worker_partitions)
245245
assert num_partitions == len(
246246
worker_num_warps
@@ -255,7 +255,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
255255
# Emit the default partition to get the result types.
256256
default_block = builder.new_block()
257257
builder.set_insertion_point_to_start(default_block)
258-
default_results = generator.call_JitFunction(default_partition, args, kwargs={})
258+
default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
259259
mlir_results = []
260260
if default_results is not None:
261261
mlir_results = flatten_values_to_ir(default_results)
@@ -264,7 +264,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
264264

265265
# Create the warp specialize op.
266266
builder.restore_insertion_point(insert_pt)
267-
mlir_args = flatten_values_to_ir(args)
267+
mlir_args = flatten_values_to_ir(worker_args)
268268
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
269269
ws_op.get_default_region().push_back(default_block)
270270
ws_op.set_requested_registers(worker_num_regs)
@@ -276,7 +276,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
276276
for i in range(num_partitions):
277277
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
278278
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
279-
block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
279+
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
280280
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
281281
builder.create_warp_return()
282282

python/tutorials/gluon/01-attention-forward.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def consume_result(self, tile):
589589

590590
@gluon.jit
591591
def _attn_fwd_load(config, #
592-
m_is, infos, k_load_ctx, v_load_ctx, #
592+
infos, k_load_ctx, v_load_ctx, #
593593
STAGE: gl.constexpr):
594594
prog = config.get_program()
595595
lo, hi = prog.get_loop_bounds(STAGE)
@@ -609,7 +609,7 @@ def _attn_fwd_load(config, #
609609

610610
@gluon.jit
611611
def _attn_fwd_mma(config, #
612-
m_is, infos, k_load_ctx, v_load_ctx, #
612+
infos, k_load_ctx, v_load_ctx, #
613613
STAGE: gl.constexpr):
614614
prog = config.get_program()
615615
lo, hi = prog.get_loop_bounds(STAGE)
@@ -684,7 +684,7 @@ def _attn_fwd_correction_compute(config, mi_consumer, o_consumer, m_i):
684684

685685
@gluon.jit
686686
def _attn_fwd_correction(config, #
687-
m_is, infos, k_load_ctx, v_load_ctx, #
687+
m_is, infos, #
688688
STAGE: gl.constexpr):
689689
prog = config.get_program()
690690
lo, hi = prog.get_loop_bounds(STAGE)
@@ -757,14 +757,14 @@ def _softmax_tile(tile_id: gl.constexpr, config, info, STAGE: gl.constexpr):
757757

758758
@gluon.jit
759759
def _attn_fwd_softmax0(config, #
760-
m_is, infos, k_load_ctx, v_load_ctx, #
760+
infos, k_load_ctx, v_load_ctx, #
761761
STAGE: gl.constexpr):
762762
_softmax_tile(0, config, infos[0], STAGE)
763763

764764

765765
@gluon.jit
766766
def _attn_fwd_softmax1(config, #
767-
m_is, infos, k_load_ctx, v_load_ctx, #
767+
infos, k_load_ctx, v_load_ctx, #
768768
STAGE: gl.constexpr):
769769
_softmax_tile(1, config, infos[1], STAGE)
770770

@@ -781,10 +781,14 @@ def _attn_fwd_inner(config, info0, info1, m_i0, m_i1, #
781781
config,
782782
(m_i0, m_i1),
783783
(info0, info1),
784+
STAGE,
785+
), _attn_fwd_correction, (
786+
config,
787+
(info0, info1),
784788
k_load_ctx,
785789
v_load_ctx,
786790
STAGE,
787-
), _attn_fwd_correction, [
791+
), [
788792
_attn_fwd_softmax0,
789793
_attn_fwd_softmax1,
790794
_attn_fwd_mma,

test/Analysis/test-alias.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
66
#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
77
#B_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
8-
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
8+
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
99
#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
1010
#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
1111

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5959
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
6060
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
6161
#smem = #ttg.shared_memory
62-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
62+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
6363
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
6464
// COMMON-LABEL: @local_loads_with_token_from_async_wait
6565
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -98,7 +98,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9898
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
9999
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
100100
#smem = #ttg.shared_memory
101-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
101+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
102102
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
103103
// COMMON-LABEL: @local_loads_without_token_from_async_wait
104104
tt.func public @local_loads_without_token_from_async_wait(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -137,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
137137
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
138138
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
139139
#smem = #ttg.shared_memory
140-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
140+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
141141
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
142142
// COMMON-LABEL: @local_loads_with_loop_carried_token
143143
tt.func public @local_loads_with_loop_carried_token(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},

test/Conversion/amd/compute-base-ptr.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

3-
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4-
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
3+
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0]}>
66
#smem = #ttg.shared_memory
7-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} {
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 64 : i32} {
88
// CHECK-LABEL: @local_load_offset
99
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
1010
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2167,7 +2167,7 @@ tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4x
21672167

21682168
// -----
21692169

2170-
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
2170+
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
21712171
#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
21722172
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
21732173

0 commit comments

Comments
 (0)