Skip to content

Commit 4048f31

Browse files
authored
[BACKEND] Add LocalLoadTrait to group local load-like ops (#7511)
This is useful to avoid changing common code too much when adding support for AMD LocalLoadPackedTransposedOp.
1 parent 3854ae8 commit 4048f31

File tree

6 files changed

+15
-4
lines changed

6 files changed

+15
-4
lines changed

include/triton/Dialect/TritonGPU/IR/Traits.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ class MemDescViewTrait
1616
// Optional: Add methods or verification logic here
1717
};
1818

19+
template <typename ConcreteType>
20+
class LocalLoadTrait
21+
: public mlir::OpTrait::TraitBase<ConcreteType, LocalLoadTrait> {
22+
// Optional: Add methods or verification logic here
23+
};
24+
1925
} // namespace OpTrait
2026
} // namespace mlir
2127

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
1717

1818
def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
1919

20+
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
2021

2122
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
2223
Dialect dialect = TritonGPU_Dialect,

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
322322
let hasFolder = 1;
323323
}
324324

325-
def TTG_LocalLoadOp : TTG_Op<"local_load"> {
325+
def TTG_LocalLoadOp : TTG_Op<"local_load", [LocalLoadTrait]> {
326326
let summary = "Load a buffer from local memory into a distributed tensor";
327327

328328
let description = [{

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op,
173173
return op;
174174
if (isa<ttg::AsyncCommitGroupOp, ttg::AsyncWaitOp>(op))
175175
return op;
176-
if (isa<ttg::LocalLoadOp, ttg::LocalStoreOp>(op))
176+
if (op->hasTrait<OpTrait::LocalLoadTrait>())
177+
return op;
178+
if (isa<ttg::LocalStoreOp>(op))
177179
return op;
178180
if (isa<ttng::TMEMAllocOp, ttng::TMEMLoadOp>(op))
179181
return op;

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
3232
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
3333
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
3434
include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td"
35+
include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
3536

3637
include "mlir/IR/EnumAttr.td"
3738
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -566,7 +567,7 @@ def InThreadTransposeOp : TT_AMDGPU_Op<"in_thread_transpose", [Pure]> {
566567
// LocalLoadPackedTransposedOp
567568
//===----------------------------------------------------------------------===//
568569

569-
def LocalLoadPackedTransposedOp : TT_AMDGPU_Op<"local_load_packed_tranposed"> {
570+
def LocalLoadPackedTransposedOp : TT_AMDGPU_Op<"local_load_packed_tranposed", [LocalLoadTrait]> {
570571
let summary = "Load a transposed packed tensor from shared memory into a distributed tensor";
571572
let description = [{
572573
Requires a M/N packed and M/N contiguous tensor in shared memory and will yield a K packed K contiguous tensor in registers.

third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,8 @@ getSharedEncIfAllUsersAreDotEnc(Value loadedValue) {
399399
if (!getSharedEncIfAllUsersAreDotEnc(userResult).has_value())
400400
return std::nullopt;
401401
} else {
402-
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
402+
if (!(isa<ttg::ConvertLayoutOp>(user) ||
403+
user->hasTrait<OpTrait::LocalLoadTrait>()))
403404
return std::nullopt;
404405

405406
auto srcTy = cast<ttg::TensorOrMemDesc>(loadedValue.getType());

0 commit comments

Comments
 (0)