Skip to content

Commit 4db79f5

Browse files
Revert "Revert "[LAYOUTS] Enable generic swizzling on AMD (#7225)"" (#4862)
Fixes #4564 Please do not squash and merge this PR.
2 parents f1307bd + 4630ec2 commit 4db79f5

39 files changed

+192
-165
lines changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 80 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "triton/Tools/LayoutUtils.h"
2424
#include "triton/Tools/LinearLayout.h"
2525
#include "triton/Tools/StrUtil.h"
26+
#include "triton/Tools/Sys/GetEnv.hpp"
2627
#include "llvm/ADT/SmallSet.h"
2728
#include "llvm/ADT/TypeSwitch.h"
2829
#include "llvm/Support/MathExtras.h"
@@ -153,8 +154,10 @@ SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
153154
// Return the order that represents that the batch is in row-major or
154155
// column-major order for a batch of matrices of shape [*, m, n] with
155156
// len(shape) == rank.
156-
assert(rank >= 2);
157157
SmallVector<unsigned> order(rank);
158+
if (rank < 2) {
159+
return order;
160+
}
158161
std::iota(order.rbegin(), order.rend(), 0);
159162
if (!rowMajor) {
160163
std::swap(order[0], order[1]);
@@ -396,6 +399,21 @@ BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
396399
return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
397400
"order must all have the same rank.";
398401
}
402+
if (llvm::any_of(sizePerThread,
403+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
404+
return emitError()
405+
<< "Every element in sizePerThread must be a power of two.";
406+
}
407+
if (llvm::any_of(threadsPerWarp,
408+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
409+
return emitError()
410+
<< "Every element in threadsPerWarp must be a power of two.";
411+
}
412+
if (llvm::any_of(warpsPerCTA,
413+
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
414+
return emitError()
415+
<< "Every element in warpsPerCTA must be a power of two.";
416+
}
399417

400418
// Empty CTALayout is allowed, but if it's present its rank must match the
401419
// BlockedEncodingAttr's rank.
@@ -2246,6 +2264,8 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
22462264
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
22472265
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
22482266
return mma.getRepOrderForOperand(getOpIdx());
2267+
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
2268+
return to_vector(blocked.getOrder());
22492269
}
22502270
llvm::report_fatal_error(
22512271
"getRepOrder not implemented for DotOperandEncodingAttr");
@@ -2958,60 +2978,66 @@ struct TritonGPUVerifyTensorLayoutInterface
29582978
LogicalResult verifyTensorLayout(
29592979
Attribute layout, RankedTensorType rankedTy, Operation *op,
29602980
function_ref<InFlightDiagnostic()> makeErr) const override {
2961-
if (isa<triton::gpu::SharedEncodingTrait>(layout))
2962-
return makeErr() << "Shared layout is not allowed on tensor type.";
2963-
// TODO(jlebar): Currently this only checks blocked layouts, but other
2964-
// layouts also have invariants!
2965-
2966-
// TODO(jlebar): Handle the case when the encoding is nested within tt.ptr.
2967-
if (auto blocked = dyn_cast<BlockedEncodingAttr>(layout)) {
2968-
ModuleOp module = op->getParentOfType<ModuleOp>();
2969-
2970-
// A different verifier should have checked that the layout itself is
2971-
// valid, including that threads-per-warp has the same rank as
2972-
// warps-per-block etc.
2973-
if (blocked.getRank() != rankedTy.getRank()) {
2974-
return makeErr() << layout << ".\nLayout has rank " << blocked.getRank()
2975-
<< ", but the tensor it's attached to has rank "
2976-
<< rankedTy.getRank() << ".";
2977-
}
2978-
2979-
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2980-
int64_t layoutThreadsPerWarp = product(blocked.getThreadsPerWarp());
2981-
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
2982-
return makeErr() << layout << ".\nLayout has a total of "
2983-
<< layoutThreadsPerWarp
2984-
<< " threads per warp, but the module specifies "
2985-
<< moduleThreadsPerWarp << " threads per warp.";
2986-
}
2987-
2988-
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
2989-
if (!moduleWarpsPerCTA) {
2990-
return makeErr()
2991-
<< "Could not determine the number of warps per CTA. Operation "
2992-
"is not in a context with `ttg.num-warps`.";
2993-
}
2994-
int64_t layoutWarpsPerCTA = product(blocked.getWarpsPerCTA());
2995-
if (layoutWarpsPerCTA != *moduleWarpsPerCTA) {
2996-
return makeErr() << layout << ".\nLayout has a total of "
2997-
<< layoutWarpsPerCTA
2998-
<< " warps per CTA, but the context requires "
2999-
<< *moduleWarpsPerCTA << " warps per CTA.";
3000-
}
3001-
3002-
if (blocked.getCTALayout().getCTAsPerCGA().size() > 0) {
3003-
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
3004-
int64_t layoutCTAsPerCGA =
3005-
product(blocked.getCTALayout().getCTAsPerCGA());
3006-
if (layoutCTAsPerCGA != moduleCTAsPerCGA) {
3007-
return makeErr() << layout << ".\nLayout has a total of "
3008-
<< layoutCTAsPerCGA
3009-
<< " CTAs per CGA, but the module specifies "
3010-
<< moduleCTAsPerCGA << " CTAs per CGA.";
3011-
}
3012-
}
2981+
auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
2982+
if (!distr)
2983+
return makeErr()
2984+
<< "Non-distributed layout is not allowed in tensor type.";
2985+
if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH"))
2986+
return success();
2987+
auto rank = distr.getRepOrder().size();
2988+
if (rank != rankedTy.getRank())
2989+
return makeErr() << "Layout has rank " << rank
2990+
<< ", but the tensor it's attached to has rank "
2991+
<< rankedTy.getRank() << ".";
2992+
if (llvm::any_of(rankedTy.getShape(),
2993+
[](int64_t i) { return !llvm::isPowerOf2_64(i); })) {
2994+
return makeErr() << "Layout has shape " << rankedTy.getShape()
2995+
<< ", but the tensor it's attached to has shape "
2996+
<< rankedTy.getShape()
2997+
<< " which is not a power of two.";
2998+
}
2999+
auto ll = toLinearLayout(rankedTy);
3000+
ModuleOp module = op->getParentOfType<ModuleOp>();
3001+
3002+
// Number of threads per warp.
3003+
auto kLane = StringAttr::get(module.getContext(), "lane");
3004+
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
3005+
// FIXME: ll.getInDimSize(kLane) does not return the correct threads per
3006+
// warp. https://github.com/intel/intel-xpu-backend-for-triton/issues/4861
3007+
unsigned layoutThreadsPerWarp = ll.getInDimSize(kLane);
3008+
if (auto dotOperandLayout =
3009+
dyn_cast<DotOperandEncodingAttr>(rankedTy.getEncoding()))
3010+
if (auto dpasLayout =
3011+
dyn_cast<intel::DpasEncodingAttr>(dotOperandLayout.getParent()))
3012+
layoutThreadsPerWarp = dpasLayout.getThreadsPerWarp();
3013+
if (layoutThreadsPerWarp != moduleThreadsPerWarp) {
3014+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane)
3015+
<< " threads per warp, but the module specifies "
3016+
<< moduleThreadsPerWarp << " threads per warp.";
3017+
}
3018+
3019+
// Number of warps per CTA.
3020+
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
3021+
if (!moduleWarpsPerCTA) {
3022+
return makeErr()
3023+
<< "Could not determine the number of warps per CTA. Operation "
3024+
"is not in a context with `ttg.num-warps`.";
3025+
}
3026+
auto kWarp = StringAttr::get(module.getContext(), "warp");
3027+
if (ll.getInDimSize(kWarp) != *moduleWarpsPerCTA) {
3028+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp)
3029+
<< " warps per CTA, but the context requires "
3030+
<< *moduleWarpsPerCTA << " warps per CTA.";
3031+
}
3032+
3033+
// Number of CTAs per CGA.
3034+
auto kBlock = StringAttr::get(module.getContext(), "block");
3035+
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
3036+
if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) {
3037+
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock)
3038+
<< " CTAs per CGA, but the context requires "
3039+
<< moduleCTAsPerCGA << " CTAs per CGA.";
30133040
}
3014-
30153041
return success();
30163042
}
30173043
};

python/test/gluon/test_frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,14 +393,14 @@ def test_warp_specialize():
393393
c = ttgl.arange(0, 4, layout=layout)
394394
pair = Pair(a, b)
395395
e: ttgl.constexpr = 42
396-
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default,
396+
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default, (pair, c, e),
397397
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
398398
anchor(a)
399399
anchor(b)
400400

401401
# CHECK: ttg.warp_specialize([[A]], [[B]], [[C]])
402402
# CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> ()
403-
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, [warp_specialize_worker1], [4], [48])
403+
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48])
404404

405405

406406
@gluon.jit

python/test/unit/intel/test_block_load.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,24 @@ def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pa
2121
A_width = 2
2222
B_width = 4
2323
layouts = "#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 4, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 2]}>"
24+
num_warps = 4
2425
elif dtype_str == "float32":
2526
A_width = 1
2627
B_width = 1
2728
layouts = "#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}>"
29+
num_warps = 32
2830
else:
2931
A_width = 1
3032
B_width = 2
3133
layouts = "#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}>"
34+
num_warps = 32
3235

3336
block_io = "\"column_major\"" if transpose else "\"row_major\""
3437

3538
ty = {"float32": "f32", "float16": "f16", "int8": "i8"}[dtype_str]
3639

3740
ir = layouts + f"""
38-
module attributes {{ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32}} {{
41+
module attributes {{ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32}} {{
3942
tt.func public @block_load_dpas_layout(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) attributes {{noinline = false}} {{
4043
%0 = tt.get_program_id x : i32
4144
%M_i64 = arith.constant {M} : i64

python/test/unit/language/test_core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6704,13 +6704,17 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t
67046704
elif dtype == "float8e5":
67056705
mlir_dtype = "f8E5M2"
67066706

6707+
num_warps = 4
6708+
if isinstance(dist_layout, DotOperandLayout) and isinstance(dist_layout.parent, DpasLayout):
6709+
num_warps = math.prod(dist_layout.parent.warps_per_cta)
6710+
67076711
layouts = f"""
67086712
#dist = {dist_layout}
67096713
#shared = {shared_layout}
67106714
#smem = #ttg.shared_memory
67116715
"""
67126716
ir = layouts + f"""
6713-
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}} {{
6717+
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = {num_warps} : i32, "ttg.threads-per-warp" = 32 : i32}} {{
67146718
tt.func public @kernel(%arg0: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{mlir_dtype}> {{tt.divisibility = 16 : i32}}) attributes {{noinline = false}} {{
67156719
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #dist>
67166720
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>

python/triton/experimental/gluon/language/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def set_auto_layout(value, layout, _semantic=None):
447447

448448

449449
@builtin
450-
def warp_specialize(args, default_partition, worker_partitions, worker_num_warps, worker_num_regs, #
450+
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
451451
_semantic=None, _generator=None):
452452
"""
453453
Create a warp-specialized execution region, partitioning work across warps.
@@ -465,7 +465,7 @@ def warp_specialize(args, default_partition, worker_partitions, worker_num_warps
465465
"""
466466
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
467467
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
468-
return _semantic.warp_specialize(args, default_partition, worker_partitions, worker_num_warps, #
468+
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
469469
worker_num_regs, _generator)
470470

471471

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ def reduction(self, inputs: Sequence[TensorTy], axis: int, region_builder_fn) ->
296296
self._wrap_handle_infer_layout(reduce_op.get_result(i), inputs[i].type.scalar, ret_shape)
297297
for i in range(len(inputs)))
298298

299-
def warp_specialize(self, args, default_partition, worker_partitions, worker_num_warps: Sequence[int],
300-
worker_num_regs: Sequence[int], generator):
299+
def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions,
300+
worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator):
301301
num_partitions = len(worker_partitions)
302302
assert num_partitions == len(
303303
worker_num_warps
@@ -312,7 +312,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
312312
# Emit the default partition to get the result types.
313313
default_block = builder.new_block()
314314
builder.set_insertion_point_to_start(default_block)
315-
default_results = generator.call_JitFunction(default_partition, args, kwargs={})
315+
default_results = generator.call_JitFunction(default_partition, default_args, kwargs={})
316316
mlir_results = []
317317
if default_results is not None:
318318
mlir_results = flatten_values_to_ir(default_results)
@@ -321,7 +321,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
321321

322322
# Create the warp specialize op.
323323
builder.restore_insertion_point(insert_pt)
324-
mlir_args = flatten_values_to_ir(args)
324+
mlir_args = flatten_values_to_ir(worker_args)
325325
ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps)
326326
ws_op.get_default_region().push_back(default_block)
327327
ws_op.set_requested_registers(worker_num_regs)
@@ -334,7 +334,7 @@ def warp_specialize(self, args, default_partition, worker_partitions, worker_num
334334
caller_context = GluonCallerContext(num_warps=worker_num_warps[i])
335335
block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types)
336336
block_args = [block.get_argument(j) for j in range(len(mlir_args))]
337-
block_args = unflatten_ir_values(block_args, [arg.type for arg in args])
337+
block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args])
338338
generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context)
339339
builder.create_warp_return()
340340

test/Analysis/test-alias.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#A_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
66
#A_SHARED_T = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
77
#B_SHARED = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
8-
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
8+
#C = #ttg.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1], instrShape = [16, 8]}>
99
#A_DOT = #ttg.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
1010
#B_DOT = #ttg.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
1111

test/Conversion/amd/async-ops-alias-scopes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
5959
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
6060
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
6161
#smem = #ttg.shared_memory
62-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
62+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
6363
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
6464
// COMMON-LABEL: @local_loads_with_token_from_async_wait
6565
tt.func public @local_loads_with_token_from_async_wait(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -98,7 +98,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
9898
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
9999
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
100100
#smem = #ttg.shared_memory
101-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
101+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
102102
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
103103
// COMMON-LABEL: @local_loads_without_token_from_async_wait
104104
tt.func public @local_loads_without_token_from_async_wait(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},
@@ -137,7 +137,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
137137
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 8, maxPhase = 2, order = [0, 1]}>
138138
#shared1 = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 16, order = [1, 0]}>
139139
#smem = #ttg.shared_memory
140-
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [32, 32], isTransposed = true}>
140+
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [32, 32], isTransposed = true}>
141141
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
142142
// COMMON-LABEL: @local_loads_with_loop_carried_token
143143
tt.func public @local_loads_with_loop_carried_token(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32},

test/Conversion/amd/compute-base-ptr.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s
22

3-
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4-
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
3+
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
4+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [1, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0]}>
66
#smem = #ttg.shared_memory
7-
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} {
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 64 : i32} {
88
// CHECK-LABEL: @local_load_offset
99
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
1010
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)

0 commit comments

Comments
 (0)