Skip to content

Reland "[Dialect] Layout attr cleanup and tighten invariants (#7714)" #4919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 54 additions & 79 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,61 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td"

//===----------------------------------------------------------------------===//
// TritonGPU Attribute Definitions
// Traits and Interfaces
//===----------------------------------------------------------------------===//
def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> {
let cppNamespace = "::mlir::triton::gpu";

def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;

def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
let cppNamespace = "::mlir::triton::gpu";
let description = [{
Common trait for all TTGIR layouts.
}];
let methods = [
InterfaceMethod<"Get the shape of the CTAs per CGA.",
"SmallVector<unsigned>",
"getCTAsPerCGA", (ins), [{}], [{
return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA());
}]>,
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
"SmallVector<unsigned>",
"getCTAOrder", (ins), [{}], [{
return llvm::to_vector($_attr.getCTALayout().getCTAOrder());
}]>,
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
"SmallVector<unsigned>",
"getCTASplitNum", (ins), [{}], [{
return llvm::to_vector($_attr.getCTALayout().getCTASplitNum());
}]>,
InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{
return $_attr.getCTAOrder().size();
}]>
];
}
def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods<
LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>;

def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">;
def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
let cppNamespace = "::mlir::triton::gpu";

def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">;
let description = [{
Common trait describing shared memory.
}];
let methods = [
InterfaceMethod<"Return the default alignment for the layout.",
"int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>,
];
}
def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods<
SharedEncodingTrait, ["getAlignment"]>;

//===----------------------------------------------------------------------===//
// Base Attribute
//===----------------------------------------------------------------------===//

class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [],
Dialect dialect = TritonGPU_Dialect,
string baseCppClass = "::mlir::Attribute">
: AttrDef<dialect, name, !listconcat([TritonGPU_AttrTrait], traits), baseCppClass> {
class TritonGPU_Attr<string name, string attrMnemonic, list<Trait> traits = [], Dialect dialect = TritonGPU_Dialect>
: AttrDef<dialect, name, traits> {

let description = [{
TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines
Expand Down Expand Up @@ -123,51 +161,17 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
CTAOrder.push_back(i);
return get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
}
unsigned getRank() const {
return getCTAOrder().size();
}
unsigned getRank() const { return getCTAOrder().size(); }
}];

let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
}


def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> {
let cppNamespace = "::mlir::triton::gpu";
let description = [{
Common trait for all TTGIR layouts.
}];
let methods = [
InterfaceMethod<"Get the shape of the CTAs per CGA.",
"SmallVector<unsigned>",
"getCTAsPerCGA">,
InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first",
"SmallVector<unsigned>",
"getCTAOrder">,
InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.",
"SmallVector<unsigned>",
"getCTASplitNum">,
];
}

//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//

def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> {
let cppNamespace = "::mlir::triton::gpu";

let description = [{
Common trait describing shared memory.
}];
let methods = [
InterfaceMethod<"Return the default alignment for the layout.",
"int32_t",
"getAlignment">,
];
}

def SwizzledSharedEncodingAttr
: TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding",
[SharedEncodingTrait, LayoutEncodingTrait]> {
Expand Down Expand Up @@ -359,13 +363,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
}]>,
];

let extraClassDeclaration = extraBaseClassDeclaration # [{
unsigned getRank() const { return getCTAOrder().size(); }
int32_t getAlignment() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;
}];
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}
Expand Down Expand Up @@ -433,27 +430,19 @@ attributes too, for example,
];

let extraClassDeclaration = extraBaseClassDeclaration # [{
unsigned getRank() const { return getOrder().size(); }
int32_t getAlignment() const { return 16; }

unsigned getMinInterval() const {
return *llvm::min_element(getIntervals());
}

// Returns the total number of elements including padding given the input
// tensor shape.
int64_t getPaddedSize(ArrayRef<int64_t> shape) const;

SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;
}];
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}

def NVMMASharedEncodingAttr :
TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> {
def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> {
let mnemonic = "nvmma_shared";

let description = [{
Expand Down Expand Up @@ -513,11 +502,6 @@ def NVMMASharedEncodingAttr :
];

let extraClassDeclaration = extraBaseClassDeclaration # [{
unsigned getRank() const { return getCTAOrder().size(); }
int32_t getAlignment() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;
int getPerPhase() const;
int getMaxPhase() const;
int getVec() const;
Expand Down Expand Up @@ -619,20 +603,14 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1):
"CTALayoutAttr":$CTALayout
);

let extraClassDeclaration = extraBaseClassDeclaration # [{
unsigned getRank() const { return getCTAOrder().size(); }
int32_t getAlignment() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;
}];
let hasCustomAssemblyFormat = 1;
}


//===----------------------------------------------------------------------===//
// Distributed Layout Encoding
//===----------------------------------------------------------------------===//

def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> {
let cppNamespace = "::mlir::triton::gpu";

Expand Down Expand Up @@ -719,12 +697,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];

code extraDistributedDeclaration = extraBaseClassDeclaration # [{
unsigned getRank() const { return getCTAOrder().size(); }
// Implemented in subclasses
SmallVector<unsigned> getRepOrder() const;
SmallVector<unsigned> getCTAsPerCGA() const;
SmallVector<unsigned> getCTAOrder() const;
SmallVector<unsigned> getCTASplitNum() const;

LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
}];
Expand All @@ -739,7 +713,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout",
let cppAccessorType = "const LinearLayout &";
}

def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods]> {
let mnemonic = "linear";

let description = [{
Expand Down Expand Up @@ -1376,7 +1350,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
let hasCustomAssemblyFormat = 1;
}

def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods]> {
let mnemonic = "slice";

let description = [{
Expand Down Expand Up @@ -1419,9 +1393,10 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> {
}];

let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
}

def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> {
def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods]> {
let mnemonic = "dot_op";

let description = [{
Expand Down
14 changes: 11 additions & 3 deletions include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"

// TritonNvidiaGPU depends on Triton
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc"

namespace mlir::triton::nvidia_gpu::impl {
Expand Down Expand Up @@ -61,13 +63,19 @@ struct TMemAllocation {

TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType);

Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
RankedTensorType oltType, unsigned numWarps);
gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N,
RankedTensorType oltType,
unsigned numWarps);
gpu::DistributedEncodingTrait
getTmemLoadLayoutSplitLongM(RankedTensorType tensorType,
gpu::MemDescType memType, int numWarps);
SmallVector<gpu::DistributedEncodingTrait>
getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType,
gpu::MemDescType memType);

bool isDistributedLayoutTMemCompatible(Operation *op,
RankedTensorType tensorType,
gpu::MemDescType memType);

bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType,
gpu::MemDescType memType,
int numWarps);
Expand Down
Loading