Skip to content

Commit 4eb5d61

Browse files
Merge commit '68a24ff70cf59001e9fd216374620cc1a6071c5a'
2 parents 50fc4c3 + 68a24ff commit 4eb5d61

34 files changed

+166
-147
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 69 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,10 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
143143
// Return the order that represents that the batch is in row-major or
144144
// column-major order for a batch of matrices of shape [*, m, n] with
145145
// len(shape) == rank.
146-
assert(rank >= 2);
147146
SmallVector<unsigned> order(rank);
147+
if (rank < 2) {
148+
return order;
149+
}
148150
std::iota(order.rbegin(), order.rend(), 0);
149151
if (!rowMajor) {
150152
std::swap(order[0], order[1]);
@@ -397,6 +399,21 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
397399
return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
398400
"order must all have the same rank.";
399401
}
402+
if (llvm::any_of(sizePerThread,
403+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
404+
return emitError()
405+
<< "Every element in sizePerThread must be a power of two.";
406+
}
407+
if (llvm::any_of(threadsPerWarp,
408+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
409+
return emitError()
410+
<< "Every element in threadsPerWarp must be a power of two.";
411+
}
412+
if (llvm::any_of(warpsPerCTA,
413+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
414+
return emitError()
415+
<< "Every element in warpsPerCTA must be a power of two.";
416+
}
400417

401418
// Empty CTALayout is allowed, but if it's present its rank must match the
402419
// BlockedEncodingAttr's rank.
@@ -1996,6 +2013,8 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19962013
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
19972014
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
19982015
return mma.getRepOrderForOperand(getOpIdx());
2016+
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
2017+
return to_vector(blocked.getOrder());
19992018
}
20002019
llvm::report_fatal_error(
20012020
"getRepOrder not implemented for DotOperandEncodingAttr");
@@ -2696,60 +2715,56 @@ struct TritonGPUVerifyTensorLayoutInterface
26962715
LogicalResult verifyTensorLayout(
26972716
Attribute layout, RankedTensorType rankedTy, Operation *op,
26982717
function_ref<InFlightDiagnostic()> makeErr) const override {
2699-
if (isa<triton::gpu::SharedEncodingTrait>(layout))
2700-
return makeErr() << "Shared layout is not allowed on tensor type.";
2701-
// TODO(jlebar): Currently this only checks blocked layouts, but other
2702-
// layouts also have invariants!
2703-
2704-
// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2705-
if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2706-
ModuleOp module = op->getParentOfType<ModuleOp>();
2707-
2708-
// A different verifier should have checked that the layout itself is
2709-
// valid, including that threads-per-warp has the same rank as
2710-
// warps-per-block etc.
2711-
if (blocked.getRank() != rankedTy.getRank()) {
2712-
return makeErr() << layout << ".\nLayout has rank " << blocked.getRank()
2713-
<< ", but the tensor it's attached to has rank "
2714-
<< rankedTy.getRank() << ".";
2715-
}
2716-
2717-
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2718-
int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp());
2719-
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2720-
return makeErr() << layout << ".\nLayout has a total of "
2721-
<< layoutThreadsPerWarp
2722-
<< " threads per warp, but the module specifies "
2723-
<< moduleThreadsPerWarp << " threads per warp.";
2724-
}
2725-
2726-
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2727-
if (!moduleWarpsPerCTA) {
2728-
return makeErr()
2729-
<< "Could not determine the number of warps per CTA. Operation "
2730-
"is not in a context with `ttg.num-warps`.";
2731-
}
2732-
int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA());
2733-
if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2734-
return makeErr() << layout << ".\nLayout has a total of "
2735-
<< layoutWarpsPerCTA
2736-
<< " warps per CTA, but the context requires "
2737-
<< *moduleWarpsPerCTA << " warps per CTA.";
2738-
}
2739-
2740-
if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
2741-
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
2742-
int64_t layoutCTAsPerCGA =
2743-
product(blocked.getCTALayout().getCTAsPerCGA());
2744-
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
2745-
return makeErr() << layout << ".\nLayout has a total of "
2746-
<< layoutCTAsPerCGA
2747-
<< " CTAs per CGA, but the module specifies "
2748-
<< moduleCTAsPerCGA << " CTAs per CGA.";
2749-
}
2750-
}
2718+
auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2719+
if (!distr)
2720+
return makeErr()
2721+
<< "Non-distributed layout is not allowed in tensor type.";
2722+
auto rank = distr.getRepOrder().size();
2723+
if (rank != rankedTy.getRank())
2724+
return makeErr() << "Layout has rank " << rank
2725+
<< ", but the tensor it's attached to has rank "
2726+
<< rankedTy.getRank() << ".";
2727+
if (llvm::any_of(rankedTy.getShape(),
2728+
[](int64_t i) { return !llvm::isPowerOf2_64(i); })) {
2729+
return makeErr() << "Layout has shape " << rankedTy.getShape()
2730+
<< ", but the tensor it's attached to has shape "
2731+
<< rankedTy.getShape()
2732+
<< " which is not a power of two.";
2733+
}
2734+
auto ll = toLinearLayout(rankedTy.getShape(), layout);
2735+
ModuleOp module = op->getParentOfType<ModuleOp>();
2736+
2737+
// Number of threads per warp.
2738+
auto kLane = StringAttr::get(module.getContext(), "lane");
2739+
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2740+
if (ll.getInDimSize(kLane) != moduleThreadsPerWarp) {
2741+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane)
2742+
<< " threads per warp, but the module specifies "
2743+
<< moduleThreadsPerWarp << " threads per warp.";
2744+
}
2745+
2746+
// Number of warps per CTA.
2747+
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2748+
if (!moduleWarpsPerCTA) {
2749+
return makeErr()
2750+
<< "Could not determine the number of warps per CTA. Operation "
2751+
"is not in a context with `ttg.num-warps`.";
2752+
}
2753+
auto kWarp = StringAttr::get(module.getContext(), "warp");
2754+
if (ll.getInDimSize(kWarp) != *moduleWarpsPerCTA) {
2755+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp)
2756+
<< " warps per CTA, but the context requires "
2757+
<< *moduleWarpsPerCTA << " warps per CTA.";
2758+
}
2759+
2760+
// Number of CTAs per CGA.
2761+
auto kBlock = StringAttr::get(module.getContext(), "block");
2762+
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
2763+
if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) {
2764+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock)
2765+
<< " CTAs per CGA, but the context requires "
2766+
<< moduleCTAsPerCGA << " CTAs per CGA.";
27512767
}
2752-
27532768
return success();
27542769
}
27552770
};

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,14 +356,14 @@ def test_warp_specialize():
356356
c = ttgl.arange(0, 4, layout=layout)
357357
pair = Pair(a, b)
358358
e: ttgl.constexpr = 42
359-
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default,
359+
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default, (pair, c, e),
360360
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
361361
anchor(a)
362362
anchor(b)
363363

364364
# CHECK: ttg.warp_specialize([[A]], [[B]], [[C]])
365365
# CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> ()
366-
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, [warp_specialize_worker1], [4], [48])
366+
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48])
367367

368368

369369
@gluon.jit

python/test/unit/language/test_core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3114,6 +3114,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
31143114
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
31153115
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024:
31163116
pytest.xfail("Skipping sum reduction on float16 due to accuracy issues")
3117+
if isinstance(src_layout, LinearLayout) and THREADS_PER_WARP != (1 << len(src_layout.lane)):
3118+
pytest.xfail(f"Skipping. This LinearLayout assumes {1 << len(src_layout.lane)} threads per warp")
31173119

31183120
if isinstance(src_layout, MmaLayout) and src_layout.version == 3:
31193121
src_layout.instr_shape[2] = 16 if dtype_str == "float16" else 8

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,11 @@ def allocate_shared_memory(element_ty, shape, layout, value=None, _semantic=None
302302

303303

304304
@builtin
305-
def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
305+
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
306306
_semantic=None, _generator=None):
307307
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
308308
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
309-
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
309+
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
310310
worker_num_regs, _generator)
311311

312312

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) ->
239239
self.wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape, ret_layout)
240240
for i in range(len(inputs)))
241241

242-
def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
243-
worker_num_regs: Sequence[int], generator):
242+
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
243+
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
244244
num_partitions = len(worker_partitions)
245245
assert num_partitions == len(
246246
worker_num_warps
@@ -255,7 +255,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
255255
# Emit the default partition to get the result types.
256256
default_block = builder.new_block()
257257
builder.set_insertion_point_to_start(default_block)
258-
default_results = generator.call_JitFunction(default_partition, args, kwargs={})
258+
default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
259259
mlir_results = []
260260
if default_results is not None:
261261
mlir_results = flatten_values_to_ir(default_results)
@@ -264,7 +264,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
264264

265265
# Create the warp specialize op.
266266
builder.restore_insertion_point(insert_pt)
267-
mlir_args = flatten_values_to_ir(args)
267+
mlir_args = flatten_values_to_ir(worker_args)
268268
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
269269
ws_op.get_default_region().push_back(default_block)
270270
ws_op.set_requested_registers(worker_num_regs)
@@ -276,7 +276,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
276276
for i in range(num_partitions):
277277
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
278278
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
279-
block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
279+
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
280280
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={})
281281
builder.create_warp_return()
282282

python/tutorials/gluon/01-attention-forward.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def consume_result(self, tile):
589589

590590
@gluon.jit
591591
def _attn_fwd_load(config, #
592-
m_is, infos, k_load_ctx, v_load_ctx, #
592+
infos, k_load_ctx, v_load_ctx, #
593593
STAGE: gl.constexpr):
594594
prog = config.get_program()
595595
lo, hi = prog.get_loop_bounds(STAGE)
@@ -609,7 +609,7 @@ def _attn_fwd_load(config, #
609609

610610
@gluon.jit
611611
def _attn_fwd_mma(config, #
612-
m_is, infos, k_load_ctx, v_load_ctx, #
612+
infos, k_load_ctx, v_load_ctx, #
613613
STAGE: gl.constexpr):
614614
prog = config.get_program()
615615
lo, hi = prog.get_loop_bounds(STAGE)
@@ -684,7 +684,7 @@ def _attn_fwd_correction_compute(config, mi_consumer, o_consumer, m_i):
684684

685685
@gluon.jit
686686
def _attn_fwd_correction(config, #
687-
m_is, infos, k_load_ctx, v_load_ctx, #
687+
m_is, infos, #
688688
STAGE: gl.constexpr):
689689
prog = config.get_program()
690690
lo, hi = prog.get_loop_bounds(STAGE)
@@ -757,14 +757,14 @@ def _softmax_tile(tile_id: gl.constexpr, config, info, STAGE: gl.constexpr):
757757

758758
@gluon.jit
759759
def _attn_fwd_softmax0(config, #
760-
m_is, infos, k_load_ctx, v_load_ctx, #
760+
infos, k_load_ctx, v_load_ctx, #
761761
STAGE: gl.constexpr):
762762
_softmax_tile(0, config, infos[0], STAGE)
763763

764764

765765
@gluon.jit
766766
def _attn_fwd_softmax1(config, #
767-
m_is, infos, k_load_ctx, v_load_ctx, #
767+
infos, k_load_ctx, v_load_ctx, #
768768
STAGE: gl.constexpr):
769769
_softmax_tile(1, config, infos[1], STAGE)
770770

@@ -781,10 +781,14 @@ def _attn_fwd_inner(config, info0, info1, m_i0, m_i1, #
781781
config,
782782
(m_i0, m_i1),
783783
(info0, info1),
784+
STAGE,
785+
), _attn_fwd_correction, (
786+
config,
787+
(info0, info1),
784788
k_load_ctx,
785789
v_load_ctx,
786790
STAGE,
787-
), _attn_fwd_correction, [
791+
), [
788792
_attn_fwd_softmax0,
789793
_attn_fwd_softmax1,
790794
_attn_fwd_mma,

test/Analysis/test-alias.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
66
#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
77
#B_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
8-
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
8+
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
99
#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
1010
#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
1111

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5959
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
6060
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
6161
#smem = #ttg.shared_memory
62-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
62+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
6363
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
6464
// COMMON-LABEL: @local_loads_with_token_from_async_wait
6565
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -98,7 +98,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9898
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
9999
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
100100
#smem = #ttg.shared_memory
101-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
101+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
102102
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
103103
// COMMON-LABEL: @local_loads_without_token_from_async_wait
104104
tt.func public @local_loads_without_token_from_async_wait(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -137,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
137137
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
138138
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
139139
#smem = #ttg.shared_memory
140-
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
140+
#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
141141
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
142142
// COMMON-LABEL: @local_loads_with_loop_carried_token
143143
tt.func public @local_loads_with_loop_carried_token(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},

test/Conversion/amd/compute-base-ptr.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

3-
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4-
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
3+
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0]}>
66
#smem = #ttg.shared_memory
7-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} {
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 64 : i32} {
88
// CHECK-LABEL: @local_load_offset
99
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
1010
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2167,7 +2167,7 @@ tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4x
21672167

21682168
// -----
21692169

2170-
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
2170+
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
21712171
#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
21722172
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
21732173

0 commit comments

Comments
 (0)