Skip to content

Commit 8065e6a

Browse files
Revert "Revert "[LAYOUTS] Enable generic swizzling on AMD (#7225)""
This reverts commit dc3c13d.
1 parent f1307bd commit 8065e6a

31 files changed

+150
-137
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 69 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,10 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
153153
// Return the order that represents that the batch is in row-major or
154154
// column-major order for a batch of matrices of shape [*, m, n] with
155155
// len(shape) == rank.
156-
assert(rank >= 2);
157156
SmallVector<unsigned> order(rank);
157+
if (rank < 2) {
158+
return order;
159+
}
158160
std::iota(order.rbegin(), order.rend(), 0);
159161
if (!rowMajor) {
160162
std::swap(order[0], order[1]);
@@ -396,6 +398,21 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
396398
return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
397399
"order must all have the same rank.";
398400
}
401+
if (llvm::any_of(sizePerThread,
402+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
403+
return emitError()
404+
<< "Every element in sizePerThread must be a power of two.";
405+
}
406+
if (llvm::any_of(threadsPerWarp,
407+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
408+
return emitError()
409+
<< "Every element in threadsPerWarp must be a power of two.";
410+
}
411+
if (llvm::any_of(warpsPerCTA,
412+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
413+
return emitError()
414+
<< "Every element in warpsPerCTA must be a power of two.";
415+
}
399416

400417
// Empty CTALayout is allowed, but if it's present its rank must match the
401418
// BlockedEncodingAttr's rank.
@@ -2246,6 +2263,8 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
22462263
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
22472264
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
22482265
return mma.getRepOrderForOperand(getOpIdx());
2266+
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
2267+
return to_vector(blocked.getOrder());
22492268
}
22502269
llvm::report_fatal_error(
22512270
"getRepOrder not implemented for DotOperandEncodingAttr");
@@ -2958,60 +2977,56 @@ struct TritonGPUVerifyTensorLayoutInterface
29582977
LogicalResult verifyTensorLayout(
29592978
Attribute layout, RankedTensorType rankedTy, Operation *op,
29602979
function_ref<InFlightDiagnostic()> makeErr) const override {
2961-
if (isa<triton::gpu::SharedEncodingTrait>(layout))
2962-
return makeErr() << "Shared layout is not allowed on tensor type.";
2963-
// TODO(jlebar): Currently this only checks blocked layouts, but other
2964-
// layouts also have invariants!
2965-
2966-
// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2967-
if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2968-
ModuleOp module = op->getParentOfType<ModuleOp>();
2969-
2970-
// A different verifier should have checked that the layout itself is
2971-
// valid, including that threads-per-warp has the same rank as
2972-
// warps-per-block etc.
2973-
if (blocked.getRank() != rankedTy.getRank()) {
2974-
return makeErr() << layout << ".\nLayout has rank " << blocked.getRank()
2975-
<< ", but the tensor it's attached to has rank "
2976-
<< rankedTy.getRank() << ".";
2977-
}
2978-
2979-
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2980-
int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp());
2981-
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2982-
return makeErr() << layout << ".\nLayout has a total of "
2983-
<< layoutThreadsPerWarp
2984-
<< " threads per warp, but the module specifies "
2985-
<< moduleThreadsPerWarp << " threads per warp.";
2986-
}
2987-
2988-
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2989-
if (!moduleWarpsPerCTA) {
2990-
return makeErr()
2991-
<< "Could not determine the number of warps per CTA. Operation "
2992-
"is not in a context with `ttg.num-warps`.";
2993-
}
2994-
int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA());
2995-
if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2996-
return makeErr() << layout << ".\nLayout has a total of "
2997-
<< layoutWarpsPerCTA
2998-
<< " warps per CTA, but the context requires "
2999-
<< *moduleWarpsPerCTA << " warps per CTA.";
3000-
}
3001-
3002-
if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
3003-
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
3004-
int64_t layoutCTAsPerCGA =
3005-
product(blocked.getCTALayout().getCTAsPerCGA());
3006-
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
3007-
return makeErr() << layout << ".\nLayout has a total of "
3008-
<< layoutCTAsPerCGA
3009-
<< " CTAs per CGA, but the module specifies "
3010-
<< moduleCTAsPerCGA << " CTAs per CGA.";
3011-
}
3012-
}
2980+
auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2981+
if (!distr)
2982+
return makeErr()
2983+
<< "Non-distributed layout is not allowed in tensor type.";
2984+
auto rank = distr.getRepOrder().size();
2985+
if (rank != rankedTy.getRank())
2986+
return makeErr() << "Layout has rank " << rank
2987+
<< ", but the tensor it's attached to has rank "
2988+
<< rankedTy.getRank() << ".";
2989+
if (llvm::any_of(rankedTy.getShape(),
2990+
[](int64_t i) { return !llvm::isPowerOf2_64(i); })) {
2991+
return makeErr() << "Layout has shape " << rankedTy.getShape()
2992+
<< ", but the tensor it's attached to has shape "
2993+
<< rankedTy.getShape()
2994+
<< " which is not a power of two.";
2995+
}
2996+
auto ll = toLinearLayout(rankedTy.getShape(), layout);
2997+
ModuleOp module = op->getParentOfType<ModuleOp>();
2998+
2999+
// Number of threads per warp.
3000+
auto kLane = StringAttr::get(module.getContext(), "lane");
3001+
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
3002+
if (ll.getInDimSize(kLane) != moduleThreadsPerWarp) {
3003+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane)
3004+
<< " threads per warp, but the module specifies "
3005+
<< moduleThreadsPerWarp << " threads per warp.";
3006+
}
3007+
3008+
// Number of warps per CTA.
3009+
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
3010+
if (!moduleWarpsPerCTA) {
3011+
return makeErr()
3012+
<< "Could not determine the number of warps per CTA. Operation "
3013+
"is not in a context with `ttg.num-warps`.";
3014+
}
3015+
auto kWarp = StringAttr::get(module.getContext(), "warp");
3016+
if (ll.getInDimSize(kWarp) != *moduleWarpsPerCTA) {
3017+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp)
3018+
<< " warps per CTA, but the context requires "
3019+
<< *moduleWarpsPerCTA << " warps per CTA.";
3020+
}
3021+
3022+
// Number of CTAs per CGA.
3023+
auto kBlock = StringAttr::get(module.getContext(), "block");
3024+
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
3025+
if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) {
3026+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock)
3027+
<< " CTAs per CGA, but the context requires "
3028+
<< moduleCTAsPerCGA << " CTAs per CGA.";
30133029
}
3014-
30153030
return success();
30163031
}
30173032
};

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,14 @@ def test_warp_specialize():
393393
c = ttgl.arange(0, 4, layout=layout)
394394
pair = Pair(a, b)
395395
e: ttgl.constexpr = 42
396-
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default,
396+
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default, (pair, c, e),
397397
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
398398
anchor(a)
399399
anchor(b)
400400

401401
# CHECK: ttg.warp_specialize([[A]], [[B]], [[C]])
402402
# CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> ()
403-
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, [warp_specialize_worker1], [4], [48])
403+
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48])
404404

405405

406406
@gluon.jit

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def set_auto_layout(value, layout, _semantic=None):
447447

448448

449449
@builtin
450-
def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
450+
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
451451
_semantic=None, _generator=None):
452452
"""
453453
Create a warp-specialized execution region, partitioning work across warps.
@@ -465,7 +465,7 @@ def warp_specialize(args, default_partition, worker_partitions, worker_num_warps
465465
"""
466466
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
467467
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
468-
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
468+
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
469469
worker_num_regs, _generator)
470470

471471

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) ->
296296
self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
297297
for i in range(len(inputs)))
298298

299-
def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
300-
worker_num_regs: Sequence[int], generator):
299+
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
300+
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
301301
num_partitions = len(worker_partitions)
302302
assert num_partitions == len(
303303
worker_num_warps
@@ -312,7 +312,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
312312
# Emit the default partition to get the result types.
313313
default_block = builder.new_block()
314314
builder.set_insertion_point_to_start(default_block)
315-
default_results = generator.call_JitFunction(default_partition, args, kwargs={})
315+
default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
316316
mlir_results = []
317317
if default_results is not None:
318318
mlir_results = flatten_values_to_ir(default_results)
@@ -321,7 +321,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
321321

322322
# Create the warp specialize op.
323323
builder.restore_insertion_point(insert_pt)
324-
mlir_args = flatten_values_to_ir(args)
324+
mlir_args = flatten_values_to_ir(worker_args)
325325
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
326326
ws_op.get_default_region().push_back(default_block)
327327
ws_op.set_requested_registers(worker_num_regs)
@@ -334,7 +334,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
334334
caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
335335
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
336336
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
337-
block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
337+
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
338338
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context)
339339
builder.create_warp_return()
340340

test/Analysis/test-alias.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
66
#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
77
#B_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
8-
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
8+
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
99
#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
1010
#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
1111

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5959
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
6060
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
6161
#smem = #ttg.shared_memory
62-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
62+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
6363
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
6464
// COMMON-LABEL: @local_loads_with_token_from_async_wait
6565
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -98,7 +98,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9898
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
9999
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
100100
#smem = #ttg.shared_memory
101-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
101+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
102102
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
103103
// COMMON-LABEL: @local_loads_without_token_from_async_wait
104104
tt.func public @local_loads_without_token_from_async_wait(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -137,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
137137
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
138138
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
139139
#smem = #ttg.shared_memory
140-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
140+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
141141
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
142142
// COMMON-LABEL: @local_loads_with_loop_carried_token
143143
tt.func public @local_loads_with_loop_carried_token(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},

test/Conversion/amd/compute-base-ptr.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

3-
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4-
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
3+
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0]}>
66
#smem = #ttg.shared_memory
7-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} {
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 64 : i32} {
88
// CHECK-LABEL: @local_load_offset
99
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
1010
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2234,7 +2234,7 @@ tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4x
22342234

22352235
// -----
22362236

2237-
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
2237+
#mma = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [1, 1]}>
22382238
#dot = #ttg.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
22392239
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
22402240

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' 2>&1 | FileCheck %s
1+
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' | FileCheck %s
22

33
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
44
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
@@ -123,7 +123,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
123123
#mma1 = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>
124124
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 8}>
125125
#smem = #ttg.shared_memory
126-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
126+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
127127
// CHECK-LABEL: @dot_reg_operand_A_fp8
128128
// Generate a wgmma where the first operand is a struct.
129129
// CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, i1) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>
@@ -469,7 +469,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
469469
// -----
470470

471471
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
472-
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
472+
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
473473
// CHECK-LABEL: test_fp8_to_fp16_dot_operand
474474
// CHECK-COUNT-16: cvt.rn.f16x2.e5m2x2
475475
tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) {

0 commit comments

Comments
 (0)