Skip to content

Commit 9e7dfc6

Browse files
authored
[Dialect] Verify local/tmem store/load/alloc reg shape and type matches mem shape and type (#7144)
1 parent b66b9ce commit 9e7dfc6

File tree

7 files changed

+41
-23
lines changed

7 files changed

+41
-23
lines changed

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,12 @@ bool isInnermostContiguous(MemDescType type, unsigned numElems);
283283
LinearLayout inferReshapeLinearLayout(ArrayRef<int64_t> srcShape,
284284
Attribute srcEnc,
285285
ArrayRef<int64_t> dstShape);
286+
287+
// Verify the types of operations that operate on memory.
288+
LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
289+
ShapedType dstTy);
290+
// Verify a memory allocation operation.
291+
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy);
286292
} // namespace mlir::triton::gpu
287293

288294
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,6 @@ def TTG_LocalAllocOp : TTG_Op<"local_alloc", [DeclareOpInterfaceMethods<MemoryEf
163163
];
164164

165165
let extraClassDeclaration = [{
166-
static LogicalResult verifyAllocOp(Operation *op, Value src,
167-
MemDescType dstTy);
168166
bool isSharedMemoryAlloc() {
169167
return isa_and_nonnull<SharedMemorySpaceAttr>(getType().getMemorySpace());
170168
}
@@ -312,6 +310,7 @@ def TTG_LocalLoadOp : TTG_Op<"local_load"> {
312310
Arg<TTG_MemDescType, "", [MemRead<SharedMemory>]>:$src,
313311
Optional<TTG_AsyncToken>:$token
314312
);
313+
let results = (outs TT_Tensor:$result);
315314

316315
let builders = [
317316
OpBuilder<(ins "Type":$retType, "Value":$src),
@@ -321,8 +320,7 @@ def TTG_LocalLoadOp : TTG_Op<"local_load"> {
321320

322321
// Use qualified() otherwise "!ttg.memdesc<X>" is printed as "<X>".
323322
let assemblyFormat = [{$src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)}];
324-
325-
let results = (outs TT_Tensor:$result);
323+
let hasVerifier = 1;
326324
}
327325

328326
def TTG_LocalStoreOp : TTG_Op<"local_store"> {

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,22 @@ OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) {
529529
return loadSrc;
530530
}
531531

532-
LogicalResult LocalAllocOp::verifyAllocOp(Operation *op, Value src,
533-
MemDescType dstTy) {
532+
LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
533+
ShapedType dstTy) {
534+
if (srcTy.getElementType() != dstTy.getElementType()) {
535+
return op->emitOpError("source element type ")
536+
<< srcTy << " must match "
537+
<< "destination element type " << dstTy.getElementType();
538+
}
539+
if (srcTy.getShape() != dstTy.getShape()) {
540+
return op->emitOpError("source shape [")
541+
<< srcTy.getShape() << "] must match ["
542+
<< "destination shape " << dstTy.getShape() << "]";
543+
}
544+
return success();
545+
}
546+
547+
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy) {
534548
if (dstTy.getShape() != dstTy.getAllocShape())
535549
return op->emitOpError("result shape and its alloc shape must match");
536550

@@ -542,12 +556,7 @@ LogicalResult LocalAllocOp::verifyAllocOp(Operation *op, Value src,
542556
return success();
543557
}
544558

545-
auto srcTy = cast<RankedTensorType>(src.getType());
546-
if (srcTy.getElementType() != dstTy.getElementType())
547-
return op->emitOpError("result element type must source element type");
548-
if (srcTy.getShape() != dstTy.getShape())
549-
return op->emitOpError("result shape must match source shape");
550-
return success();
559+
return verifyMemoryOpTypes(op, cast<RankedTensorType>(src.getType()), dstTy);
551560
}
552561

553562
LogicalResult LocalAllocOp::verify() {
@@ -561,7 +570,12 @@ LogicalResult LocalAllocOp::verify() {
561570
LogicalResult LocalStoreOp::verify() {
562571
if (!getDst().getType().getMutableMemory())
563572
return emitOpError("Cannot store into immutable memory");
564-
return success();
573+
return verifyMemoryOpTypes(*this, getSrc().getType(), getDst().getType());
574+
}
575+
576+
// LocalLoadOp
577+
LogicalResult LocalLoadOp::verify() {
578+
return verifyMemoryOpTypes(*this, getSrc().getType(), getType());
565579
}
566580

567581
// AsyncCopyGlobalToLocalOp

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,8 @@ LogicalResult TMEMStoreOp::verify() {
455455
if (!getDst().getType().getMutableMemory()) {
456456
return emitOpError("Cannot store into an immutable alloc");
457457
}
458-
return success();
458+
return triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(),
459+
getDst().getType());
459460
}
460461

461462
// -- TMEMLoadOp --
@@ -466,7 +467,7 @@ LogicalResult TMEMLoadOp::verify() {
466467
if (!isa<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
467468
getSrc().getType().getEncoding()))
468469
return emitOpError("should use tensor memory encoding.");
469-
return success();
470+
return triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(), getType());
470471
}
471472

472473
// -- TMEMAllocOp --
@@ -476,8 +477,7 @@ LogicalResult TMEMAllocOp::verify() {
476477
if (!isa<TensorMemoryEncodingAttr, TensorMemoryScalesEncodingAttr>(
477478
getType().getEncoding()))
478479
return emitOpError("should use tensor memory encoding");
479-
480-
return LocalAllocOp::verifyAllocOp(*this, getSrc(), getType());
480+
return triton::gpu::verifyAllocOp(*this, getSrc(), getType());
481481
}
482482

483483
void TMEMAllocOp::getEffects(

test/Analysis/test-membar.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ tt.func @convert_layout5(%A : !tt.ptr<f16>) {
605605
// CHECK: ttg.local_load
606606
// CHECK-NEXT: gpu.barrier
607607
// CHECK: ttg.local_load
608-
%3 = ttg.local_load %0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL>
608+
%3 = ttg.local_load %0 : !ttg.memdesc<32x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<32x16xf16, #AL>
609609
%4 = ttg.local_load %1 : !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x16xf16, #AL>
610610
tt.return
611611
}

test/TritonGPU/optimize-partition-warps.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ tt.func @tmem_min_4_warps(%tensor_desc: !ttg.memdesc<64x64xf32, #tmem, #ttng.ten
148148
}
149149
// CHECK: partition1{{.*}} num_warps(4)
150150
partition1(%desc: !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>) num_warps(8) {
151-
%cst = arith.constant dense<0> : tensor<64x64xi32, #blocked2d_8>
151+
%cst = arith.constant dense<0.0> : tensor<64x64xf32, #blocked2d_8>
152152
%true = arith.constant true
153-
ttng.tmem_store %cst, %desc, %true : tensor<64x64xi32, #blocked2d_8> -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>
153+
ttng.tmem_store %cst, %desc, %true : tensor<64x64xf32, #blocked2d_8> -> !ttg.memdesc<64x64xf32, #tmem, #ttng.tensor_memory, mutable>
154154
ttg.warp_return
155155
}
156156
// CHECK: partition2{{.*}} num_warps(4)

test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
1414
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
1515
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #blocked>
1616
%cst1 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #blocked>
17-
%cst2 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #blocked>
17+
%cst2 = arith.constant dense<0.000000e+00> : tensor<64x128xf16, #blocked>
1818
%cst3 = arith.constant dense<0> : tensor<64x4xi8, #linear>
1919
%cst4 = arith.constant dense<0.000000e+00> : tensor<64x128xf16, #blocked>
2020

@@ -39,8 +39,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
3939
// CHECK: ttng.tmem_alloc %{{.+}} {tensor_memory_col_offset = 128 : i32, tensor_memory_row_offset = 0 : i32}
4040
%6 = ttng.tmem_alloc %cst : (tensor<128x128xf32, #blocked>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
4141

42-
ttng.tmem_store %cst2, %4, %true : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable>
43-
ttng.tmem_store %cst2, %5, %true : tensor<64x256xf16, #blocked> -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable>
42+
ttng.tmem_store %cst2, %4, %true : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable>
43+
ttng.tmem_store %cst2, %5, %true : tensor<64x128xf16, #blocked> -> !ttg.memdesc<64x128xf16, #tmem2, #ttng.tensor_memory, mutable>
4444
ttng.tmem_store %cst, %6, %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
4545

4646
%7 = ttng.tmem_alloc : () -> !ttg.memdesc<64x4xi8, #tmem_scales, #ttng.tensor_memory, mutable>

0 commit comments

Comments
 (0)