Skip to content

Commit 8b97ac5

Browse files
Merge commit '8d42d211841b4241a08d9d0d2bb6b77fe6e261c0'
2 parents acea966 + 8d42d21 commit 8b97ac5

38 files changed

+800
-590
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUTypes.td

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,27 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
4949
"Type":$elementType,
5050
"Attribute":$encoding,
5151
"Attribute":$memorySpace,
52-
"bool":$mutable_memory
52+
"bool":$mutableMemory,
53+
ArrayRefParameter<"int64_t">:$allocShape
5354
);
55+
5456
let extraClassDeclaration = [{
5557
MemDescType cloneWith(std::optional<ArrayRef<int64_t>> shape,
5658
Type elementType) const {
57-
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory());
59+
return MemDescType::get(shape.value_or(getShape()), elementType, getEncoding(), getMemorySpace(), getMutableMemory(), getAllocShape());
5860
}
5961

6062
bool hasRank() const { return true; }
6163
}];
64+
6265
let builders = [
6366
TypeBuilderWithInferredContext<(ins
6467
"llvm::ArrayRef<int64_t>":$shape,
6568
"Type":$elementType,
6669
"Attribute":$encoding,
6770
"Attribute":$memorySpace
6871
), [{
69-
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false);
72+
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, /*mutableMemory=*/false, /*allocShape=*/shape);
7073
}]>,
7174
TypeBuilderWithInferredContext<(ins
7275
"llvm::ArrayRef<int64_t>":$shape,
@@ -75,10 +78,23 @@ def TTG_MemDescType : TTG_TypeDef<"MemDesc", "memdesc", [ShapedTypeInterface]> {
7578
"Attribute":$memorySpace,
7679
"bool":$mutableMemory
7780
), [{
78-
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory);
81+
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, /*allocShape=*/shape);
82+
}]>,
83+
TypeBuilderWithInferredContext<(ins
84+
"llvm::ArrayRef<int64_t>":$shape,
85+
"Type":$elementType,
86+
"Attribute":$encoding,
87+
"Attribute":$memorySpace,
88+
"bool":$mutableMemory,
89+
"llvm::ArrayRef<int64_t>":$allocShape
90+
), [{
91+
return $_get(elementType.getContext(), shape, elementType, encoding, memorySpace, mutableMemory, allocShape);
7992
}]>
93+
8094
];
95+
8196
let hasCustomAssemblyFormat = 1;
97+
let genVerifyDecl = 1;
8298
}
8399

84100

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,6 +2505,7 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
25052505
using OpAsmDialectInterface::OpAsmDialectInterface;
25062506

25072507
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
2508+
// Encoding attributes
25082509
if (auto mmaAttr = mlir::dyn_cast<MmaEncodingTrait>(attr)) {
25092510
os << "mma";
25102511
return AliasResult::FinalAlias;
@@ -2524,6 +2525,11 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
25242525
os << "slice";
25252526
return AliasResult::FinalAlias;
25262527
} */
2528+
// Memory space attributes
2529+
if (auto smem = mlir::dyn_cast<SharedMemorySpaceAttr>(attr)) {
2530+
os << "smem";
2531+
return AliasResult::FinalAlias;
2532+
}
25272533
return OpAsmDialectInterface::getAlias(attr, os);
25282534
}
25292535
};

lib/Dialect/TritonGPU/IR/Types.cpp

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,47 +30,54 @@ void TokenType::print(AsmPrinter &printer) const {
3030
static constexpr llvm::StringRef kMutableMemory = "mutable";
3131

3232
Type MemDescType::parse(AsmParser &parser) {
33-
if (parser.parseLess())
33+
if (failed(parser.parseLess()))
3434
return Type();
3535

36-
SmallVector<int64_t> dimensions;
37-
if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false))
36+
SmallVector<int64_t> dimensions; // required
37+
if (failed(parser.parseDimensionList(dimensions, /*allowDynamic=*/false)))
3838
return Type();
3939

40-
// Parse the element type.
41-
Type elementType;
42-
if (parser.parseType(elementType))
40+
Type elementType; // required
41+
if (failed(parser.parseType(elementType)))
4342
return Type();
4443

45-
Attribute encoding;
46-
if (succeeded(parser.parseOptionalComma())) {
47-
if (parser.parseAttribute(encoding))
48-
return Type();
49-
}
50-
bool mutableMemory = false;
51-
Attribute memorySpace;
44+
Attribute encoding; // required
45+
if (failed(parser.parseComma()) || failed(parser.parseAttribute(encoding)))
46+
return Type();
47+
48+
Attribute memorySpace; // required
49+
if (failed(parser.parseComma()) || failed(parser.parseAttribute(memorySpace)))
50+
return Type();
51+
52+
bool mutableMemory = false; // optional
53+
SmallVector<int64_t> allocShape; // optional
5254
if (succeeded(parser.parseOptionalComma())) {
53-
if (failed(parser.parseOptionalKeyword(kMutableMemory))) {
54-
if (parser.parseAttribute(memorySpace))
55-
return Type();
56-
} else {
55+
if (succeeded(parser.parseOptionalKeyword(kMutableMemory))) {
5756
mutableMemory = true;
58-
}
59-
}
60-
if (mutableMemory == false && succeeded(parser.parseOptionalComma())) {
61-
if (parser.parseOptionalKeyword(kMutableMemory))
57+
if (succeeded(parser.parseOptionalComma())) {
58+
if (failed(parser.parseDimensionList(allocShape, /*allowDynamic=*/false,
59+
/*withTrailingX=*/false))) {
60+
return Type();
61+
}
62+
}
63+
} else if (failed(parser.parseDimensionList(allocShape,
64+
/*allowDynamic=*/false,
65+
/*withTrailingX=*/false))) {
6266
return Type();
63-
mutableMemory = true;
67+
}
6468
}
69+
6570
if (parser.parseGreater())
6671
return Type();
72+
6773
return MemDescType::get(parser.getContext(), dimensions, elementType,
68-
encoding, memorySpace, mutableMemory);
74+
encoding, memorySpace, mutableMemory, dimensions);
6975
}
7076

7177
void MemDescType::print(AsmPrinter &printer) const {
7278
printer << "<";
73-
for (auto dim : getShape())
79+
auto shape = getShape();
80+
for (auto dim : shape)
7481
printer << dim << "x";
7582
printer << getElementType();
7683
if (getEncoding())
@@ -79,9 +86,26 @@ void MemDescType::print(AsmPrinter &printer) const {
7986
printer << ", " << getMemorySpace();
8087
if (getMutableMemory())
8188
printer << ", " << kMutableMemory;
89+
auto allocShape = getAllocShape();
90+
if (allocShape != shape) {
91+
printer << ", " << allocShape[0];
92+
for (auto dim : allocShape.drop_front(1)) {
93+
printer << "x" << dim;
94+
}
95+
}
8296
printer << ">";
8397
}
8498

99+
LogicalResult MemDescType::verify(function_ref<InFlightDiagnostic()> emitError,
100+
ArrayRef<int64_t> shape, Type elementType,
101+
Attribute encoding, Attribute memorySpace,
102+
bool mutableMemory,
103+
ArrayRef<int64_t> allocShape) {
104+
if (allocShape.size() < shape.size())
105+
emitError() << "alloc shape must have at least as many dimensions as shape";
106+
return success();
107+
}
108+
85109
//===----------------------------------------------------------------------===//
86110
// Triton Dialect
87111
//===----------------------------------------------------------------------===//

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
144144
triton::gpu::SharedMemorySpaceAttr::get(forOp.getContext());
145145
ttg::MemDescType subviewTy = ttg::MemDescType::get(
146146
allocTy.getShape().drop_front(), allocTy.getElementType(),
147-
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
147+
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true,
148+
/*allocShape=*/allocTy.getAllocShape());
148149
auto view = builder.createWithStage<ttg::MemDescSubviewOp>(
149150
loc, stage, clusterId, subviewTy, alloc, copyOffsets);
150151
Operation *copy = builder.createWithStage<ttg::AsyncCopyGlobalToLocalOp>(
@@ -232,7 +233,8 @@ createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp,
232233
copyOffsets[0] = insertIdx;
233234
ttg::MemDescType subviewTy = ttg::MemDescType::get(
234235
allocTy.getShape().drop_front(), allocTy.getElementType(),
235-
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true);
236+
allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true,
237+
/*allocShape=*/allocTy.getAllocShape());
236238
auto view = builder.createWithStage<ttg::MemDescSubviewOp>(
237239
loc, stage, clusterId, subviewTy, alloc, copyOffsets);
238240

@@ -526,7 +528,7 @@ static Value createAlloc(scf::ForOp &forOp, Operation *loadOp,
526528
bufferShape.insert(bufferShape.begin(), distance);
527529
Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(),
528530
sharedEnc, sharedMemorySpace,
529-
/*mutableMemory*/ true);
531+
/*mutableMemory=*/true);
530532
Value alloc =
531533
builder.create<ttg::LocalAllocOp>(loadOp->getLoc(), memdescType, Value());
532534
return alloc;
@@ -544,12 +546,13 @@ static Value createBarrierAlloc(scf::ForOp &forOp, unsigned distance) {
544546
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
545547
auto barrierEncoding =
546548
ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout);
547-
Type barrierMemDescType = ttg::MemDescType::get(
549+
auto barrierMemDescType = ttg::MemDescType::get(
548550
{distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace,
549551
/*mutableMemory=*/true);
550-
Type singleBarrierMemDescType =
551-
ttg::MemDescType::get({1}, builder.getI64Type(), barrierEncoding,
552-
sharedMemorySpace, /*mutableMemory=*/true);
552+
Type singleBarrierMemDescType = ttg::MemDescType::get(
553+
{1}, builder.getI64Type(), barrierEncoding, sharedMemorySpace,
554+
/*mutableMemory=*/true,
555+
/*allocShape=*/barrierMemDescType.getAllocShape());
553556
Value barrierAlloc =
554557
builder.create<ttg::LocalAllocOp>(loc, barrierMemDescType, Value());
555558
for (unsigned i = 0; i < distance; i++) {
@@ -650,11 +653,11 @@ static void createTMABarrierAndWait(
650653
OpBuilderWithStage builder(forOp);
651654
Attribute sharedMemorySpace =
652655
ttg::SharedMemorySpaceAttr::get(builder.getContext());
656+
auto allocTy = cast<ttg::MemDescType>(barrierAlloc.getType());
653657
ttg::MemDescType barrierTy = ttg::MemDescType::get(
654-
{1}, builder.getI64Type(),
655-
cast<ttg::MemDescType>(barrierAlloc.getType()).getEncoding(),
656-
sharedMemorySpace,
657-
/*mutableMemory=*/true);
658+
{1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace,
659+
/*mutableMemory=*/true,
660+
/*allocShape=*/allocTy.getAllocShape());
658661
builder.setInsertionPoint(group[0]->loadOp);
659662
Value barrier = builder.createWithStage<ttg::MemDescSubviewOp>(
660663
loc, stage, cluster, barrierTy, barrierAlloc,
@@ -835,14 +838,14 @@ static void invalidateBarriers(OpBuilder &builder,
835838
Attribute sharedMemorySpace =
836839
ttg::SharedMemorySpaceAttr::get(builder.getContext());
837840
for (Value barrier : barriers) {
838-
int numBarriers = cast<ttg::MemDescType>(barrier.getType()).getShape()[0];
841+
auto allocTy = cast<ttg::MemDescType>(barrier.getType());
842+
int numBarriers = allocTy.getShape()[0];
839843
for (int i = 0; i < numBarriers; i++) {
840844
Value idx = builder.create<arith::ConstantIntOp>(barrier.getLoc(), i, 32);
841845
ttg::MemDescType barrierTy = ttg::MemDescType::get(
842-
{1}, builder.getI64Type(),
843-
cast<ttg::MemDescType>(barrier.getType()).getEncoding(),
844-
sharedMemorySpace,
845-
/*mutableMemory=*/true);
846+
{1}, builder.getI64Type(), allocTy.getEncoding(), sharedMemorySpace,
847+
/*mutableMemory=*/true,
848+
/*allocShape=*/allocTy.getShape());
846849
Value barrierView = builder.create<ttg::MemDescSubviewOp>(
847850
barrier.getLoc(), barrierTy, barrier, idx);
848851
builder.create<ttng::InvalBarrierOp>(barrier.getLoc(), barrierView);

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
136136
builder.create<arith::ConstantIntOp>(v.getLoc(), off, 32));
137137
Value newSmem = builder.create<triton::gpu::MemDescSubviewOp>(
138138
v.getLoc(),
139-
triton::gpu::MemDescType::get(shape, elementType, type.getEncoding(),
140-
type.getMemorySpace()),
139+
triton::gpu::MemDescType::get(
140+
shape, elementType, type.getEncoding(), type.getMemorySpace(),
141+
type.getMutableMemory(), type.getAllocShape()),
141142
v, offsetsVal);
142143

143144
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(

python/test/unit/language/test_core.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5378,20 +5378,22 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
53785378
layouts = f"""
53795379
#src = {src_layout}
53805380
#dst = {dst_layout}
5381+
#smem = #ttg.shared_memory
53815382
""" if interm_layout is None else f"""
53825383
#src = {src_layout}
53835384
#interm = {interm_layout}
53845385
#dst = {dst_layout}
5386+
#smem = #ttg.shared_memory
53855387
"""
53865388

53875389
conversion = f"""
53885390
%12 = ttg.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
53895391
%13 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
53905392
""" if interm_layout is None else f"""
5391-
%15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory>
5392-
%16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xi32, #src>
5393-
%17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory>
5394-
%18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #ttg.shared_memory> -> tensor<{M}x{N}xf16, #src>
5393+
%15 = ttg.local_alloc %9 : (tensor<{M}x{N}xi32, #src>) -> !ttg.memdesc<{M}x{N}xi32, #interm, #smem>
5394+
%16 = ttg.local_load %15 : !ttg.memdesc<{M}x{N}xi32, #interm, #smem> -> tensor<{M}x{N}xi32, #src>
5395+
%17 = ttg.local_alloc %11 : (tensor<{M}x{N}xf16, #src>) -> !ttg.memdesc<{M}x{N}xf16, #interm, #smem>
5396+
%18 = ttg.local_load %17 : !ttg.memdesc<{M}x{N}xf16, #interm, #smem> -> tensor<{M}x{N}xf16, #src>
53955397
53965398
%12 = ttg.convert_layout %16 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
53975399
%13 = ttg.convert_layout %18 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
@@ -5455,6 +5457,7 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
54555457
layouts = f"""
54565458
#dist = {dist_layout}
54575459
#shared = {shared_layout}
5460+
#smem = #ttg.shared_memory
54585461
"""
54595462
ir = layouts + f"""
54605463
module attributes {{"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
@@ -5483,8 +5486,8 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
54835486
%17 = tt.broadcast %15 : tensor<{M}x1x1xi32, #dist> -> tensor<{M}x{N}x{K}xi32, #dist>
54845487
%18 = tt.addptr %16, %17 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>, tensor<{M}x{N}x{K}xi32, #dist>
54855488
%19 = tt.load %18 : tensor<{M}x{N}x{K}x!tt.ptr<i32>, #dist>
5486-
%20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory>
5487-
%21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #ttg.shared_memory> -> tensor<{M}x{N}x{K}xi32, #dist>
5489+
%20 = ttg.local_alloc %19 : (tensor<{M}x{N}x{K}xi32, #dist>) -> !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem>
5490+
%21 = ttg.local_load %20 : !ttg.memdesc<{M}x{N}x{K}xi32, #shared, #smem> -> tensor<{M}x{N}x{K}xi32, #dist>
54885491
%22 = tt.make_range {{end = {K} : i32, start = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>>
54895492
%23 = tt.expand_dims %22 {{axis = 0 : i32}} : tensor<{K}xi32, #ttg.slice<{{dim = 0, parent = #ttg.slice<{{dim = 1, parent = #dist}}>}}>> -> tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>>
54905493
%24 = tt.expand_dims %23 {{axis = 1 : i32}} : tensor<1x{K}xi32, #ttg.slice<{{dim = 1, parent = #dist}}>> -> tensor<1x1x{K}xi32, #dist>

test/Conversion/amd/compute-base-ptr.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
44
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
55
#shared = #ttg.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
6+
#smem = #ttg.shared_memory
67
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 544 : i32, "ttg.threads-per-warp" = 32 : i32} {
78
// CHECK-LABEL: @local_load_offset
89
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
910
%0 = ttg.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)
10-
%1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> loc(#loc2)
11+
%1 = ttg.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !ttg.memdesc<16x16xf16, #shared, #smem> loc(#loc2)
1112
// This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type.
1213
// CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0
13-
%2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #ttg.shared_memory> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3)
14+
%2 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #shared, #smem> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3)
1415
tt.return
1516
}
1617
}

test/Conversion/amd/decompose-unsupported-conversions.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}>
66
// CHECK-LABEL: wmma_to_wmma_dot_op
77
#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}>
8+
#smem = #ttg.shared_memory
89
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1130", "ttg.threads-per-warp" = 32 : i32} {
910
tt.func @wmma_to_wmma_dot_op(%arg0: tensor<16x16xf16, #mma>) {
1011
// CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<16x16xf16, #[[$WMMA]]> -> tensor<16x16xf16, #[[$BLOCKED]]>
11-
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #ttg.shared_memory>
12+
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<16x16xf16, #[[$SHARED]], #smem>
1213
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
1314
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
1415
tt.return
@@ -22,10 +23,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
2223
// CHECK: #[[$SHARED:.+]] = #ttg.shared<{{.*}}>
2324
// CHECK-LABEL: wmma_to_wmma_dot3d_op
2425
#mma = #ttg.amd_wmma<{version = 1, warpsPerCTA = [2, 2, 2]}>
26+
#smem = #ttg.shared_memory
2527
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
2628
tt.func @wmma_to_wmma_dot3d_op(%arg0: tensor<2x16x16xf16, #mma>) {
2729
// CHECK: %[[SRC_BLOCKED:.+]] = ttg.convert_layout %{{.*}} : tensor<2x16x16xf16, #[[$WMMA]]> -> tensor<2x16x16xf16, #[[$BLOCKED]]>
28-
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #ttg.shared_memory>
30+
// CHECK-NEXT: %[[INT_SHARED:.+]] = ttg.local_alloc %[[SRC_BLOCKED]] : {{.*}} -> !ttg.memdesc<2x16x16xf16, #[[$SHARED]], #smem>
2931
// CHECK-NEXT: %[[DST_DOT_OP:.+]] = ttg.local_load %[[INT_SHARED]] : {{.*}} -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$WMMA]], kWidth = 16}>>
3032
%0 = ttg.convert_layout %arg0 : tensor<2x16x16xf16, #mma> -> tensor<2x16x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
3133
tt.return

0 commit comments

Comments
 (0)