Skip to content

Commit dc3c13d

Browse files
Revert "[LAYOUTS] Enable generic swizzling on AMD (#7225)"
This reverts commit 68a24ff.
1 parent 4eb5d61 commit dc3c13d

34 files changed

+147
-166
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 54 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,8 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
143143
// Return the order that represents that the batch is in row-major or
144144
// column-major order for a batch of matrices of shape [*, m, n] with
145145
// len(shape) == rank.
146+
assert(rank >= 2);
146147
SmallVector<unsigned> order(rank);
147-
if (rank < 2) {
148-
return order;
149-
}
150148
std::iota(order.rbegin(), order.rend(), 0);
151149
if (!rowMajor) {
152150
std::swap(order[0], order[1]);
@@ -399,21 +397,6 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
399397
return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
400398
"order must all have the same rank.";
401399
}
402-
if (llvm::any_of(sizePerThread,
403-
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
404-
return emitError()
405-
<< "Every element in sizePerThread must be a power of two.";
406-
}
407-
if (llvm::any_of(threadsPerWarp,
408-
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
409-
return emitError()
410-
<< "Every element in threadsPerWarp must be a power of two.";
411-
}
412-
if (llvm::any_of(warpsPerCTA,
413-
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
414-
return emitError()
415-
<< "Every element in warpsPerCTA must be a power of two.";
416-
}
417400

418401
// Empty CTALayout is allowed, but if it's present its rank must match the
419402
// BlockedEncodingAttr's rank.
@@ -2013,8 +1996,6 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
20131996
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
20141997
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
20151998
return mma.getRepOrderForOperand(getOpIdx());
2016-
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
2017-
return to_vector(blocked.getOrder());
20181999
}
20192000
llvm::report_fatal_error(
20202001
"getRepOrder not implemented for DotOperandEncodingAttr");
@@ -2715,56 +2696,60 @@ struct TritonGPUVerifyTensorLayoutInterface
27152696
LogicalResult verifyTensorLayout(
27162697
Attribute layout, RankedTensorType rankedTy, Operation *op,
27172698
function_ref<InFlightDiagnostic()> makeErr) const override {
2718-
auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2719-
if (!distr)
2720-
return makeErr()
2721-
<< "Non-distributed layout is not allowed in tensor type.";
2722-
auto rank = distr.getRepOrder().size();
2723-
if (rank != rankedTy.getRank())
2724-
return makeErr() << "Layout has rank " << rank
2725-
<< ", but the tensor it's attached to has rank "
2726-
<< rankedTy.getRank() << ".";
2727-
if (llvm::any_of(rankedTy.getShape(),
2728-
[](int64_t i) { return !llvm::isPowerOf2_64(i); })) {
2729-
return makeErr() << "Layout has shape " << rankedTy.getShape()
2730-
<< ", but the tensor it's attached to has shape "
2731-
<< rankedTy.getShape()
2732-
<< " which is not a power of two.";
2733-
}
2734-
auto ll = toLinearLayout(rankedTy.getShape(), layout);
2735-
ModuleOp module = op->getParentOfType<ModuleOp>();
2736-
2737-
// Number of threads per warp.
2738-
auto kLane = StringAttr::get(module.getContext(), "lane");
2739-
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2740-
if (ll.getInDimSize(kLane) != moduleThreadsPerWarp) {
2741-
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane)
2742-
<< " threads per warp, but the module specifies "
2743-
<< moduleThreadsPerWarp << " threads per warp.";
2744-
}
2745-
2746-
// Number of warps per CTA.
2747-
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2748-
if (!moduleWarpsPerCTA) {
2749-
return makeErr()
2750-
<< "Could not determine the number of warps per CTA. Operation "
2751-
"is not in a context with `ttg.num-warps`.";
2752-
}
2753-
auto kWarp = StringAttr::get(module.getContext(), "warp");
2754-
if (ll.getInDimSize(kWarp) != *moduleWarpsPerCTA) {
2755-
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp)
2756-
<< " warps per CTA, but the context requires "
2757-
<< *moduleWarpsPerCTA << " warps per CTA.";
2758-
}
2759-
2760-
// Number of CTAs per CGA.
2761-
auto kBlock = StringAttr::get(module.getContext(), "block");
2762-
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
2763-
if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) {
2764-
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock)
2765-
<< " CTAs per CGA, but the context requires "
2766-
<< moduleCTAsPerCGA << " CTAs per CGA.";
2699+
if (isa<triton::gpu::SharedEncodingTrait>(layout))
2700+
return makeErr() << "Shared layout is not allowed on tensor type.";
2701+
// TODO(jlebar): Currently this only checks blocked layouts, but other
2702+
// layouts also have invariants!
2703+
2704+
// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2705+
if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2706+
ModuleOp module = op->getParentOfType<ModuleOp>();
2707+
2708+
// A different verifier should have checked that the layout itself is
2709+
// valid, including that threads-per-warp has the same rank as
2710+
// warps-per-block etc.
2711+
if (blocked.getRank() != rankedTy.getRank()) {
2712+
return makeErr() << layout << ".\nLayout has rank " << blocked.getRank()
2713+
<< ", but the tensor it's attached to has rank "
2714+
<< rankedTy.getRank() << ".";
2715+
}
2716+
2717+
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2718+
int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp());
2719+
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2720+
return makeErr() << layout << ".\nLayout has a total of "
2721+
<< layoutThreadsPerWarp
2722+
<< " threads per warp, but the module specifies "
2723+
<< moduleThreadsPerWarp << " threads per warp.";
2724+
}
2725+
2726+
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2727+
if (!moduleWarpsPerCTA) {
2728+
return makeErr()
2729+
<< "Could not determine the number of warps per CTA. Operation "
2730+
"is not in a context with `ttg.num-warps`.";
2731+
}
2732+
int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA());
2733+
if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2734+
return makeErr() << layout << ".\nLayout has a total of "
2735+
<< layoutWarpsPerCTA
2736+
<< " warps per CTA, but the context requires "
2737+
<< *moduleWarpsPerCTA << " warps per CTA.";
2738+
}
2739+
2740+
if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
2741+
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
2742+
int64_t layoutCTAsPerCGA =
2743+
product(blocked.getCTALayout().getCTAsPerCGA());
2744+
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
2745+
return makeErr() << layout << ".\nLayout has a total of "
2746+
<< layoutCTAsPerCGA
2747+
<< " CTAs per CGA, but the module specifies "
2748+
<< moduleCTAsPerCGA << " CTAs per CGA.";
2749+
}
2750+
}
27672751
}
2752+
27682753
return success();
27692754
}
27702755
};

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,14 @@ def test_warp_specialize():
356356
c = ttgl.arange(0, 4, layout=layout)
357357
pair = Pair(a, b)
358358
e: ttgl.constexpr = 42
359-
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default, (pair, c, e),
359+
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default,
360360
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
361361
anchor(a)
362362
anchor(b)
363363

364364
# CHECK: ttg.warp_specialize([[A]], [[B]], [[C]])
365365
# CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> ()
366-
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48])
366+
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, [warp_specialize_worker1], [4], [48])
367367

368368

369369
@gluon.jit

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,8 +3114,6 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
31143114
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
31153115
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024:
31163116
pytest.xfail("Skipping sum reduction on float16 due to accuracy issues")
3117-
if isinstance(src_layout, LinearLayout) and THREADS_PER_WARP != (1 << len(src_layout.lane)):
3118-
pytest.xfail(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane)} threads per warp")
31193117

31203118
if isinstance(src_layout, MmaLayout) and src_layout.version == 3:
31213119
src_layout.instr_shape[2] = 16 if dtype_str == "float16" else 8

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,11 @@ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None
302302

303303

304304
@builtin
305-
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
305+
def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
306306
_semantic=None, _generator=None):
307307
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
308308
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
309-
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
309+
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
310310
worker_num_regs, _generator)
311311

312312

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) ->
239239
self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
240240
for i in range(len(inputs)))
241241

242-
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
243-
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
242+
def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
243+
worker_num_regs: Sequence[int], generator):
244244
num_partitions = len(worker_partitions)
245245
assert num_partitions == len(
246246
worker_num_warps
@@ -255,7 +255,7 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p
255255
# Emit the default partition to get the result types.
256256
default_block = builder.new_block()
257257
builder.set_insertion_point_to_start(default_block)
258-
default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
258+
default_results = generator.call_JitFunction(default_partition, args, kwargs={})
259259
mlir_results = []
260260
if default_results is not None:
261261
mlir_results = flatten_values_to_ir(default_results)
@@ -264,7 +264,7 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p
264264

265265
# Create the warp specialize op.
266266
builder.restore_insertion_point(insert_pt)
267-
mlir_args = flatten_values_to_ir(worker_args)
267+
mlir_args = flatten_values_to_ir(args)
268268
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
269269
ws_op.get_default_region().push_back(default_block)
270270
ws_op.set_requested_registers(worker_num_regs)
@@ -276,7 +276,7 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p
276276
for i in range(num_partitions):
277277
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
278278
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
279-
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
279+
block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
280280
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
281281
builder.create_warp_return()
282282

python/tutorials/gluon/01-attention-forward.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def consume_result(self, tile):
589589

590590
@gluon.jit
591591
def _attn_fwd_load(config, #
592-
infos, k_load_ctx, v_load_ctx, #
592+
m_is, infos, k_load_ctx, v_load_ctx, #
593593
STAGE: gl.constexpr):
594594
prog = config.get_program()
595595
lo, hi = prog.get_loop_bounds(STAGE)
@@ -609,7 +609,7 @@ def _attn_fwd_load(config, #
609609

610610
@gluon.jit
611611
def _attn_fwd_mma(config, #
612-
infos, k_load_ctx, v_load_ctx, #
612+
m_is, infos, k_load_ctx, v_load_ctx, #
613613
STAGE: gl.constexpr):
614614
prog = config.get_program()
615615
lo, hi = prog.get_loop_bounds(STAGE)
@@ -684,7 +684,7 @@ def _attn_fwd_correction_compute(config, mi_consumer, o_consumer, m_i):
684684

685685
@gluon.jit
686686
def _attn_fwd_correction(config, #
687-
m_is, infos, #
687+
m_is, infos, k_load_ctx, v_load_ctx, #
688688
STAGE: gl.constexpr):
689689
prog = config.get_program()
690690
lo, hi = prog.get_loop_bounds(STAGE)
@@ -757,14 +757,14 @@ def _softmax_tile(tile_id: gl.constexpr, config, info, STAGE: gl.constexpr):
757757

758758
@gluon.jit
759759
def _attn_fwd_softmax0(config, #
760-
infos, k_load_ctx, v_load_ctx, #
760+
m_is, infos, k_load_ctx, v_load_ctx, #
761761
STAGE: gl.constexpr):
762762
_softmax_tile(0, config, infos[0], STAGE)
763763

764764

765765
@gluon.jit
766766
def _attn_fwd_softmax1(config, #
767-
infos, k_load_ctx, v_load_ctx, #
767+
m_is, infos, k_load_ctx, v_load_ctx, #
768768
STAGE: gl.constexpr):
769769
_softmax_tile(1, config, infos[1], STAGE)
770770

@@ -781,14 +781,10 @@ def _attn_fwd_inner(config, info0, info1, m_i0, m_i1, #
781781
config,
782782
(m_i0, m_i1),
783783
(info0, info1),
784-
STAGE,
785-
), _attn_fwd_correction, (
786-
config,
787-
(info0, info1),
788784
k_load_ctx,
789785
v_load_ctx,
790786
STAGE,
791-
), [
787+
), _attn_fwd_correction, [
792788
_attn_fwd_softmax0,
793789
_attn_fwd_softmax1,
794790
_attn_fwd_mma,

test/Analysis/test-alias.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
66
#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
77
#B_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
8-
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
8+
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
99
#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
1010
#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
1111

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5959
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
6060
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
6161
#smem = #ttg.shared_memory
62-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
62+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
6363
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
6464
// COMMON-LABEL: @local_loads_with_token_from_async_wait
6565
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -98,7 +98,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9898
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
9999
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
100100
#smem = #ttg.shared_memory
101-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
101+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
102102
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
103103
// COMMON-LABEL: @local_loads_without_token_from_async_wait
104104
tt.func public @local_loads_without_token_from_async_wait(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -137,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
137137
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
138138
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
139139
#smem = #ttg.shared_memory
140-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
140+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
141141
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
142142
// COMMON-LABEL: @local_loads_with_loop_carried_token
143143
tt.func public @local_loads_with_loop_carried_token(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},

test/Conversion/amd/compute-base-ptr.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

3-
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4-
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = false}>
3+
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0]}>
66
#smem = #ttg.shared_memory
7-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 64 : i32} {
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} {
88
// CHECK-LABEL: @local_load_offset
99
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
1010
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2167,7 +2167,7 @@ tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4x
21672167

21682168
// -----
21692169

2170-
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
2170+
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
21712171
#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
21722172
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
21732173

0 commit comments

Comments
 (0)