diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index b70abe322b..6a73e7a8ad 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -22,7 +22,7 @@ using namespace mlir; // clang-format off // Example usage: // -// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" +// triton-tensor-layout -l "#ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 32]}>" -t "tensor<128x256xf16>" // // triton-tensor-layout -i input.mlir -t "tensor<1x128x128xf16>" -o output.txt // diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index ecd1e00c92..b8d8af3f71 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -489,7 +489,6 @@ using ::mlir::LLVM::delinearize; using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::BlockedEncodingAttr; -using ::mlir::triton::gpu::CTALayoutAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::NvidiaMmaEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; diff --git a/include/triton/Dialect/TritonGPU/IR/Attributes.h b/include/triton/Dialect/TritonGPU/IR/Attributes.h index 1f93b3d935..77e3283a5a 100644 --- a/include/triton/Dialect/TritonGPU/IR/Attributes.h +++ b/include/triton/Dialect/TritonGPU/IR/Attributes.h @@ -2,6 +2,7 @@ #define TRITON_DIALECT_TRITONGPU_IR_ATTRIBUTES_H_ #include "mlir/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #define GET_ATTRDEF_CLASSES diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index 6cca0ebf81..436bbdc830 100644 --- a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -15,11 +15,17 @@ set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs) mlir_tablegen(AttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) + +set(LLVM_TARGET_DEFINITIONS TritonGPUAttrImpls.td) +mlir_tablegen(AttrDefs.cpp.inc -gen-attrdef-defs) add_public_tablegen_target(TritonGPUAttrDefsIncGen) +set(LLVM_TARGET_DEFINITIONS CTAEncodingAttr.td) +mlir_tablegen(CTAEncodingAttr.h.inc -gen-attrdef-decls) +add_public_tablegen_target(TritonGPUCTAAttrIncGen) + set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td) mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls) mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs) diff --git a/include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h b/include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h new file mode 100644 index 0000000000..3ad60e8646 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h @@ -0,0 +1,11 @@ +#ifndef TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_ +#define TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_ + +#include "mlir/IR/Attributes.h" +#include "triton/Tools/LinearLayout.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h.inc" +#undef GET_ATTRDEF_CLASSES + +#endif // TRITON_DIALECT_TRITONGPU_IR_CTAENCODINGATTR_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.td b/include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.td new file mode 100644 index 0000000000..7f159c01c8 --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/CTAEncodingAttr.td @@ -0,0 +1,40 @@ +//===----------------------------------------------------------------------===// +// CTA encoding attribute definition emitted early to break interface cycles. +//===----------------------------------------------------------------------===// + +#ifndef TRITONGPU_CTAENCODING_ATTR_TD +#define TRITONGPU_CTAENCODING_ATTR_TD + +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td" + +//===----------------------------------------------------------------------===// +// CTA Layout +//===----------------------------------------------------------------------===// + +def CTAEncodingAttr : TritonGPU_Attr<"CTAEncoding", "cta_encoding"> { + let parameters = (ins LinearLayoutParam:$linearLayout); + + let description = [{ +Describes how blocks (CTAs) in a cooperative thread array (CGA) map onto logical +tensor dimensions. The `LinearLayout` maps from `block` into `dim0`, `dim1`... + }]; + + let extraClassDeclaration = [{ + static CTAEncodingAttr getDefault(MLIRContext *context, int rank); + // Legacy, we should kill this! Note that it is not true in general that + // fromSplitParams(enc.getCTAsPerCGA(), enc.getCTASplitNum(), enc.getCTAOrder()) == enc!! + static CTAEncodingAttr fromSplitParams(MLIRContext *context, + ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, + ArrayRef CTAOrder); + + unsigned getRank() const { return getLinearLayout().getNumOutDims(); } + SmallVector getCTAsPerCGA() const; + SmallVector getCTASplitNum() const; + SmallVector getCTAOrder() const; + }]; + + let genVerifyDecl = 1; +} + +#endif // TRITONGPU_CTAENCODING_ATTR_TD diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index a6b1bd7c06..a8f7b14c7a 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -210,7 +210,7 @@ inline SmallVector getThreadOrder(RankedTensorType type) { type.getShape()); } -CTALayoutAttr getCTALayout(Attribute layout); +CTAEncodingAttr getCTALayout(Attribute layout); SmallVector getCTAsPerCGA(Attribute layout); diff --git a/include/triton/Dialect/TritonGPU/IR/LayoutUtility.h b/include/triton/Dialect/TritonGPU/IR/LayoutUtility.h deleted file mode 100644 index a7b9129289..0000000000 --- a/include/triton/Dialect/TritonGPU/IR/LayoutUtility.h +++ /dev/null @@ -1,8 +0,0 @@ -#include -#include - -namespace mlir::triton::gpu { - -CTALayoutAttr permuteCTALayout(MLIRContext *ctx, CTALayoutAttr layout, - ArrayRef order); -} // namespace mlir::triton::gpu diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 70ccb03ac3..3086bfcdc5 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -17,7 +17,7 @@ class SwizzledSharedEncodingAttr; class NVMMASharedEncodingAttr; class TensorOrMemDesc; class MemDescType; -class CTALayoutAttr; +class CTAEncodingAttr; // - BlockedEncodingAttrs have the following input dimensions. // @@ -77,9 +77,9 @@ LinearLayout getLayoutWithinBlock(const LinearLayout &layout); // given shape. // // See the nomenclature note at the top of LinearLayoutConversions.cpp for why -// the variable with type CTALayoutAttr is called cgaLayoutAttr. +// the variable with type CTAEncodingAttr is called cgaLayoutAttr. LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, - CTALayoutAttr cgaLayoutAttr, + CTAEncodingAttr cgaLayoutAttr, ArrayRef shape); // In this function, we construct a linear layout representing the @@ -133,7 +133,7 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, ArrayRef shape, int opIdx, ArrayRef warpsPerCTA, - CTALayoutAttr ctaLayout); + CTAEncodingAttr ctaLayout); // Create LinearLayout for nvidia mma tile. LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, @@ -149,16 +149,5 @@ std::optional chooseMfmaLikeStoreLayout(RankedTensorType valType); // Create the core layout (atom in the PTX manual) a given nvmma shared encoding LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, bool disableSwizzle); - -// Make a LinearLayout that maps a block-id to an N-dimensional index. -// -// The tensor is split up into CTAsPerCGA pieces, which are distributed among -// the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). -// -// See the nomenclature note at the top of the LinearLayoutConversions.cpp file -// for an explanation of why this is called makeCgaLayout when it accepts a -// CTALayoutAttr. -LinearLayout makeCgaLayout(CTALayoutAttr layout); - } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td new file mode 100644 index 0000000000..fa0d582b7b --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td @@ -0,0 +1,54 @@ +//===----------------------------------------------------------------------===// +// Base definitions shared by TritonGPU attribute TableGen files. +// Splitting these out lets us emit certain attributes (e.g. CTAEncodingAttr) +// before interface headers without creating circular dependencies. +//===----------------------------------------------------------------------===// + +#ifndef TRITONGPU_ATTRBASE_TD +#define TRITONGPU_ATTRBASE_TD + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" + +// Traits used across several attrs. +def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">; +def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">; + +// Common parameter helpers. +def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout", + "linear layout"> { + let cppAccessorType = "const LinearLayout &"; +} + +// Base class for all TritonGPU attributes. +class TritonGPU_Attr traits = []> + : AttrDef { + + let description = [{ +TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines +how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function +\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding +to the indices of the CUDA threads allowed to access some data at index $i$. + +For example, let us consider the layout function: +\mathcal{L}(0, 0) = {0, 4} +\mathcal{L}(0, 1) = {1, 5} +\mathcal{L}(1, 0) = {2, 6} +\mathcal{L}(1, 1) = {3, 7} + +Then, attaching $\mathcal{L} to a tensor $T$ would mean that: +- T[0,0] is owned by both cuda thread 0 and 4 +- T[0,1] is owned by both cuda thread 1 and 5 +- T[1,0] is owned by both cuda thread 2 and 6 +- T[1,1] is owned by both cuda thread 3 and 7 + +Right now, Triton implements two main classes of layouts: shared, and distributed. + }]; + let attrName = "triton.gpu." # attrMnemonic; + + code extraBaseClassDeclaration = [{ + }]; +} + +#endif // TRITONGPU_ATTRBASE_TD diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 1fcb1ed754..7c97f172eb 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1,45 +1,28 @@ #ifndef TRITONGPU_ATTRDEFS #define TRITONGPU_ATTRDEFS -include "mlir/IR/AttrTypeBase.td" -include "triton/Dialect/Triton/IR/TritonInterfaces.td" -include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrBase.td" //===----------------------------------------------------------------------===// // Traits, Interfaces and shared Parameters //===----------------------------------------------------------------------===// -def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">; -def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">; - def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> { let cppNamespace = "::mlir::triton::gpu"; let description = [{ Common trait for all TTGIR layouts. }]; let methods = [ - InterfaceMethod<"Get the shape of the CTAs per CGA.", - "SmallVector", - "getCTAsPerCGA", (ins), [{}], [{ - return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA()); - }]>, - InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", - "SmallVector", - "getCTAOrder", (ins), [{}], [{ - return llvm::to_vector($_attr.getCTALayout().getCTAOrder()); - }]>, - InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", - "SmallVector", - "getCTASplitNum", (ins), [{}], [{ - return llvm::to_vector($_attr.getCTALayout().getCTASplitNum()); - }]>, - InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{ - return $_attr.getCTAOrder().size(); + InterfaceMethod<"Get the CTA layout backing this encoding.", + "CTAEncodingAttr", "getCTALayout">, + InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", + (ins), [{}], [{ + return $_attr.getCTALayout().getRank(); }]> ]; } def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods< - LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>; + LayoutEncodingTrait, ["getCTALayout"]>; def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { let cppNamespace = "::mlir::triton::gpu"; @@ -55,131 +38,14 @@ def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods< SharedEncodingTrait, ["getAlignment"]>; -def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout", - "linear layout"> { - let cppAccessorType = "const LinearLayout &"; -} - -//===----------------------------------------------------------------------===// -// Base Attribute -//===----------------------------------------------------------------------===// - -class TritonGPU_Attr traits = []> - : AttrDef { - - let description = [{ -TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines -how the data should be partitioned across CUDA threads. Formally speaking, we define a layout as a function -\mathcal{L} that maps a multi-dimensional tensor index $i \in \mathbb{Z}^d$ to a set of integers T corresponding -to the indices of the CUDA threads allowed to access some data at index $i$. - -For example, let us consider the layout function: -\mathcal{L}(0, 0) = {0, 4} -\mathcal{L}(0, 1) = {1, 5} -\mathcal{L}(1, 0) = {2, 6} -\mathcal{L}(1, 1) = {3, 7} - -Then, attaching $\mathcal{L} to a tensor $T$ would mean that: -- T[0,0] is owned by both cuda thread 0 and 4 -- T[0,1] is owned by both cuda thread 1 and 5 -- T[1,0] is owned by both cuda thread 2 and 6 -- T[1,1] is owned by both cuda thread 3 and 7 - -Right now, Triton implements two main classes of layouts: shared, and distributed. - }]; - let attrName = "triton.gpu." # attrMnemonic; - - code extraBaseClassDeclaration = [{ - }]; -} - -//===----------------------------------------------------------------------===// -// CTA Layout -//===----------------------------------------------------------------------===// - -def CTALayoutAttr : TritonGPU_Attr<"CTALayout", "cta_layout"> { - let parameters = ( - ins - ArrayRefParameter<"unsigned">:$CTAsPerCGA, - ArrayRefParameter<"unsigned">:$CTASplitNum, - ArrayRefParameter<"unsigned">:$CTAOrder - ); - - let description = [{ -Describes how blocks are distributed among the cooperate thread arrays (aka -CTAs, aka thread blocks) in a cooperate thread group (aka CTG, aka thread group -cluster). CGAs were introduced in Hopper (sm90). - -The tensor is divided up into CTASplitNum pieces, which are distributed among -the CTAsPerCGA thread blocks. Each CTA processes a subtensor of shape -`tensor_shape / CTASplitNum`. - -Example 0: The tensor shape is [64, 128] and, there are two CTAs, each -processing half the tensor [64, 64]. Then CTAsPerCGA = [1, 2] and -CTASplitNum = [1, 2]. - -Example 1: The tensor shape is [64, 128] and, there are two CTAs, both -processing the complete tensor [64, 128]. This happens when multicast is -enabled. In this case, CTAsPerCTA = [1, 2] but CTASplitNum = [1, 1]. - -Example 2: Consider a matmul AxB=C, where A=[M,K], B=[K,N], C=[M,N]. The -CTAsPerCGA for A, B, C are the same, [SplitM, SplitN], but the CTASplitNum are -different. CTASplitNum_A = [SplitM, 1], which means multicast on dim1, -CTASplitNum_B = [1, SplitN], which means multicast on dim0, CTASplitNum_C = -[SplitM, SplitN] which means no multicast. - -Currently programs with multiple CTAs per CGA are an experimental feature in -Triton, not enabled by default. - -You can leave off the CTALayout properties in the textual IR and Triton will -fill in the "default" CTALayout of CTAsPerCGA = CTASplitNum = [1...1]. In -addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to -[n-1,...,0] (it doesn't matter in this case). - }]; - - // CTALayout::get canonicalizes CTAOrder to [n,n-1,...,0] if CTAsPerCGA is - // [1...1]. The CTAOrder doesn't matter in this case. - // - // This is a little weird because if you write textual IR with a one order and - // then print it back out, you might get a different order. But it seems this - // is the best way to canonicalize an attribute in MLIR. - let builders = [ - AttrBuilder<(ins "ArrayRef":$CTAsPerCGA, - "ArrayRef":$CTASplitNum, - "ArrayRef":$CTAOrder), [{ - if (llvm::all_of(CTAsPerCGA, [](unsigned x) { return x == 1; })) { - SmallVector order; - for (int i = CTAsPerCGA.size() - 1; i >= 0; --i) - order.push_back(i); - return $_get(context, CTAsPerCGA, CTASplitNum, order); - } - return $_get(context, CTAsPerCGA, CTASplitNum, CTAOrder); - }]>, - ]; - - let extraClassDeclaration = [{ - static CTALayoutAttr getDefault(MLIRContext *context, int rank) { - SmallVector CTAsPerCGA(rank, 1); - SmallVector CTASplitNum(rank, 1); - SmallVector CTAOrder; - for (int i = rank - 1; i >= 0; --i) - CTAOrder.push_back(i); - return get(context, CTAsPerCGA, CTASplitNum, CTAOrder); - } - unsigned getRank() const { return getCTAOrder().size(); } - }]; - - let genVerifyDecl = 1; - let skipDefaultBuilders = 1; -} - //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// def SwizzledSharedEncodingAttr : TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", - [SharedEncodingTrait, LayoutEncodingTrait]> { + [SharedEncodingTrait, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { let mnemonic = "swizzled_shared"; let description = [{ @@ -265,14 +131,14 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at "unsigned":$perPhase, "unsigned":$maxPhase, ArrayRefParameter<"unsigned">:$order, - "CTALayoutAttr":$CTALayout + "CTAEncodingAttr":$CTALayout ); let builders = [ AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, "ArrayRef":$shape, "ArrayRef":$order, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, "unsigned":$typeWidthInBit), [{ bool needTrans = false; // default value return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); @@ -284,7 +150,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, "ArrayRef":$shape, "ArrayRef":$order, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, "unsigned":$typeWidthInBit, "bool":$needTrans), [{ @@ -323,7 +189,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at "unsigned":$kWidth, "ArrayRef":$shape, "ArrayRef":$order, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, "unsigned":$bitwidth, "bool":$needTrans), [{ int K = getShapePerCTA(CTALayout.getCTASplitNum(), shape)[order[0]]; @@ -351,7 +217,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, "ArrayRef":$shape, "ArrayRef":$order, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, "Type":$eltTy), [{ unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); @@ -360,7 +226,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, "ArrayRef":$shape, "ArrayRef":$order, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, "Type":$eltTy, "bool":$needTrans), [{ unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); @@ -465,7 +331,7 @@ For identity mappings a short form based on order and shape is used to increase // Builder to create an identity mapping as the linear component AttrBuilder<(ins "ArrayRef>":$intervalPads, "ArrayRef":$order, "ArrayRef":$shape, - "CTALayoutAttr":$ctaLayout)>, + "CTAEncodingAttr":$ctaLayout)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -495,7 +361,8 @@ For identity mappings a short form based on order and shape is used to increase def SharedLinearEncodingAttr : TritonGPU_Attr<"SharedLinearEncoding", "shared_linear_encoding", - [SharedEncodingTrait, DeclareLayoutEncodingMethods]> { + [SharedEncodingTrait, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { let mnemonic = "shared_linear"; let description = [{ @@ -523,7 +390,9 @@ def SharedLinearEncodingAttr let hasCustomAssemblyFormat = 1; } -def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> { +def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", + [DeclareSharedEncodingMethods, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { let mnemonic = "nvmma_shared"; let description = [{ @@ -544,13 +413,13 @@ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_share "bool":$transposed, "unsigned":$elementBitWidth, "bool":$fp4Padded, - "CTALayoutAttr":$CTALayout + "CTAEncodingAttr":$CTALayout ); let builders = [ AttrBuilder<(ins "ArrayRef":$shape, "ArrayRef":$order, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, "Type":$eltTy, "bool": $fp4Padded), [{ auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape); @@ -592,7 +461,8 @@ def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_share def AMDRotatingSharedEncodingAttr : TritonGPU_Attr<"AMDRotatingSharedEncoding", "amd_rotating_shared_encoding", - [SharedEncodingTrait, LayoutEncodingTrait]> { + [SharedEncodingTrait, LayoutEncodingTrait, + DeclareLayoutEncodingMethods]> { let mnemonic = "amd_rotating_shared"; let description = [{ @@ -681,7 +551,7 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1): "unsigned":$perPhase, "unsigned":$maxPhase, ArrayRefParameter<"unsigned">:$order, - "CTALayoutAttr":$CTALayout + "CTAEncodingAttr":$CTALayout ); let hasCustomAssemblyFormat = 1; @@ -741,7 +611,10 @@ We call each individual tile "rep". } class DistributedEncoding traits = []> - : TritonGPU_Attr { + : TritonGPU_Attr { let description = [{ Distributed encodings have a layout function L that is entirely characterized @@ -788,7 +661,7 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, // Linear Layout Encoding //===----------------------------------------------------------------------===// -def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods]> { +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { let mnemonic = "linear"; let description = [{ @@ -867,9 +740,7 @@ for #ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} - warpsPerCTA = {1, 2} - CTAsPerCGA = {1, 1} - CTASplitNum = {1, 1} + blocked = {{0, 1}} }> Example 2, a row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) as follows: @@ -893,9 +764,7 @@ for #ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} - warpsPerCTA = {1, 2} - CTAsPerCGA = {1, 1} - CTASplitNum = {1, 1} + blocked = {{0, 1}} }> Example 3, A row-major coalesced layout may partition a 32x32 tensor over 2 warps (i.e. 64 threads) and @@ -923,9 +792,7 @@ for #ttg.blocked_layout<{ sizePerThread = {2, 2} threadsPerWarp = {8, 4} - warpsPerCTA = {1, 2} - CTAsPerCGA = {2, 2} - CTASplitNum = {2, 2} + blocked = {{0, 1}, {1, 0}} }> }]; @@ -937,9 +804,8 @@ for ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first // CTALayout is optional in the textual IR. If omitted, we infer it to be a - // single CTA (so CTAsPerCGA = [1,...,1], CTASplitNum = [1,...,1], - // CTAOrder=[n,n-1,...,0]). - "CTALayoutAttr":$CTALayout + // single CTA (i.e. the trivial map onto dim0..dimn-1) + "CTAEncodingAttr":$CTALayout ); let genVerifyDecl = 1; @@ -949,7 +815,7 @@ for "ArrayRef":$order, "unsigned":$numWarps, "unsigned":$numThreadsPerWarp, - "CTALayoutAttr":$CTALayout), [{ + "CTAEncodingAttr":$CTALayout), [{ unsigned rank = sizePerThread.size(); SmallVector threadsPerWarp(rank); SmallVector warpsPerCTA(rank); @@ -1004,7 +870,7 @@ for CTAsPerCGA[rank - 1] *= remainingCTAs; // wrap at CTA level - CTALayoutAttr CTALayout = CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + CTAEncodingAttr CTALayout = CTAEncodingAttr::fromSplitParams(context, CTAsPerCGA, CTASplitNum, CTAOrder); return get(context, shape, sizePerThread, order, numWarps, numThreadsPerWarp, CTALayout); }]> ]; @@ -1158,7 +1024,7 @@ w2 w2 w3 w3 ArrayRefParameter<"unsigned">:$warpsPerCTA, ArrayRefParameter<"unsigned">:$instrShape, "bool":$isTransposed, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, ArrayRefParameter<"unsigned">:$tilesPerWarp, "unsigned":$elementBitWidth ); @@ -1168,7 +1034,7 @@ w2 w2 w3 w3 "ArrayRef":$warpsPerCTA, "ArrayRef":$instrShape, "bool":$isTransposed, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, CArg<"ArrayRef", "{}">:$tpw, CArg<"unsigned", "0">:$elementBitWidth), [{ SmallVector tilesPerWarp(tpw); @@ -1191,7 +1057,7 @@ w2 w2 w3 w3 // Returns a swizzled shared layout matching this MFMA layout for the // dot operand at the given |operandIdx| with |operandShape|. SwizzledSharedEncodingAttr composeSharedLayoutForOperand( - CTALayoutAttr ctaLayout, int operandIdx, ArrayRef operandShape, + CTAEncodingAttr ctaLayout, int operandIdx, ArrayRef operandShape, ArrayRef sharedOrder, unsigned vectorSize, unsigned elemBitWidth, bool needTrans) const; }]; @@ -1328,7 +1194,7 @@ w2 w2 w3 w3 "bool":$isTransposed, ArrayRefParameter<"unsigned">:$warpsPerCTA, ArrayRefParameter<"unsigned">:$tilesPerWarp, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, ArrayRefParameter<"unsigned">:$instrShape ); @@ -1339,7 +1205,7 @@ w2 w2 w3 w3 AttrBuilder<(ins "unsigned":$version, "bool":$isTransposed, "ArrayRef":$warpsPerCTA, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, "ArrayRef":$instrShape), [{ SmallVector tilesPerWarp(warpsPerCTA.size(), 1); return $_get(context, version, isTransposed, warpsPerCTA, tilesPerWarp, CTALayout, instrShape); @@ -1360,7 +1226,7 @@ w2 w2 w3 w3 // Returns a swizzled shared layout matching this WMMA layout for the // dot operand at the given |operandIdx| with |operandShape|. SwizzledSharedEncodingAttr composeSharedLayoutForOperand( - CTALayoutAttr ctaLayout, int operandIdx, ArrayRef operandShape, + CTAEncodingAttr ctaLayout, int operandIdx, ArrayRef operandShape, ArrayRef sharedOrder, unsigned kWidth, unsigned elemBitWidth, bool needTrans) const; }]; @@ -1455,7 +1321,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: "unsigned":$versionMajor, "unsigned":$versionMinor, ArrayRefParameter<"unsigned">:$warpsPerCTA, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, ArrayRefParameter<"unsigned">:$instrShape ); @@ -1475,7 +1341,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: let hasCustomAssemblyFormat = 1; } -def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods]> { +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { let mnemonic = "slice"; let description = [{ @@ -1521,7 +1387,7 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [ let genVerifyDecl = 1; } -def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods]> { +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { let mnemonic = "dot_op"; let description = [{ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrImpls.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrImpls.td new file mode 100644 index 0000000000..8138b8df0a --- /dev/null +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrImpls.td @@ -0,0 +1,13 @@ +//===----------------------------------------------------------------------===// +// Aggregated attr definitions (including CTA) for implementation emission. +// This file exists to generate AttrDefs.cpp.inc once, without duplicating +// CTAEncodingAttr while still making CTA available before LayoutEncodingTrait. +//===----------------------------------------------------------------------===// + +#ifndef TRITONGPU_ATTRIMPLS_TD +#define TRITONGPU_ATTRIMPLS_TD + +include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.td" +include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" + +#endif // TRITONGPU_ATTRIMPLS_TD diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h index 4d368100bc..32d8ff94dc 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h @@ -2,11 +2,12 @@ #define TRITON_GPU_DIALECT_INTERFACES_H #include "mlir/IR/OpDefinition.h" +#include "triton/Dialect/TritonGPU/IR/CTAEncodingAttr.h" // clang-format off #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" -#include "triton/Dialect/TritonGPU/IR/OpInterfaces.h.inc" #include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc" +#include "triton/Dialect/TritonGPU/IR/OpInterfaces.h.inc" // clang-format on #endif // TRITON_GPU_DIALECT_INTERFACES_H diff --git a/include/triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h b/include/triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h index 2dd66e6c8f..c79f44f747 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h +++ b/include/triton/Dialect/TritonGPU/Transforms/CoalesceUtils.h @@ -4,12 +4,13 @@ #include "mlir/Support/LLVM.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace mlir::triton::gpu { BlockedEncodingAttr buildCoalescedEncoding( MLIRContext *context, ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, int numWarps, int threadsPerWarp, - triton::gpu::CTALayoutAttr CTALayout, SmallVector shapePerCTA); + triton::gpu::CTAEncodingAttr CTALayout, SmallVector shapePerCTA); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_COALESCINGUTILS_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index 72ea164e20..fa1bec63ad 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -130,12 +130,12 @@ bool isDistributedLayoutTMemCompatible(Operation *op, gpu::DistributedEncodingTrait getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps, - gpu::CTALayoutAttr ctaLayout); + gpu::CTAEncodingAttr ctaLayout); std::optional getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom, unsigned numWarps, - gpu::CTALayoutAttr ctaLayout); + gpu::CTAEncodingAttr ctaLayout); } // namespace mlir::triton::nvidia_gpu diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h index da26061590..2dace4fd9c 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h @@ -20,8 +20,8 @@ SmallVector translateTMAIndices(OpBuilder &builder, Location loc, Attribute encoding, SmallVector indices); -gpu::CTALayoutAttr updateCTALayoutForShape(gpu::CTALayoutAttr ctaLayout, - ArrayRef shape); +gpu::CTAEncodingAttr updateCTALayoutForShape(gpu::CTAEncodingAttr ctaLayout, + ArrayRef shape); gpu::SharedEncodingTrait updateEncodingForShape(Operation *op, gpu::SharedEncodingTrait encoding, diff --git a/include/triton/Tools/LayoutUtils.h b/include/triton/Tools/LayoutUtils.h index 15fc661f90..7ea612fb02 100644 --- a/include/triton/Tools/LayoutUtils.h +++ b/include/triton/Tools/LayoutUtils.h @@ -182,6 +182,9 @@ largestVectorisation(MLIRContext *ctx, const LinearLayout &cvt, int bitwidth, std::optional getReps(const LinearLayout &cvt, const LinearLayout &tile); +// Given a layout mapping onto dim0..dimn, remove a dimension `dim` +// and rename the rest as dim0..dimn-1 +LinearLayout removeStandardDim(const LinearLayout &layout, int dim); } // namespace mlir::triton #endif // TRITON_TOOLS_LAYOUTUTILS_H diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 77e3a80dba..5e9eb94661 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -6,7 +6,6 @@ #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" namespace { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 8a8c1c3d3e..727393dbdc 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -6,7 +6,6 @@ #include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" -#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/GenericSwizzling.h" diff --git a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp index 3584878121..75ef836873 100644 --- a/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp +++ b/lib/Conversion/TritonInstrumentToLLVM/InstrumentationToLLVM.cpp @@ -27,7 +27,7 @@ Value createMemDescToI64(RewriterBase &rewriter, Location loc, const LLVMTypeConverter *typeConverter, ttg::MemDescType memDescTy, Value sharedMemStruct) { TritonLLVMOpBuilder b(loc, rewriter); - if (isa(memDescTy.getEncoding())) { + if (isa(memDescTy.getMemorySpace())) { return b.ptrtoint(rewriter.getIntegerType(64), sharedMemStruct); } assert(isa(memDescTy.getEncoding()) && diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp index 0c37a52d96..9b4e0ccea9 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp @@ -144,7 +144,7 @@ static RankedTensorType getNewIndicesType(RankedTensorType type, std::array warpsPerCta = {1, numWarps}; MLIRContext *ctx = type.getContext(); - auto ctaLayout = CTALayoutAttr::getDefault(ctx, /*rank=*/2); + auto ctaLayout = CTAEncodingAttr::getDefault(ctx, /*rank=*/2); auto parentEncoding = BlockedEncodingAttr::get( ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, ctaLayout); auto newEncoding = SliceEncodingAttr::get(ctx, /*dim=*/0, parentEncoding); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 2e8ed7ffa4..fe6a76e92f 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -9,6 +9,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" namespace mlir::triton { #define GEN_PASS_DEF_CONVERTTRITONTOTRITONGPU @@ -146,6 +147,7 @@ struct TritonExpandDimsPattern // return shape auto retShape = argType.getShape().vec(); retShape.insert(retShape.begin() + op.getAxis(), 1); + auto newRank = retShape.size(); // return encoding auto retSizePerThread = llvm::to_vector(argEncoding.getSizePerThread()); retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1); @@ -156,14 +158,18 @@ struct TritonExpandDimsPattern SmallVector retOrder(retShape.size()); std::iota(retOrder.begin(), retOrder.end(), 0); - auto argCTALayout = argEncoding.getCTALayout(); - auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis()); - auto retCTASplitNum = - insertOne(argCTALayout.getCTASplitNum(), op.getAxis()); - auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis()); - auto retCTALayout = triton::gpu::CTALayoutAttr::get( - getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder); - + auto ctaLl = argEncoding.getCTALayout().getLinearLayout(); + auto kBlock = *ctaLl.getInDimNames().begin(); + auto *ctx = kBlock.getContext(); + auto newDim = standardOutDimNames(ctx, newRank)[newRank - 1]; + ctaLl *= LinearLayout::identity1D(1, kBlock, newDim); + // Move last dim to op.getAxis(). nb is this a std::rotate? + auto newOrder = to_vector(llvm::seq(newRank)); + for (int i = newRank - 1; i >= op.getAxis() + 1; --i) { + std::swap(newOrder[i], newOrder[i - 1]); + } + ctaLl = transposeLinearLayout(ctaLl, newOrder); + auto retCTALayout = CTAEncodingAttr::get(ctx, std::move(ctaLl)); triton::gpu::BlockedEncodingAttr retEncoding = triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread, retThreadsPerWarp, retWarpsPerCTA, @@ -374,15 +380,16 @@ struct TritonSplitOpPattern : public OpConversionPattern { return res; }; + auto layout = defaultEnc.getCTALayout().getLinearLayout(); + auto kBlock = StringAttr::get(getContext(), "block"); + auto newDim = standardOutDimNames(getContext(), rank)[rank - 1]; + layout *= LinearLayout::identity1D(1, kBlock, newDim); srcEnc = BlockedEncodingAttr::get( getContext(), append(defaultEnc.getSizePerThread(), 2), append(defaultEnc.getThreadsPerWarp(), 1), append(defaultEnc.getWarpsPerCTA(), 1), prepend(defaultEnc.getOrder(), rank - 1), - CTALayoutAttr::get(getContext(), - append(defaultEnc.getCTAsPerCGA(), 1), - append(defaultEnc.getCTASplitNum(), 1), - prepend(defaultEnc.getCTAOrder(), rank - 1))); + CTAEncodingAttr::get(getContext(), layout)); srcTy = srcTy.cloneWithEncoding(srcEnc); src = ConvertLayoutOp::create(rewriter, op.getLoc(), srcTy, src); } diff --git a/lib/Dialect/Gluon/Transforms/InferCoalescedEncodings.cpp b/lib/Dialect/Gluon/Transforms/InferCoalescedEncodings.cpp index 346c4a3d9d..d736b676e1 100644 --- a/lib/Dialect/Gluon/Transforms/InferCoalescedEncodings.cpp +++ b/lib/Dialect/Gluon/Transforms/InferCoalescedEncodings.cpp @@ -25,12 +25,12 @@ namespace mlir::triton::gluon { namespace { -ttg::CTALayoutAttr getDefaultCTALayout(RankedTensorType refTensorType, - int numCTAs) { +ttg::CTAEncodingAttr getDefaultCTALayout(RankedTensorType refTensorType, + int numCTAs) { // TODO support numCTAs > 1 assert(numCTAs == 1 && "only numCTAs == 1 is supported for now"); - return ttg::CTALayoutAttr::getDefault(refTensorType.getContext(), - refTensorType.getShape().size()); + return ttg::CTAEncodingAttr::getDefault(refTensorType.getContext(), + refTensorType.getShape().size()); } bool isCoalescedEncodingTensorType(Type ty) { diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index f6783f1bc3..79648359f6 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,11 +1,11 @@ add_triton_library(TritonGPUIR Dialect.cpp LinearLayoutConversions.cpp - LayoutUtility.cpp Ops.cpp Types.cpp DEPENDS + TritonGPUCTAAttrIncGen TritonGPUTableGen TritonGPUAttrDefsIncGen TritonGPUTypeInterfacesIncGen diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 47fbd2e022..48b6c62972 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/DialectImplementation.h" @@ -16,7 +17,6 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Dialect/TritonGPU/IR/Types.h" @@ -39,6 +39,10 @@ using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::gpu; +static SmallVector +basesPerDimImpl(const LinearLayout::BasesT &namedBases, StringAttr dimName, + size_t rank, bool skipBroadcast = true); + // Utility namespace mlir { namespace triton { @@ -244,26 +248,23 @@ SmallVector getWarpOrder(DistributedEncodingTrait layout, return toLinearEncoding(layout, shape).getWarpOrder(); } -CTALayoutAttr getCTALayout(Attribute layout) { - if (auto ttgLayout = mlir::dyn_cast(layout)) { - return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(ttgLayout), - getCTASplitNum(ttgLayout), - getCTAOrder(ttgLayout)); - } +CTAEncodingAttr getCTALayout(Attribute layout) { + if (auto ttgLayout = mlir::dyn_cast(layout)) + return ttgLayout.getCTALayout(); llvm::report_fatal_error("Unimplemented usage of getCTALayout"); return {}; } SmallVector getCTAsPerCGA(Attribute layout) { if (auto ttgLayout = mlir::dyn_cast(layout)) - return ttgLayout.getCTAsPerCGA(); + return ttgLayout.getCTALayout().getCTAsPerCGA(); llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); } SmallVector getCTASplitNum(Attribute layout) { SmallVector res; if (auto ttgLayout = mlir::dyn_cast(layout)) { - return ttgLayout.getCTASplitNum(); + return ttgLayout.getCTALayout().getCTASplitNum(); } else if (auto tmemLayout = mlir::dyn_cast( layout)) { @@ -284,7 +285,7 @@ SmallVector getCTASplitNum(Attribute layout) { SmallVector getCTAOrder(Attribute layout) { SmallVector res; if (auto ttgLayout = mlir::dyn_cast(layout)) { - res = ttgLayout.getCTAOrder(); + res = ttgLayout.getCTALayout().getCTAOrder(); } else { llvm::report_fatal_error("Unimplemented usage of getCTAOrder"); } @@ -386,36 +387,91 @@ verifyLayoutOrder(function_ref emitError, return success(); } -LogicalResult CTALayoutAttr::verify( - function_ref emitError, ArrayRef CTAsPerCGA, - ArrayRef CTASplitNum, ArrayRef CTAOrder) { - if (!llvm::all_equal( - {CTAsPerCGA.size(), CTASplitNum.size(), CTAOrder.size()})) { - return emitError() << "CTAsPerCGA, CTASplitNum, and CTAOrder must all have " - "the same rank."; +LogicalResult +CTAEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + if (linearLayout.getNumInDims() != 1) { + return emitError() << "CTA encoding must have exactly one input dimension " + "named 'block'."; } - - if (failed(verifyLayoutOrder(emitError, CTAOrder))) - return failure(); - - if (llvm::any_of(CTAsPerCGA, [](unsigned x) { return x == 0; })) { - return emitError() << "Every element in CTAsPerCGA must be greater than 0."; + auto dim = *linearLayout.getInDimNames().begin(); + auto ctx = dim.getContext(); + if (dim != StringAttr::get(ctx, "block")) { + return emitError() << "CTA encoding must have exactly one input dimension " + "named 'block'."; } - if (llvm::any_of(CTASplitNum, [](unsigned x) { return x == 0; })) { - return emitError() - << "Every element in CTASplitNum must be greater than 0."; + auto outDimNames = linearLayout.getOutDimNames(); + auto expected = standardOutDimNames(ctx, linearLayout.getNumOutDims()); + if (!llvm::equal(outDimNames, expected)) { + return emitError() << "CTA encoding output dims must be [dim0, dim1, ...], " + "but got [" + << outDimNames << "]."; } return success(); } -LogicalResult -BlockedEncodingAttr::verify(function_ref emitError, - ArrayRef sizePerThread, - ArrayRef threadsPerWarp, - ArrayRef warpsPerCTA, - ArrayRef order, CTALayoutAttr CTALayout) { +CTAEncodingAttr CTAEncodingAttr::getDefault(MLIRContext *ctx, int rank) { + auto kBlock = StringAttr::get(ctx, "block"); + LinearLayout::BasesT bases; + bases[kBlock] = {}; + auto dims = standardOutDimNames(ctx, rank); + return get(ctx, LinearLayout(bases, dims)); +} + +CTAEncodingAttr CTAEncodingAttr::fromSplitParams(MLIRContext *ctx, + ArrayRef CTAsPerCGA, + ArrayRef CTASplitNum, + ArrayRef CTAOrder) { + int rank = CTAOrder.size(); + auto outDimNames = standardOutDimNames(ctx, rank); + StringAttr kBlock = StringAttr::get(ctx, "block"); + + LinearLayout layout = LinearLayout::empty(); + SmallVector splitNums(CTASplitNum.begin(), CTASplitNum.end()); + SmallVector ctas(CTAsPerCGA.begin(), CTAsPerCGA.end()); + + for (int i = 0; i < rank; ++i) { + int dim = CTAOrder[i]; + unsigned split = splitNums[dim]; + unsigned total = ctas[dim]; + assert(total % split == 0 && "invalid CTA encoding parameters"); + layout *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * + LinearLayout::zeros1D(total / split, kBlock, outDimNames[dim]); + } + + layout = layout.transposeOuts(outDimNames); + return CTAEncodingAttr::get(ctx, layout); +} + +SmallVector CTAEncodingAttr::getCTAsPerCGA() const { + auto ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), StringAttr::get(getContext(), "block"), + rank, /*skipBroadcast=*/false); +} + +SmallVector CTAEncodingAttr::getCTASplitNum() const { + auto ll = getLinearLayout(); + auto rank = ll.getNumOutDims(); + return basesPerDimImpl(ll.getBases(), StringAttr::get(getContext(), "block"), + rank); +} + +SmallVector CTAEncodingAttr::getCTAOrder() const { + auto rank = getRank(); + SmallVector defaultOrder(rank); + std::iota(defaultOrder.begin(), defaultOrder.end(), 0); + return orderPerDimImpl(getLinearLayout(), + StringAttr::get(getContext(), "block"), defaultOrder); +} + +LogicalResult BlockedEncodingAttr::verify( + function_ref emitError, + ArrayRef sizePerThread, ArrayRef threadsPerWarp, + ArrayRef warpsPerCTA, ArrayRef order, + CTAEncodingAttr CTALayout) { if (!llvm::all_equal({sizePerThread.size(), threadsPerWarp.size(), warpsPerCTA.size(), order.size()})) { return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and " @@ -681,15 +737,29 @@ static void printLinearLayout(AsmPrinter &printer, const LinearLayout &ll) { }); } -// Print the CTALayout if it's not equal to the default. +// Print the CTA encoding as `CGALayout = [[...]]` when the layout is +// non-trivial. static void maybePrintCTALayout(mlir::MLIRContext *context, - mlir::AsmPrinter &printer, CTALayoutAttr layout, - unsigned rank) { - if (layout != CTALayoutAttr::getDefault(context, rank)) { - printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" - << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" - << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; - } + mlir::AsmPrinter &printer, + CTAEncodingAttr layout, unsigned rank) { + if (layout == CTAEncodingAttr::getDefault(context, rank)) + return; + + auto kBlock = StringAttr::get(context, "block"); + const auto &basesMap = layout.getLinearLayout().getBases(); + auto it = basesMap.find(kBlock); + assert(it != basesMap.end()); + const auto &bases = it->second; + // This is the default layout + assert(!bases.empty()); + + printer << ", CGALayout = ["; + llvm::interleaveComma(bases, printer, [&](const std::vector &vec) { + printer << "["; + llvm::interleaveComma(vec, printer); + printer << "]"; + }); + printer << "]"; } //===----------------------------------------------------------------------===// @@ -697,27 +767,54 @@ static void maybePrintCTALayout(mlir::MLIRContext *context, //===----------------------------------------------------------------------===// #include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc" + #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" +#undef GET_ATTRDEF_CLASSES //===----------------------------------------------------------------------===// // Blocked Encoding //===----------------------------------------------------------------------===// -static std::optional getCTALayoutOrError( - AsmParser &parser, std::optional> CTAsPerCGA, - std::optional> CTASplitNum, - std::optional> CTAOrder, unsigned rank) { - if (CTAsPerCGA && CTASplitNum && CTAOrder) { - return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, - *CTAOrder); +std::optional parseCTAAttr(AsmParser &parser, Attribute attr, + unsigned rank) { + if (!attr) + return CTAEncodingAttr::getDefault(parser.getContext(), rank); + + auto array = llvm::dyn_cast(attr); + if (!array) { + parser.emitError(parser.getNameLoc(), + "expected array value for 'CGALayout'"); + return {}; } - if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { - return CTALayoutAttr::getDefault(parser.getContext(), rank); + + auto ctx = parser.getContext(); + auto cgaName = StringAttr::get(ctx, "CGALayout"); + std::vector> bases; + bases.reserve(array.size()); + for (Attribute vecAttr : array) { + SmallVector basisValues; + NamedAttribute basisAttr(cgaName, vecAttr); + if (parseIntArrayAttr(parser, basisAttr, basisValues, "CGALayout entry") + .failed()) + return {}; + if (basisValues.size() != rank) { + parser.emitError(parser.getNameLoc()) + << "'CGALayout' entry length does not match rank " << rank; + return {}; + } + std::vector basis; + basis.reserve(basisValues.size()); + for (unsigned value : basisValues) + basis.push_back(static_cast(value)); + bases.push_back(std::move(basis)); } - parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " - "must all be present or all be absent"); - return std::nullopt; + + LinearLayout::BasesT namedBases; + namedBases.insert( + std::make_pair(StringAttr::get(ctx, "block"), std::move(bases))); + LinearLayout ll(namedBases, standardOutDimNames(ctx, rank)); + return CTAEncodingAttr::get(ctx, std::move(ll)); } Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { @@ -734,9 +831,7 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { SmallVector threadsPerWarp; SmallVector warpsPerCTA; SmallVector order; - std::optional> CTAsPerCGA; - std::optional> CTASplitNum; - std::optional> CTAOrder; + Attribute ctaAttr = nullptr; for (const NamedAttribute &attr : dict) { if (attr.getName() == "sizePerThread") { @@ -757,18 +852,8 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { } else if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; - } else if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") - .failed()) - return {}; - } else if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") - .failed()) - return {}; - } else if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") - .failed()) - return {}; + } else if (attr.getName() == "CGALayout") { + ctaAttr = attr.getValue(); } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") << attr.getName().strref(); @@ -776,8 +861,8 @@ Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) { } } - std::optional CTALayout = getCTALayoutOrError( - parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/sizePerThread.size()); + std::optional CTALayout = + parseCTAAttr(parser, ctaAttr, /*rank=*/sizePerThread.size()); if (!CTALayout.has_value()) return {}; @@ -886,7 +971,7 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { static SmallVector basesPerDimImpl(const LinearLayout::BasesT &namedBases, StringAttr dimName, - size_t rank, bool skipBroadcast = true) { + size_t rank, bool skipBroadcast) { const auto &bases = namedBases.find(dimName)->second; if (bases.empty()) { @@ -920,6 +1005,31 @@ LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const { return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast); } +CTAEncodingAttr linearToCTAEncodingAttr(const LinearLayout &ll, + ArrayRef cgaLogicalShape) { + // Compute the shapePerCTA + auto shape = ll.getOutDims(); + for (int i = 0; i < shape.size(); ++i) { + shape[i].second /= cgaLogicalShape[i]; + } + auto inDims = to_vector(ll.getInDimNames()); + auto kBlock = inDims.back(); + assert(kBlock.str() == "block"); + inDims.pop_back(); + auto outDims = to_vector(ll.getOutDimNames()); + auto subLl = ll.sublayout(inDims, outDims); + // sublayout returns the same output size. We trim it to the + // real size + subLl = LinearLayout(subLl.getBases(), shape, false); + // The ctaLayout is what we get after dividing on the left by + // the layout in a single CTA + auto maybeCtaLayout = divideLeft(ll, subLl); + assert(maybeCtaLayout.has_value()); + auto *ctx = inDims[0].getContext(); + auto ctaLayout = maybeCtaLayout->sublayout({kBlock}, outDims); + return CTAEncodingAttr::get(ctx, std::move(ctaLayout)); +} + SmallVector LinearEncodingAttr::orderPerDim(StringAttr dimName, ArrayRef defaultOrder) const { @@ -941,16 +1051,9 @@ SmallVector LinearEncodingAttr::getRepOrder() const { return getOrder(); } -SmallVector LinearEncodingAttr::getCTAsPerCGA() const { - // CTAs are split into an identity part (SplitNum) and a broadcast part - return basesPerDim(StringAttr::get(getContext(), "block"), - /*skipBroadcast=*/false); -} -SmallVector LinearEncodingAttr::getCTAOrder() const { - return orderPerDim(StringAttr::get(getContext(), "block"), getOrder()); -} -SmallVector LinearEncodingAttr::getCTASplitNum() const { - return basesPerDim(StringAttr::get(getContext(), "block")); +CTAEncodingAttr LinearEncodingAttr::getCTALayout() const { + auto splitNum = basesPerDim(StringAttr::get(getContext(), "block")); + return linearToCTAEncodingAttr(getLinearLayout(), splitNum); } SmallVector LinearEncodingAttr::getWarpsPerCTA() const { return basesPerDim(StringAttr::get(getContext(), "warp")); @@ -970,13 +1073,13 @@ SmallVector LinearEncodingAttr::getSizePerThread() const { auto ll = getLinearLayout(); auto ctx = getContext(); auto kRegister = StringAttr::get(ctx, "register"); + auto splitNum = getCTALayout().getCTASplitNum(); // We canonicalize on the spot, as if we use CGAs the regs are not in // canonical form The order is [reg, lane, warp, rep, block], so we first // remove the blocks llvm::SmallVector ctaShape; - for (auto [shape, cgaNum] : - llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) { + for (auto [shape, cgaNum] : llvm::zip(ll.getOutDimSizes(), splitNum)) { ctaShape.push_back(shape / cgaNum); } LinearLayout::BasesT bases = ll.getBases(); @@ -1096,10 +1199,8 @@ Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { unsigned versionMajor = 0; unsigned versionMinor = 0; SmallVector warpsPerCTA; - std::optional> CTAsPerCGA; - std::optional> CTASplitNum; - std::optional> CTAOrder; SmallVector instrShape; + Attribute ctaAttr = nullptr; for (const NamedAttribute &attr : dict) { if (attr.getName() == "versionMajor") { @@ -1114,20 +1215,9 @@ Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed()) return {}; } - if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") - .failed()) - return {}; - } - if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") - .failed()) - return {}; - } - if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") - .failed()) - return {}; + if (attr.getName() == "CGALayout") { + ctaAttr = attr.getValue(); + continue; } if (attr.getName() == "instrShape") { if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) { @@ -1136,8 +1226,8 @@ Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) { } } - std::optional CTALayout = getCTALayoutOrError( - parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + std::optional CTALayout = + parseCTAAttr(parser, ctaAttr, /*rank=*/warpsPerCTA.size()); if (!CTALayout.has_value()) return {}; @@ -1175,11 +1265,9 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { SmallVector warpsPerCTA; SmallVector instrShape; bool isTransposed; - std::optional> CTAsPerCGA; - std::optional> CTASplitNum; - std::optional> CTAOrder; SmallVector tilesPerWarp = {}; unsigned elementBitWidth = 32; + Attribute ctaAttr = nullptr; for (const NamedAttribute &attr : dict) { if (attr.getName() == "version") { @@ -1198,20 +1286,9 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) return {}; } - if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") - .failed()) - return {}; - } - if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") - .failed()) - return {}; - } - if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") - .failed()) - return {}; + if (attr.getName() == "CGALayout") { + ctaAttr = attr.getValue(); + continue; } if (attr.getName() == "tilesPerWarp") { if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp") @@ -1224,8 +1301,8 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) { } } - std::optional CTALayout = getCTALayoutOrError( - parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + std::optional CTALayout = + parseCTAAttr(parser, ctaAttr, /*rank=*/warpsPerCTA.size()); if (!CTALayout.has_value()) return {}; @@ -1263,8 +1340,8 @@ LogicalResult AMDMfmaEncodingAttr::verify( function_ref emitError, unsigned version, llvm::ArrayRef warpsPerCTA, llvm::ArrayRef instrShape, bool isTransposed, - mlir::triton::gpu::CTALayoutAttr, llvm::ArrayRef tilesPerWarp, - unsigned elementBitWidth) { + mlir::triton::gpu::CTAEncodingAttr, + llvm::ArrayRef tilesPerWarp, unsigned elementBitWidth) { if (!(version >= 0 && version <= 4)) { return emitError() << "version must be in the [0, 4] range"; } @@ -1303,11 +1380,9 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { unsigned version = 0; bool isTransposed = false; SmallVector warpsPerCTA; - std::optional> CTAsPerCGA; - std::optional> CTASplitNum; - std::optional> CTAOrder; SmallVector tilesPerWarp = {}; SmallVector instrShape = getDefaultInstrShape(); + Attribute ctaAttr = nullptr; for (const NamedAttribute &attr : dict) { if (attr.getName() == "version") { @@ -1327,20 +1402,9 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { .failed()) return {}; } - if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") - .failed()) - return {}; - } - if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") - .failed()) - return {}; - } - if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") - .failed()) - return {}; + if (attr.getName() == "CGALayout") { + ctaAttr = attr.getValue(); + continue; } if (attr.getName() == "instrShape") { instrShape.clear(); @@ -1350,8 +1414,8 @@ Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) { } } - std::optional CTALayout = getCTALayoutOrError( - parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); + std::optional CTALayout = + parseCTAAttr(parser, ctaAttr, /*rank=*/warpsPerCTA.size()); if (!CTALayout.has_value()) return {}; @@ -1385,7 +1449,7 @@ void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const { LogicalResult AMDWmmaEncodingAttr::verify( function_ref emitError, unsigned version, bool isTransposed, llvm::ArrayRef warpsPerCTA, - llvm::ArrayRef tilesPerWarp, CTALayoutAttr ctaLayout, + llvm::ArrayRef tilesPerWarp, CTAEncodingAttr ctaLayout, llvm::ArrayRef instrShape) { if (!(version >= 1 && version <= 3)) return emitError() << "WMMA version must be in the [1, 3] range"; @@ -1439,7 +1503,7 @@ void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { LogicalResult SliceEncodingAttr::verify(function_ref emitError, unsigned dim, DistributedEncodingTrait parent) { - unsigned rank = cast(parent).getRank(); + unsigned rank = ::getCTALayout(parent).getRank(); if (rank <= 1) return emitError() << "parent layout must have at least rank >= 2"; if (dim >= rank) { @@ -1454,39 +1518,10 @@ SmallVector SliceEncodingAttr::getRepOrder() const { return eraseOrder(parentRepOrder, getDim()); } -SmallVector SliceEncodingAttr::getCTASplitNum() const { - SmallVector res = ::getCTASplitNum(getParent()); - res.erase(res.begin() + getDim()); - return res; -} - -SmallVector SliceEncodingAttr::getCTAOrder() const { - auto parentCTAOrder = ::getCTAOrder(getParent()); - return eraseOrder(parentCTAOrder, getDim()); -} - -SmallVector SliceEncodingAttr::getCTAsPerCGA() const { - auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); - if (parentCTAsPerCGA[getDim()] == 1) { - parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); - return parentCTAsPerCGA; - } - /* For getCTAsPerCGA of a slice layout, we have two choices: - * (1) Return CTAsPerCGA of its parent. This is not a perfect solution - * because the rank of the returned CTAsPerCGA does not match the rank of - * tensorShape. - * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a - * perfect solution because the product of the returned CTAsPerCGA might not - * match numCTAs. - * To avoid introducing inconsistencies to the shape and - * layout system, the usage of directly getting CTAsPerCGA of a slice layout - * in which the sliced dim is not 1 is banned. You should always consider - * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) - * in the branch where layout is an instance of SliceEncodingAttr. This is - * inconvenient but safe. - */ - llvm::report_fatal_error( - "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); +CTAEncodingAttr SliceEncodingAttr::getCTALayout() const { + auto layout = ::getCTALayout(getParent()).getLinearLayout(); + layout = removeStandardDim(layout, getDim()); + return CTAEncodingAttr::get(getContext(), layout); } template @@ -1509,39 +1544,6 @@ SliceEncodingAttr::paddedShape(ArrayRef shape) const; template SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const; -//===----------------------------------------------------------------------===// -// Helper shared encoding functions -//===----------------------------------------------------------------------===// - -std::optional -parseCTAAttrs(AsmParser &parser, NamedAttrList attrList, unsigned rank) { - std::optional> CTAsPerCGA; - std::optional> CTASplitNum; - std::optional> CTAOrder; - - for (const NamedAttribute &attr : attrList) { - if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") - .failed()) - return {}; - } else if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") - .failed()) - return {}; - } else if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") - .failed()) - return {}; - } else { - parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.getName().strref(); - return {}; - } - } - - return getCTALayoutOrError(parser, CTAsPerCGA, CTASplitNum, CTAOrder, rank); -} - template Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { if (parser.parseLess().failed()) @@ -1557,7 +1559,7 @@ Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { unsigned perPhase = 0; unsigned maxPhase = 0; SmallVector order; - NamedAttrList remainingAttrs; + Attribute ctaAttr = nullptr; for (const NamedAttribute &attr : dict) { if (attr.getName() == "vec") { if (parseUInt(parser, attr, vec, "vec").failed()) @@ -1572,11 +1574,17 @@ Attribute parseSwizzledEncoding(AsmParser &parser, Type type) { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; } else { - remainingAttrs.push_back(attr); + if (attr.getName() == "CGALayout") { + ctaAttr = attr.getValue(); + } else { + parser.emitError(parser.getNameLoc(), "unexpected key: ") + << attr.getName().strref(); + return {}; + } } } - if (auto CTALayout = parseCTAAttrs(parser, remainingAttrs, order.size())) + if (auto CTALayout = parseCTAAttr(parser, ctaAttr, order.size())) return parser.getChecked( parser.getContext(), vec, perPhase, maxPhase, order, *CTALayout); return {}; @@ -1590,7 +1598,7 @@ LogicalResult SwizzledSharedEncodingAttr::verify(function_ref emitError, unsigned vec, unsigned perPhase, unsigned maxPhase, ArrayRef order, - CTALayoutAttr ctaLayout) { + CTAEncodingAttr ctaLayout) { if (order.size() != ctaLayout.getRank()) { return emitError() << "order size (" << order.size() << ") must match CTALayout rank (" << ctaLayout.getRank() @@ -1764,19 +1772,10 @@ SmallVector SharedLinearEncodingAttr::getOrder() const { return orderPerDim(StringAttr::get(getContext(), "offset"), defaultOrder); } -SmallVector SharedLinearEncodingAttr::getCTAsPerCGA() const { - return basesPerDim(StringAttr::get(getContext(), "block"), - /*skipBroadcast=*/false); -} - -SmallVector SharedLinearEncodingAttr::getCTAOrder() const { - return orderPerDim(StringAttr::get(getContext(), "block"), getOrder()); +CTAEncodingAttr SharedLinearEncodingAttr::getCTALayout() const { + auto splitNum = basesPerDim(StringAttr::get(getContext(), "block")); + return linearToCTAEncodingAttr(getLinearLayout(), splitNum); } - -SmallVector SharedLinearEncodingAttr::getCTASplitNum() const { - return basesPerDim(StringAttr::get(getContext(), "block")); -} - LinearLayout SharedLinearEncodingAttr::toLinearLayout(ArrayRef shape) const { auto ll = getLinearLayout(); @@ -1863,7 +1862,8 @@ Attribute PaddedSharedEncodingAttr::parse(AsmParser &parser, Type type) { auto kOffset = StringAttr::get(parser.getContext(), "offset"); maybeLL = identityStandardND(kOffset, shape, order); maybeLL = combineCtaCgaWithShape( - *maybeLL, CTALayoutAttr::getDefault(parser.getContext(), shape.size()), + *maybeLL, + CTAEncodingAttr::getDefault(parser.getContext(), shape.size()), SmallVector(ArrayRef(shape))); } @@ -1992,7 +1992,7 @@ LogicalResult PaddedSharedEncodingAttr::verify( PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get( MLIRContext *context, ArrayRef> intervalPads, ArrayRef order, ArrayRef shape, - CTALayoutAttr ctaLayout) { + CTAEncodingAttr ctaLayout) { auto outDimNames = standardOutDimNames(context, shape.size()); StringAttr kOffset = StringAttr::get(context, "offset"); @@ -2055,19 +2055,10 @@ SmallVector PaddedSharedEncodingAttr::getOrder() const { return orderPerDim(StringAttr::get(getContext(), "offset"), order); } -// LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>; -SmallVector PaddedSharedEncodingAttr::getCTAsPerCGA() const { - // CTAs are split into an identity part (SplitNum) and a broadcast part - return basesPerDim(StringAttr::get(getContext(), "block"), - /*skipBroadcast=*/false); -} -SmallVector PaddedSharedEncodingAttr::getCTAOrder() const { - return orderPerDim(StringAttr::get(getContext(), "block"), getOrder()); -} -SmallVector PaddedSharedEncodingAttr::getCTASplitNum() const { - return basesPerDim(StringAttr::get(getContext(), "block")); +CTAEncodingAttr PaddedSharedEncodingAttr::getCTALayout() const { + auto splitNum = basesPerDim(StringAttr::get(getContext(), "block")); + return linearToCTAEncodingAttr(getLinearComponent(), splitNum); } - //===----------------------------------------------------------------------===// // NVMMAShared encoding //===----------------------------------------------------------------------===// @@ -2087,9 +2078,7 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) { bool fp4Padded = false; unsigned elementBitWidth; unsigned layoutRank = 2; - std::optional> CTAsPerCGA; - std::optional> CTASplitNum; - std::optional> CTAOrder; + Attribute ctaAttr = nullptr; for (const NamedAttribute &attr : dict) { if (attr.getName() == "swizzlingByteWidth") { if (parseUInt(parser, attr, swizzlingByteWidth, "swizzlingByteWidth") @@ -2104,18 +2093,8 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) { } else if (attr.getName() == "fp4Padded") { if (parseBool(parser, attr, fp4Padded, "fp4Padded").failed()) return {}; - } else if (attr.getName() == "CTAsPerCGA") { - if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA") - .failed()) - return {}; - } else if (attr.getName() == "CTASplitNum") { - if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum") - .failed()) - return {}; - } else if (attr.getName() == "CTAOrder") { - if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder") - .failed()) - return {}; + } else if (attr.getName() == "CGALayout") { + ctaAttr = attr.getValue(); } else if (attr.getName() == "rank") { if (parseUInt(parser, attr, layoutRank, "rank").failed()) return {}; @@ -2126,8 +2105,8 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) { } } - std::optional CTALayout = getCTALayoutOrError( - parser, CTAsPerCGA, CTASplitNum, CTAOrder, layoutRank); + std::optional CTALayout = + parseCTAAttr(parser, ctaAttr, layoutRank); if (!CTALayout.has_value()) return {}; @@ -2147,7 +2126,7 @@ void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const { } unsigned rank = getCTALayout().getCTAOrder().size(); auto *ctx = getContext(); - auto defaultLayout = CTALayoutAttr::getDefault(ctx, rank); + auto defaultLayout = CTAEncodingAttr::getDefault(ctx, rank); if (getCTALayout() == defaultLayout && rank != 2) { printer << ", rank = " << rank; } else { @@ -2266,7 +2245,7 @@ AMDMfmaEncodingAttr::getRepForOperand(ArrayRef operandShape, } SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( - CTALayoutAttr ctaLayout, int operandIdx, ArrayRef operandShape, + CTAEncodingAttr ctaLayout, int operandIdx, ArrayRef operandShape, ArrayRef sharedOrder, unsigned vectorSize, unsigned elemBitWidth, bool needTrans) const { int kDimIndex = operandIdx == 0 ? 1 : 0; @@ -2363,7 +2342,7 @@ AMDWmmaEncodingAttr::getRepForOperand(ArrayRef operandShape, int kDim, } SwizzledSharedEncodingAttr AMDWmmaEncodingAttr::composeSharedLayoutForOperand( - CTALayoutAttr ctaLayout, int operandIdx, ArrayRef operandShape, + CTAEncodingAttr ctaLayout, int operandIdx, ArrayRef operandShape, ArrayRef sharedOrder, unsigned kWidth, unsigned elemBitWidth, bool needTrans) const { int kDimIndex = operandIdx == 0 ? 1 : 0; @@ -2478,25 +2457,20 @@ SmallVector DotOperandEncodingAttr::getRepOrder() const { return {}; } -SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { - return ::getCTAsPerCGA(getParent()); -} - -SmallVector DotOperandEncodingAttr::getCTAOrder() const { - return ::getCTAOrder(getParent()); -} - -SmallVector DotOperandEncodingAttr::getCTASplitNum() const { - SmallVector res = ::getCTASplitNum(getParent()); - auto rank = res.size(); - assert(rank == 2 || rank == 3 && "Invalid dotLayout"); - - // Do not split CTA in K dimension +CTAEncodingAttr DotOperandEncodingAttr::getCTALayout() const { + auto layout = ::getCTALayout(getParent()).getLinearLayout(); + auto bases = layout.getBases(); + auto kBlock = StringAttr::get(getContext(), "block"); + auto &blockBases = bases[kBlock]; + auto rank = layout.getNumOutDims(); auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; - res[kDim] = 1; - return res; + for (auto &basis : blockBases) { + basis[kDim] = 0; + } + auto dims = layout.getOutDims(); + dims[kDim].second = 1; + return CTAEncodingAttr::get(getContext(), LinearLayout(bases, dims, true)); } - LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned opIdx, Attribute parent, unsigned kWidth) { @@ -2531,9 +2505,9 @@ LogicalResult DotOperandEncodingAttr::verify( return emitError() << "ttg.dot_op kWidth parameter must be 4/8/16 for WMMA v2 " "(including packed cases for `scaled_dot`)"; - if (parentAttr.getVersion() == 3 && !llvm::is_contained({2, 8, 16}, kWidth)) + if (parentAttr.getVersion() == 3 && kWidth == 0) return emitError() - << "ttg.dot_op kWidth parameter must be 2/8/16 for WMMA v3"; + << "ttg.dot_op kWidth parameter is mandatory for WMMA v3 "; return success(); } @@ -2700,17 +2674,22 @@ struct TritonGPUInferLayoutInterface } return success(); }; - auto *ctx = getDialect()->getContext(); + + auto permuteCTALayout = [ctx](CTAEncodingAttr layout, + ArrayRef order) { + auto ll = transposeLinearLayout(layout.getLinearLayout(), order); + return CTAEncodingAttr::get(ctx, std::move(ll)); + }; + auto invOrder = inversePermutation(order); SmallVector invOrderUnsigned(invOrder.begin(), invOrder.end()); if (auto enc = dyn_cast(operandEncoding)) { - if (failed(checkRank(enc.getRank()))) + if (failed(checkRank(enc.getCTALayout().getRank()))) return failure(); - CTALayoutAttr ctaLayout = - permuteCTALayout(ctx, enc.getCTALayout(), order); + CTAEncodingAttr ctaLayout = permuteCTALayout(enc.getCTALayout(), order); resultEncoding = SwizzledSharedEncodingAttr::get( ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(), applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout); @@ -2719,11 +2698,10 @@ struct TritonGPUInferLayoutInterface if (auto enc = dyn_cast(operandEncoding)) { if (order == ArrayRef({1, 0})) { - if (failed(checkRank(enc.getRank()))) + if (failed(checkRank(enc.getCTALayout().getRank()))) return failure(); - CTALayoutAttr ctaLayout = - permuteCTALayout(ctx, enc.getCTALayout(), order); + CTAEncodingAttr ctaLayout = permuteCTALayout(enc.getCTALayout(), order); resultEncoding = NVMMASharedEncodingAttr::get( ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(), enc.getElementBitWidth(), enc.getFp4Padded(), ctaLayout); @@ -2732,11 +2710,10 @@ struct TritonGPUInferLayoutInterface } if (auto enc = dyn_cast(operandEncoding)) { - if (failed(checkRank(enc.getRank()))) + if (failed(checkRank(enc.getCTALayout().getRank()))) return failure(); - CTALayoutAttr ctaLayout = - permuteCTALayout(ctx, enc.getCTALayout(), order); + CTAEncodingAttr ctaLayout = permuteCTALayout(enc.getCTALayout(), order); resultEncoding = BlockedEncodingAttr::get( ctx, applyPermutation(enc.getSizePerThread(), order), applyPermutation(enc.getThreadsPerWarp(), order), @@ -2903,8 +2880,11 @@ struct TritonGPUInferLayoutInterface // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA // should be like the other fields in blocked encoding, but I'm not sure how // to handle CTASplitNum. - if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || - !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { + auto srcCTALayout = src.getCTALayout(); + if (!all_of(srcCTALayout.getCTAsPerCGA(), + [](int32_t x) { return x == 1; }) || + !all_of(srcCTALayout.getCTASplitNum(), + [](int32_t x) { return x == 1; })) { return failure(); } @@ -3079,11 +3059,8 @@ struct TritonGPUInferLayoutInterface auto dstOrder = inversePermutation(dstInvOrder); // CTALayout can be all 1's because we bailed on multi-CTA layouts above. - auto CTALayout = CTALayoutAttr::get( - src.getContext(), - /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), - /*CTASplitNum=*/SmallVector(dstShape.size(), 1), - /*CTAOrder=*/llvm::to_vector(llvm::seq(dstShape.size()))); + auto CTALayout = + CTAEncodingAttr::getDefault(src.getContext(), dstShape.size()); dstEnc = BlockedEncodingAttr::get(src.getContext(), dstSizePerThread, dstThreadsPerWarp, dstWarpsPerCTA, @@ -3178,13 +3155,16 @@ struct TritonGPUInferLayoutInterface ret.insert(ret.begin(), ret.size()); return ret; }; + auto ctall = enc.getCTALayout().getLinearLayout(); + auto kBlock = StringAttr::get(enc.getContext(), "block"); + auto newDim = standardOutDimNames( + enc.getContext(), ctall.getNumOutDims() + 1)[ctall.getNumOutDims()]; + ctall *= LinearLayout::identity1D(1, kBlock, newDim); dstEnc = BlockedEncodingAttr::get( enc.getContext(), append(enc.getSizePerThread(), 2), append(enc.getThreadsPerWarp(), 1), append(enc.getWarpsPerCTA(), 1), appendMajorDim(enc.getOrder()), - CTALayoutAttr::get(enc.getContext(), append(enc.getCTAsPerCGA(), 1), - append(enc.getCTASplitNum(), 1), - appendMajorDim(enc.getCTAOrder()))); + CTAEncodingAttr::get(enc.getContext(), ctall)); return success(); } @@ -3217,22 +3197,22 @@ struct TritonGPUInferLayoutInterface bool isSimpleSplit = (enc && (enc.getSizePerThread().back() == 2) && (enc.getThreadsPerWarp().back() == 1) && (enc.getWarpsPerCTA().back() == 1) && - (enc.getCTAsPerCGA().back() == 1)); + (enc.getCTALayout().getCTAsPerCGA().back() == 1)); if (isSimpleSplit) { SmallVector newOrder(enc.getOrder()); + auto ctall = enc.getCTALayout().getLinearLayout(); int splitDim = newOrder.size() - 1; // Remove splitDim from order. newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim), newOrder.end()); + // Remove last dimension from ctall. + ctall = ctall.unsqueezeOut(to_vector(ctall.getOutDimNames()).back()); dstEnc = BlockedEncodingAttr::get( enc.getContext(), // ArrayRef(enc.getSizePerThread()).drop_back(1), ArrayRef(enc.getThreadsPerWarp()).drop_back(1), ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder), - CTALayoutAttr::get(enc.getContext(), // - ArrayRef(enc.getCTAsPerCGA()).drop_back(1), - ArrayRef(enc.getCTASplitNum()).drop_back(1), - ArrayRef(enc.getCTAOrder()).drop_front(1))); + CTAEncodingAttr::get(enc.getContext(), ctall)); return success(); } diff --git a/lib/Dialect/TritonGPU/IR/LayoutUtility.cpp b/lib/Dialect/TritonGPU/IR/LayoutUtility.cpp deleted file mode 100644 index e4d6bcc364..0000000000 --- a/lib/Dialect/TritonGPU/IR/LayoutUtility.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include - -#include -#include -#include - -namespace mlir::triton::gpu { - -CTALayoutAttr permuteCTALayout(MLIRContext *ctx, CTALayoutAttr layout, - ArrayRef order) { - auto n = order.size(); - assert(n == layout.getRank() && "order and layout rank mismatch"); - - auto invOrder = inversePermutation(order); - llvm::SmallVector invOrderUnsigned(invOrder.begin(), - invOrder.end()); - return CTALayoutAttr::get( - ctx, applyPermutation(layout.getCTAsPerCGA(), order), - applyPermutation(layout.getCTASplitNum(), order), - applyPermutation(invOrderUnsigned, layout.getCTAOrder())); -} - -} // namespace mlir::triton::gpu diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 066552a05b..fe3839cf4d 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -28,10 +28,10 @@ namespace { // for register layouts, and input dims [offset] for shared layouts. // - cgaLayout: Arrangement of multiple blocks, i.e. input dims [block]. // -// Note that this is inconsistent with the type name CTALayoutAttr. That type +// Note that this is inconsistent with the type name CTAEncodingAttr. That type // is equivalent to our cgaLayout. // -// IMO the name CTALayoutAttr is wrong. If we tried to be consistent anyway, +// IMO the name CTAEncodingAttr is wrong. If we tried to be consistent anyway, // then we'd have to rename ctaLayout to "warpLayout". I think that's more // confusing than being inconsistent about "cgaLayout", especially when we have // to consider the size of the warpLayout (surely that's not the "warpSize"). @@ -157,28 +157,6 @@ sharedToLinearLayoutAMDRotating(ArrayRef shape, } // namespace -LinearLayout makeCgaLayout(CTALayoutAttr layout) { - MLIRContext *ctx = layout.getContext(); - StringAttr kBlock = S("block"); - - int rank = layout.getCTAOrder().size(); - SmallVector outDimNames = standardOutDimNames(ctx, rank); - - LinearLayout ret = LinearLayout::empty(); - for (int i = 0; i < rank; i++) { - // Start with the most minor dimension, which is order[0]. - int dim = layout.getCTAOrder()[i]; - int split = layout.getCTASplitNum()[dim]; - int ctas = layout.getCTAsPerCGA()[dim]; - assert(ctas % split == 0); - ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) * - LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]); - } - - // Transpose to standard order (dim0, dim1, ...). - return ret.transposeOuts(outDimNames); -} - // Returns the layout of a single core matrix which tiles the nvmma layout LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared, bool disableSwizzle) { @@ -1045,28 +1023,7 @@ LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { parentShape.insert(parentShape.begin() + getDim(), 1); LinearLayout parentLL = triton::gpu::toLinearLayout(parentShape, getParent()); - // Remove dimension getDim() from the parent layout. - // - // 1. Construct a layout `transform` from parent-out-dims to slice-out-dims - // that removes the relevant out-dim. - // 2. Compute linearSlice = parent.compose(transform). Now linearSlice maps - // from parent in-dims to slice out-dims. - // 3. Fix up duplicate registers introduced by slicing. - auto outDimNames = standardOutDimNames(ctx, shape.size() + 1); - LinearLayout transform = LinearLayout::empty(); - for (auto [idx, outDim] : llvm::enumerate(parentLL.getOutDimNames())) { - if (idx == getDim()) { - // Because we're multiplying by all zeros, we could replace outDimNames[0] - // with any other valid out-dim; the layout will be the same. - transform *= LinearLayout::zeros1D(parentLL.getOutDimSize(outDim), outDim, - outDimNames[0]); - } else { - transform *= - LinearLayout::identity1D(parentLL.getOutDimSize(outDim), outDim, - outDimNames[idx - (idx < getDim() ? 0 : 1)]); - } - } - LinearLayout sliceLL = parentLL.compose(transform); + auto sliceLL = removeStandardDim(parentLL, getDim()); // Step 3: Along the "register" dim, remove any all-zero bases. auto bases = sliceLL.getBases(); @@ -1312,7 +1269,7 @@ LinearLayout getLayoutWithinBlock(const LinearLayout &layout) { } LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, - CTALayoutAttr cgaLayoutAttr, + CTAEncodingAttr cgaLayoutAttr, ArrayRef shape) { int rank = shape.size(); assert(ctaLayout.getNumOutDims() == rank); @@ -1327,7 +1284,7 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, } LinearLayout cgaLayout = - ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape) + ensureLayoutNotLargerThan(cgaLayoutAttr.getLinearLayout(), labeledShape) .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); // Calculate the shape of the ctaLayout, which is `shape` divided by the @@ -1460,7 +1417,7 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, warpLayout.transposeOuts(outDimNames); return combineCtaCgaWithShape( - ctaLayout, CTALayoutAttr::getDefault(ctx, /*rank=*/2), dotOperandShape); + ctaLayout, CTAEncodingAttr::getDefault(ctx, /*rank=*/2), dotOperandShape); } // PTX ISA - Warp-level MMA Block Scaling @@ -1476,7 +1433,7 @@ LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, ArrayRef shape, int opIdx, ArrayRef warpsPerCTA, - CTALayoutAttr ctaLayout) { + CTAEncodingAttr ctaLayout) { unsigned rank = shape.size(); auto outDims = standardOutDimNames(ctx, rank); StringAttr kRegister = StringAttr::get(ctx, "register"); @@ -1590,8 +1547,7 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * warpLayout.transposeOuts(outDimNames); - auto ctaLay = CTALayoutAttr::get(/*context=*/ctx, /*CTAsPerCGA=*/{1, 1}, - /*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0}); + auto ctaLay = CTAEncodingAttr::getDefault(ctx, 2); auto finalLay = combineCtaCgaWithShape(ctaLayout, ctaLay, dotOperandShape); return finalLay; } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 7b3091ed8b..f324eb99cb 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -591,11 +591,7 @@ static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef srcShape, // preserved. Otherwise fall back to the generic shared-linear encoding // logic below. if (innerDimDst == innerDimSrc) { - auto CTALayout = CTALayoutAttr::get( - ctx, - /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), - /*CTASplitNum=*/SmallVector(dstShape.size(), 1), - /*CTAOrder=*/llvm::to_vector(llvm::seq(dstShape.size()))); + auto CTALayout = CTAEncodingAttr::getDefault(ctx, dstShape.size()); auto candidateEncoding = NVMMASharedEncodingAttr::get( ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(), mmaEncoding.getElementBitWidth(), diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 3c0fd5c5e6..5321595b07 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -467,7 +467,7 @@ static bool canUseTwoCTAs(triton::DotOp dotOp) { static DistributedEncodingTrait replaceCTALayout(DistributedEncodingTrait layout, - const triton::gpu::CTALayoutAttr &newCTALayout) { + const triton::gpu::CTAEncodingAttr &newCTALayout) { if (auto blockedLayout = mlir::dyn_cast(layout)) { return BlockedEncodingAttr::get( layout.getContext(), blockedLayout.getSizePerThread(), @@ -494,8 +494,10 @@ static Value splitBOperand(Value b, mlir::PatternRewriter &rewriter) { "expected LoadOp"); RankedTensorType bType = cast(b.getType()); auto currentLayout = cast(bType.getEncoding()); + auto kBlock = StringAttr::get(ctx, "block"); + auto dims = standardOutDimNames(ctx, 2); auto newCTALayout = - CTALayoutAttr::get(ctx, {1, 2}, {1, 2}, getCTAOrder(currentLayout)); + CTAEncodingAttr::get(ctx, LinearLayout({{kBlock, {{0, 1}}}}, dims)); Attribute newLayout = replaceCTALayout(currentLayout, newCTALayout); rewriter.setInsertionPoint(loadOp); for (OpOperand &operand : loadOp->getOpOperands()) { @@ -561,7 +563,7 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern { MLIRContext *context = dotOp->getContext(); auto instrShape = mmaVersionToInstrShape( versionMajor, retShapePerCTA, oldAType.getElementType(), numWarps); - ArrayRef CTASplitNum = CTALayout.getCTASplitNum(); + auto CTASplitNum = CTALayout.getCTASplitNum(); auto bitwidth = oldRetType.getElementType().getIntOrFloatBitWidth(); unsigned colStride = 32 / bitwidth; Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( @@ -816,7 +818,7 @@ class ScaledBlockedToMMAv5 unsigned m = 128; unsigned n = retShapePerCTA[1] >= 256 ? 256 : retShapePerCTA[1]; - ArrayRef CTASplitNum = CTALayout.getCTASplitNum(); + auto CTASplitNum = CTALayout.getCTASplitNum(); auto bitwidth = oldRetType.getElementType().getIntOrFloatBitWidth(); unsigned colStride = 32 / bitwidth; Attribute accEncoding = triton::nvidia_gpu::TensorMemoryEncodingAttr::get( diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index cdb8d7bf14..bd53c00058 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -96,7 +96,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase { int numWarps = lookupNumWarps(curr); auto tensorType = cast(ptr.getType()); - CTALayoutAttr ctaLayout = getCTALayout(tensorType.getEncoding()); + CTAEncodingAttr ctaLayout = getCTALayout(tensorType.getEncoding()); SmallVector shapePerCTA = getShapePerCTA(tensorType); auto layout = buildCoalescedEncoding(&getContext(), axisInfoAnalysis, curr, numWarps, threadsPerWarp, diff --git a/lib/Dialect/TritonGPU/Transforms/CoalesceUtils.cpp b/lib/Dialect/TritonGPU/Transforms/CoalesceUtils.cpp index 6f0c1a7365..4b5766663a 100644 --- a/lib/Dialect/TritonGPU/Transforms/CoalesceUtils.cpp +++ b/lib/Dialect/TritonGPU/Transforms/CoalesceUtils.cpp @@ -16,7 +16,7 @@ namespace mlir::triton::gpu { BlockedEncodingAttr buildCoalescedEncoding( MLIRContext *context, ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, int numWarps, int threadsPerWarp, - triton::gpu::CTALayoutAttr CTALayout, SmallVector shapePerCTA) { + triton::gpu::CTAEncodingAttr CTALayout, SmallVector shapePerCTA) { Value ptr = getMemAccessPtr(op); auto refTensorType = cast(ptr.getType()); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 8612801032..9306a1c1c2 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -7,7 +7,6 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -56,7 +55,9 @@ class SwizzleShmemConvert : public OpRewritePattern { // swizzling code. auto ctx = getContext(); auto oldCTALayout = triton::gpu::getCTALayout(srcTy.getEncoding()); - auto newCTALayout = permuteCTALayout(ctx, oldCTALayout, trans.getOrder()); + auto newLl = + transposeLinearLayout(oldCTALayout.getLinearLayout(), trans.getOrder()); + auto newCTALayout = CTAEncodingAttr::get(ctx, std::move(newLl)); auto newInnerCvtEnc = SwizzledSharedEncodingAttr::get(ctx, cvtEncoding, srcTy.getShape(), /*order=*/getOrderForMemory(srcTy), diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp index d5dbd9ba8b..cc37951076 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -12,6 +12,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/LayoutUtils.h" namespace mlir { namespace triton { @@ -52,7 +53,7 @@ struct OptimizeReshapeLayoutPattern : public OpRewritePattern { // dimension in the same thread we can skip. if (blocked.getThreadsPerWarp()[*reductionAxis] == 1 && blocked.getWarpsPerCTA()[*reductionAxis] == 1 && - blocked.getCTAsPerCGA()[*reductionAxis] == 1) + blocked.getCTALayout().getCTAsPerCGA()[*reductionAxis] == 1) return failure(); } ArrayRef shape = tensorType.getShape(); @@ -191,9 +192,7 @@ static LogicalResult setOptimizedGatherLayout(GatherOp op, RewriterBase &b) { // Construct the new layout. MLIRContext *ctx = srcType.getContext(); auto baseLayout = cast(srcType.getEncoding()); - auto ctaLayout = - CTALayoutAttr::get(ctx, baseLayout.getCTAsPerCGA(), - baseLayout.getCTASplitNum(), baseLayout.getCTAOrder()); + auto ctaLayout = getCTALayout(baseLayout); auto newLayout = BlockedEncodingAttr::get(ctx, sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); @@ -551,14 +550,12 @@ class TritonGPUOptimizeThreadLocalityPass auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1); auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1); auto order3d = insertValue(blocked.getOrder(), 0, rank); - auto ctasPerCGA3d = - insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1); - auto ctasSplitNum3d = - insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1); - auto ctaOrder3d = - insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank); - auto ctaLayout3d = triton::gpu::CTALayoutAttr::get( - reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d); + auto ctaLl = blocked.getCTALayout().getLinearLayout(); + auto kBlocked = *ctaLl.getInDimNames().begin(); + auto *ctx = kBlocked.getContext(); + auto dim = standardOutDimNames(ctx, rank + 1)[rank]; + ctaLl *= LinearLayout::identity1D(1, kBlocked, dim); + auto ctaLayout3d = CTAEncodingAttr::get(ctx, ctaLl); auto blocked3d = triton::gpu::BlockedEncodingAttr::get( reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d, order3d, ctaLayout3d); diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp index a8b6ebe4f1..4c870e73d1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -465,7 +465,7 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule, canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32; sharedEncoding = ttg::SwizzledSharedEncodingAttr::get( forOp.getContext(), 1, 1, 1, {0}, - ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0})); + ttg::CTAEncodingAttr::getDefault(forOp.getContext(), 1)); if (canUseAsyncCp) { scalarLoads.push_back(&op); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index b9374fa96d..608bdab9b2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -15,6 +15,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h" +#include "triton/Tools/LayoutUtils.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include @@ -442,9 +443,13 @@ Value mlir::triton::createScalarAlloc(ImplicitLocOpBuilder &rewriter, Type type, rewriter.getBlock()->getParentOp()->getParentOfType()); Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(rewriter.getContext()); + auto kBlock = StringAttr::get(ctx, "block"); + LinearLayout::BasesT bases; + bases[kBlock] = + std::vector>(llvm::Log2_32(numCTAs), {0}); + auto dims = standardOutDimNames(ctx, 1); auto barrierCTALayout = - ttg::CTALayoutAttr::get(/*context=*/ctx, /*CTAsPerCGA=*/{numCTAs}, - /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + ttg::CTAEncodingAttr::get(ctx, LinearLayout(bases, dims)); auto barrierEncoding = ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, {0}, barrierCTALayout); ttg::MemDescType memDescType = ttg::MemDescType::get( diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index efe0b890dc..3531c0bf6d 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -127,7 +127,7 @@ class LayoutRematerialization { } void cleanup(); - void backwardRematerialization(); + bool backwardRematerialization(); void backwardRematerialization(ConvertLayoutOp convertOp); // TODO: Merge the three hoistConvert*(); functions as they are duplicate code void hoistConvertDotOperand(); @@ -1019,7 +1019,8 @@ LogicalResult LayoutRematerialization::getRematerializableSlice( return success(); } -void LayoutRematerialization::backwardRematerialization() { +bool LayoutRematerialization::backwardRematerialization() { + bool changed = false; // Go through each ConvertLayoutOp. SmallVector convertOps; funcOp.walk( @@ -1031,8 +1032,11 @@ void LayoutRematerialization::backwardRematerialization() { // backward slices. addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), convertOp.getResult()); + } else { + changed = true; } } + return changed; } void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { @@ -1593,12 +1597,14 @@ void LayoutRematerialization::hoistConvertIntoConditionals( rewriteSlice(slice, layout, convertOp, mapping); } -void backwardRematerialization(ModuleOp module) { - module.walk([](FuncOp funcOp) { +bool backwardRematerialization(ModuleOp module) { + bool changed = false; + module.walk([&](FuncOp funcOp) { LayoutRematerialization layoutRemat(funcOp); - layoutRemat.backwardRematerialization(); + changed |= layoutRemat.backwardRematerialization(); layoutRemat.cleanup(); }); + return changed; } void hoistConvert(ModuleOp module) { @@ -1659,17 +1665,20 @@ class TritonGPURemoveLayoutConversionsPass cleanupConvertOps(); - // 2. For remaining convert ops, try to rematerialize the slice of producer - // operation to avoid having to convert. - backwardRematerialization(m); - LLVM_DEBUG({ - DBGS() << "Module after backward remat:\n"; - m.dump(); - }); - - // Cleanup dummy converts created during backward remat. - cleanupConvertOps(); - + bool changed = false; + do { + changed = false; + // 2. For remaining convert ops, try to rematerialize the slice of + // producer operation to avoid having to convert. + changed = backwardRematerialization(m); + LLVM_DEBUG({ + DBGS() << "Module after backward remat:\n"; + m.dump(); + }); + + // Cleanup dummy converts created during backward remat. + cleanupConvertOps(); + } while (changed); // 3. For remaining converts, try to hoist them above cast generating larger // size types in order to reduce the cost of the convert op. hoistConvert(m); diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 753a73fe7f..9ab16bb97f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1083,7 +1083,7 @@ std::optional getAMDArch(Operation *module) { } inline ttg::SwizzledSharedEncodingAttr -swizzleDotOperandLike(RankedTensorType type, ttg::CTALayoutAttr ctaLayout) { +swizzleDotOperandLike(RankedTensorType type, ttg::CTAEncodingAttr ctaLayout) { // We want to see if the linear layout has the same order as an mma microtile // of shape (8, 4*kWidth) or (4*kWidth, 8). If so, we return a // DotOperandEncodingAttr with a tile of this shape This works because diff --git a/lib/Dialect/TritonInstrument/IR/Utility.cpp b/lib/Dialect/TritonInstrument/IR/Utility.cpp index e19b87776c..44c7c1fb4e 100644 --- a/lib/Dialect/TritonInstrument/IR/Utility.cpp +++ b/lib/Dialect/TritonInstrument/IR/Utility.cpp @@ -18,7 +18,7 @@ namespace { BlockedEncodingAttr getThreadLocalBlockedEncoding(MLIRContext *ctx, unsigned int size, unsigned int warps) { - auto ctaLayout = CTALayoutAttr::getDefault(ctx, /*rank=*/1); + auto ctaLayout = CTAEncodingAttr::getDefault(ctx, /*rank=*/1); return BlockedEncodingAttr::get(ctx, /*sizePerThread=*/{size}, /*threadsPerWarp=*/{32}, @@ -30,7 +30,7 @@ BlockedEncodingAttr getThreadLocalBlockedEncoding(MLIRContext *ctx, unsigned int buffers, unsigned int barriers, unsigned int warps) { - auto ctaLayout = CTALayoutAttr::getDefault(ctx, /*rank=*/2); + auto ctaLayout = CTAEncodingAttr::getDefault(ctx, /*rank=*/2); return BlockedEncodingAttr::get(ctx, /*sizePerThread=*/{buffers, barriers}, /*threadsPerWarp=*/{1, 32}, diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp index d69956e4fa..0c986c8499 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -133,7 +133,8 @@ LinearLayout getTileLayout(MLIRContext *ctx, TMemAccessAtom atom, bool unpacked, static std::optional getDistributedLayoutForTmemLdSt( const LinearLayout &ll, TMemAccessAtom atom, unsigned numWarps, - int bitwidth, std::optional ctaLayout = std::nullopt) { + int bitwidth, + std::optional ctaLayout = std::nullopt) { auto dims = to_vector(ll.getOutDimNames()); assert(dims.size() == 2); auto rowColDims = to_vector(ll.getInDimNames()); @@ -142,14 +143,12 @@ static std::optional getDistributedLayoutForTmemLdSt( if (ctaLayout) { // Get CTALayout without broadcasting to divide the ll // as the TMEM layout does not reflect CTA broadcasting - auto splitNum = ctaLayout->getCTASplitNum(); - // The cta order in TMEM is always [0, 1] - auto ctaBlockSplit = CTALayoutAttr::get(ctx, splitNum, splitNum, {0, 1}); - auto ctaBlockSplitLL = gpu::makeCgaLayout(ctaBlockSplit); - assert(ctaBlockSplitLL.getNumOutDims() == ll.getNumOutDims()); - // rename block into col + auto cgaShape = to_vector(ctaLayout->getLinearLayout().getOutDimSizes()); auto kBlock = StringAttr::get(ctx, "block"); - auto ctaCol = ctaBlockSplitLL.renameInDim(kBlock, rowColDims[1]); + // The cta order in TMEM is always [0, 1] + auto ctaCol = + LinearLayout::identity1D(cgaShape[0], rowColDims[1], dims[0]) * + LinearLayout::identity1D(cgaShape[1], rowColDims[1], dims[1]); auto quot = divideRight(ll, ctaCol); assert(quot.has_value()); auto maybeRet = @@ -157,8 +156,7 @@ static std::optional getDistributedLayoutForTmemLdSt( if (!maybeRet) return maybeRet; // Add the full ctaBlock layout (with broadcasting) - auto ctaBlock = gpu::makeCgaLayout(*ctaLayout); - return *maybeRet * ctaBlock; + return *maybeRet * ctaLayout->getLinearLayout(); } // This code is dual to the one in lowerTMemLdSt if (bitwidth != 32) { @@ -307,7 +305,7 @@ static std::optional getDistributedLayoutForTmemLdSt( std::optional getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom, unsigned numWarps, - gpu::CTALayoutAttr ctaLayout) { + gpu::CTAEncodingAttr ctaLayout) { assert(memType.getMemorySpace() == TensorMemorySpaceAttr::get(memType.getContext())); assert(numWarps >= 4 && llvm::isPowerOf2_32(numWarps) && @@ -322,7 +320,7 @@ getDistributedLayoutForTmemLdSt(gpu::MemDescType memType, TMemAccessAtom atom, DistributedEncodingTrait getDefaultLayoutForTmemLdSt(gpu::MemDescType memType, unsigned numWarps, - gpu::CTALayoutAttr ctaLayout) { + gpu::CTAEncodingAttr ctaLayout) { auto *ctx = memType.getContext(); bool prefer16x256 = triton::tools::getBoolEnv("TRITON_PREFER_TMEM_16x256_LAYOUT"); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp index eb28c20e35..63a47418b7 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp @@ -29,9 +29,7 @@ class SyncMMALowering : public OpInterfaceRewritePattern { MLIRContext *ctx = op.getContext(); Location loc = op.getLoc(); Attribute sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(ctx); - auto barrierCTALayout = ttg::CTALayoutAttr::get( - /*context=*/ctx, /*CTAsPerCGA=*/{1}, - /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierCTALayout = ttg::CTAEncodingAttr::getDefault(ctx, 1); auto barrierEncoding = ttg::SwizzledSharedEncodingAttr::get( ctx, 1, 1, 1, {0}, barrierCTALayout); ttg::MemDescType barrierMemDescType = @@ -67,8 +65,8 @@ struct TCGen5MMAScaleSharedToTmemConversion auto oldType = cast(operand.get().getType()); auto numElems = product(oldType.getShape()); Type elType = oldType.getElementType(); - ttg::CTALayoutAttr CTALayout = ttg::getCTALayout(oldType.getEncoding()); - ArrayRef CTASplitNum = CTALayout.getCTASplitNum(); + ttg::CTAEncodingAttr CTALayout = ttg::getCTALayout(oldType.getEncoding()); + auto CTASplitNum = CTALayout.getCTASplitNum(); // Distribute the scales across the rows of the MMA operation. SmallVector shape = {rows, numElems / rows}; Attribute scaleEncoding = TensorMemoryScalesEncodingAttr::get( diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp index 7028758c98..1feab6a223 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp @@ -23,7 +23,7 @@ struct UseInfo { Operation *use; Attribute desiredSharedEncoding; SmallVector shape; - ttg::CTALayoutAttr ctaLayout; + ttg::CTAEncodingAttr ctaLayout; }; static bool isTMACompatibleEncoding(Attribute enc) { @@ -96,7 +96,7 @@ std::optional getUseInfo(Operation *op) { struct EncodingInfo { Attribute desiredEncoding; - ttg::CTALayoutAttr ctaLayout; + ttg::CTAEncodingAttr ctaLayout; // Shape may be different from the descriptor block shape for gather/scatter // use case SmallVector shape; @@ -154,7 +154,7 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs, result.shape.push_back(std::min(lhs.shape[i], rhs.shape[i])); } - SetVector ctaLayouts; + SetVector ctaLayouts; if (lhs.ctaLayout) ctaLayouts.insert(lhs.ctaLayout); if (rhs.ctaLayout) @@ -164,7 +164,7 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs, case 2: // if we find clashing CTALayouts, fallback to default result.ctaLayout = - ttg::CTALayoutAttr::getDefault(lhs.ctaLayout.getContext(), rank); + ttg::CTAEncodingAttr::getDefault(lhs.ctaLayout.getContext(), rank); break; case 1: result.ctaLayout = ctaLayouts[0]; @@ -194,7 +194,7 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs, } Attribute getFallbackSharedEncoding(RankedTensorType tensorType, - ttg::CTALayoutAttr ctaLayout, + ttg::CTAEncodingAttr ctaLayout, ArrayRef usageShape) { auto ctx = tensorType.getContext(); SmallVector order; @@ -204,7 +204,7 @@ Attribute getFallbackSharedEncoding(RankedTensorType tensorType, ArrayRef shape = usageShape.empty() ? tensorType.getShape() : usageShape; if (!ctaLayout) - ctaLayout = ttg::CTALayoutAttr::getDefault(ctx, tensorType.getRank()); + ctaLayout = ttg::CTAEncodingAttr::getDefault(ctx, tensorType.getRank()); else if (ctaLayout.getRank() != tensorType.getRank()) ctaLayout = updateCTALayoutForShape(ctaLayout, shape); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp index c5414a12d8..1f9fad1ba5 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -64,7 +64,7 @@ Type replaceLayout(const Type &type, const Attribute &newLayout) { ttg::DistributedEncodingTrait replaceCTALayout(ttg::DistributedEncodingTrait layout, llvm::ArrayRef shape, int numWarps, - const ttg::CTALayoutAttr &newCTALayout) { + ttg::CTAEncodingAttr newCTALayout) { if (auto blockedLayout = mlir::dyn_cast(layout)) { return ttg::BlockedEncodingAttr::get( layout.getContext(), shape, blockedLayout.getSizePerThread(), @@ -259,8 +259,8 @@ bool CTAPlanner::processDot(triton::FuncOp &funcOp) { auto numThreads = ttg::lookupThreadsPerWarp(builder); auto numWarps = ttg::lookupNumWarps(dot); - auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN}, - {splitM, splitN}, {1, 0}); + auto newCTALayout = ttg::CTAEncodingAttr::fromSplitParams( + ctx, {splitM, splitN}, {splitM, splitN}, {1, 0}); auto newDLayout = ttg::BlockedEncodingAttr::get( ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(), numWarps, numThreads, newCTALayout); @@ -326,8 +326,8 @@ bool CTAPlanner::processReduce(triton::FuncOp &funcOp) { CTAsPerCGA[order[rank - 1]] *= remainingCTAs; auto numWarps = ttg::lookupNumWarps(reduce); - auto CTALayout = - ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder); + auto CTALayout = ttg::CTAEncodingAttr::fromSplitParams( + context, CTAsPerCGA, CTASplitNum, CTAOrder); if (!tiled) markTiled(); auto newSrcLayout = @@ -356,7 +356,7 @@ void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) { assert(stores.size() > 0 && "Cannot find store-like ops"); auto numWarps = ttg::lookupNumWarps(funcOp); - ttg::CTALayoutAttr CTALayout; + ttg::CTAEncodingAttr CTALayout; for (Operation *store : stores) { auto val = [store]() -> Value { if (auto descStore = diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp index 2c42416c68..b22d1e23c4 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PromoteLHSToTMem.cpp @@ -20,7 +20,7 @@ namespace nvidia_gpu { namespace { template Attribute getLHSTMemLayout(MMAOpTy tcGen5MMAOp, gpu::MemDescType lhsTMEMType, - ttg::CTALayoutAttr ctaLayout) { + ttg::CTAEncodingAttr ctaLayout) { int numWarps = ttg::lookupNumWarps(tcGen5MMAOp); return nvidia_gpu::getDefaultLayoutForTmemLdSt(lhsTMEMType, numWarps, ctaLayout); @@ -46,8 +46,7 @@ template class LHSToTMem : public OpRewritePattern { auto srcLayout = srcType.getEncoding(); auto accTMemEncoding = dyn_cast( tcGen5MMAOp.getD().getType().getEncoding()); - ArrayRef CTASplitNum = - triton::gpu::getCTALayout(srcLayout).getCTASplitNum(); + auto CTASplitNum = triton::gpu::getCTALayout(srcLayout).getCTASplitNum(); // TMem encoding for A operand is the same as for D (Acc), but packed for // bitwidth=16 unsigned elemBitWidth = diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index beb79fceb3..a3cf67d91e 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -36,9 +36,8 @@ lowerTMALoad(Operation *op, RankedTensorType tensorType, Value desc, sharedMemorySpace, /*mutableMemory=*/true); auto alloc = gpu::LocalAllocOp::create(rewriter, loc, memDescType).getResult(); - auto barrierCTALayout = gpu::CTALayoutAttr::get( - /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, - /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierCTALayout = + gpu::CTAEncodingAttr::getDefault(tensorType.getContext(), 1); auto barrierEncoding = gpu::SwizzledSharedEncodingAttr::get( tensorType.getContext(), 1, 1, 1, {0}, barrierCTALayout); gpu::MemDescType barrierMemDescType = diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp index 2d843317f9..bab1de8d4f 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; @@ -16,36 +17,41 @@ SmallVector translateTMAIndices(OpBuilder &builder, Location loc, return indices; } -ttg::CTALayoutAttr updateCTALayoutForShape(ttg::CTALayoutAttr ctaLayout, - ArrayRef shape) { +ttg::CTAEncodingAttr updateCTALayoutForShape(ttg::CTAEncodingAttr ctaLayout, + ArrayRef shape) { auto rank = shape.size(); if (ctaLayout.getRank() == rank) return ctaLayout; auto ctx = ctaLayout.getContext(); if (ctaLayout.getRank() > rank) { + auto ll = ctaLayout.getLinearLayout(); + // Broadcast over the first rankDiff dims unsigned rankDiff = ctaLayout.getRank() - rank; - return ttg::CTALayoutAttr::get( - ctx, ctaLayout.getCTAsPerCGA().drop_front(rankDiff), - ctaLayout.getCTASplitNum().drop_front(rankDiff), - ctaLayout.getCTAOrder().drop_front(rankDiff)); + for (int i = 0; i < rankDiff; ++i) { + ll = removeStandardDim(ll, 0); + } + return ttg::CTAEncodingAttr::get(ctx, ll); } // For rank-reducing loads, we need to rank-increase the CTA Layout auto rankDiff = rank - ctaLayout.getRank(); for (unsigned i = 0; i < rankDiff; ++i) { assert(shape[i] == 1 && "Should only happen for rank-reducing loads"); } - SmallVector CTAsPerCGA(rank, 1); - SmallVector CTASplitNum(rank, 1); - SmallVector CTAOrder(rank, 1); - - llvm::copy(ctaLayout.getCTAsPerCGA(), CTAsPerCGA.begin() + rankDiff); - llvm::copy(ctaLayout.getCTASplitNum(), CTASplitNum.begin() + rankDiff); - for (unsigned i = 0; i < rankDiff; ++i) { - CTAOrder[i] = rank - i; + auto ll = ctaLayout.getLinearLayout(); + auto kBlock = *ll.getInDimNames().begin(); + auto standardOuts = standardOutDimNames(ctx, rank); + // Append to front + for (int i = ctaLayout.getRank(); i < rank; ++i) { + ll = LinearLayout::identity1D(1, kBlock, standardOuts[i]) * ll; + } + // Rename out dims to dim0..dimn-1 + auto dimSizes = ll.getOutDims(); + for (auto [i, dim] : llvm::enumerate(standardOuts)) { + dimSizes[i].first = dim; } - llvm::copy(ctaLayout.getCTAOrder(), CTAOrder.begin() + rankDiff); - return ttg::CTALayoutAttr::get(ctx, CTAsPerCGA, CTASplitNum, CTAOrder); + ll = LinearLayout(ll.getBases(), dimSizes, false); + return ttg::CTAEncodingAttr::get(ctx, ll); } ttg::SharedEncodingTrait diff --git a/lib/Tools/LayoutUtils.cpp b/lib/Tools/LayoutUtils.cpp index 02cd30781a..815bf6d4b3 100644 --- a/lib/Tools/LayoutUtils.cpp +++ b/lib/Tools/LayoutUtils.cpp @@ -562,4 +562,21 @@ std::optional getReps(const LinearLayout &cvt, /*requireSurjective=*/false); } +LinearLayout removeStandardDim(const LinearLayout &layout, int dim) { + auto rank = layout.getNumOutDims(); + assert(rank > 0); + assert(dim < rank); + auto *ctx = layout.getOutDimNames().begin()->getContext(); + auto dims = to_vector(layout.getOutDimNames()); + assert(dims == standardOutDimNames(ctx, rank)); + dims.erase(dims.begin() + dim); + auto newLayout = layout.sublayout(to_vector(layout.getInDimNames()), dims); + auto dimSizes = newLayout.getOutDims(); + auto newDims = standardOutDimNames(ctx, rank - 1); + for (auto [i, newDim] : llvm::enumerate(newDims)) { + dimSizes[i].first = newDim; + } + return LinearLayout(newLayout.getBases(), dimSizes, /*isSurjective*/ false); +} + } // namespace mlir::triton diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index e9e6aa5386..2b2ef71aa8 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -32,6 +32,29 @@ namespace ttng = triton::nvidia_gpu; namespace gluon = mlir::triton::gluon; namespace ttag = mlir::triton::amdgpu; +static ttg::CTAEncodingAttr +buildCtaLayoutAttr(MLIRContext *ctx, + const std::vector> &layout, + unsigned rank) { + auto kBlock = StringAttr::get(ctx, "block"); + tt::LinearLayout::BasesT bases; + bases[kBlock] = layout; + auto outDims = tt::standardOutDimNames(ctx, rank); + tt::LinearLayout ll(std::move(bases), outDims); + return ttg::CTAEncodingAttr::get(ctx, std::move(ll)); +} + +static std::vector> +getCgaLayoutBases(ttg::CTAEncodingAttr layout) { + std::vector> result; + auto ctx = layout.getContext(); + auto block = StringAttr::get(ctx, "block"); + const auto &basesMap = layout.getLinearLayout().getBases(); + auto it = basesMap.find(block); + assert(it != basesMap.end()); + return it->second; +} + // Helper to check if an MLIR type or attribute has a verifier method. template static constexpr auto hasVerifier(AttrOrType t) @@ -178,14 +201,11 @@ std::vector> toStdVector(R &&range) { py::object layoutToGluon(Attribute layout) { static GluonLayouts layouts; if (auto blocked = dyn_cast(layout)) { - auto ctaLayout = blocked.getCTALayout(); + auto cgaBases = getCgaLayoutBases(blocked.getCTALayout()); return layouts.BlockedLayout(toStdVector(blocked.getSizePerThread()), toStdVector(blocked.getThreadsPerWarp()), toStdVector(blocked.getWarpsPerCTA()), - toStdVector(blocked.getOrder()), - toStdVector(ctaLayout.getCTAsPerCGA()), - toStdVector(ctaLayout.getCTASplitNum()), - toStdVector(ctaLayout.getCTAOrder())); + toStdVector(blocked.getOrder()), cgaBases); } else if (auto sliced = dyn_cast(layout)) { return layouts.SliceLayout(sliced.getDim(), layoutToGluon(sliced.getParent())); @@ -204,30 +224,24 @@ py::object layoutToGluon(Attribute layout) { return layouts.DotOperandLayout( dotOp.getOpIdx(), layoutToGluon(dotOp.getParent()), dotOp.getKWidth()); } else if (auto mma = dyn_cast(layout)) { - auto ctaLayout = mma.getCTALayout(); + auto cgaBases = getCgaLayoutBases(mma.getCTALayout()); return layouts.NVMMADistributedLayout( std::vector{mma.getVersionMajor(), mma.getVersionMinor()}, toStdVector(mma.getWarpsPerCTA()), toStdVector(mma.getInstrShape()), - toStdVector(ctaLayout.getCTAsPerCGA()), - toStdVector(ctaLayout.getCTASplitNum()), - toStdVector(ctaLayout.getCTAOrder())); + cgaBases); } else if (auto nvmma = dyn_cast(layout)) { auto ctaLayout = nvmma.getCTALayout(); - return layouts.NVMMASharedLayout( - nvmma.getSwizzlingByteWidth(), nvmma.getElementBitWidth(), - ctaLayout.getRank(), nvmma.getTransposed(), nvmma.getFp4Padded(), - toStdVector(ctaLayout.getCTAsPerCGA()), - toStdVector(ctaLayout.getCTASplitNum()), - toStdVector(ctaLayout.getCTAOrder())); + auto cgaBases = getCgaLayoutBases(ctaLayout); + return layouts.NVMMASharedLayout(nvmma.getSwizzlingByteWidth(), + nvmma.getElementBitWidth(), + ctaLayout.getRank(), nvmma.getTransposed(), + nvmma.getFp4Padded(), cgaBases); } else if (auto swizzled = dyn_cast(layout)) { - auto ctaLayout = swizzled.getCTALayout(); + auto cgaBases = getCgaLayoutBases(swizzled.getCTALayout()); return layouts.SwizzledSharedLayout( swizzled.getVec(), swizzled.getPerPhase(), swizzled.getMaxPhase(), - toStdVector(swizzled.getOrder()), - toStdVector(ctaLayout.getCTAsPerCGA()), - toStdVector(ctaLayout.getCTASplitNum()), - toStdVector(ctaLayout.getCTAOrder())); + toStdVector(swizzled.getOrder()), cgaBases); } else if (auto sharedLl = dyn_cast(layout)) { const auto &ll = sharedLl.getLinearLayout(); auto ctx = layout.getContext(); @@ -241,24 +255,19 @@ py::object layoutToGluon(Attribute layout) { } else if (auto autoEnc = dyn_cast(layout)) { return layouts.CoalescedLayout(); } else if (auto amdMfma = dyn_cast(layout)) { - auto ctaLayout = amdMfma.getCTALayout(); + auto cgaBases = getCgaLayoutBases(amdMfma.getCTALayout()); return layouts.AMDMFMALayout( amdMfma.getVersion(), toStdVector(amdMfma.getInstrShape()), amdMfma.getIsTransposed(), toStdVector(amdMfma.getWarpsPerCTA()), amdMfma.getElementBitWidth(), toStdVector(amdMfma.getTilesPerWarp()), - toStdVector(ctaLayout.getCTAsPerCGA()), - toStdVector(ctaLayout.getCTASplitNum()), - toStdVector(ctaLayout.getCTAOrder())); + cgaBases); } else if (auto amdWmma = dyn_cast(layout)) { - auto ctaLayout = amdWmma.getCTALayout(); - return layouts.AMDWMMALayout(amdWmma.getVersion(), - amdWmma.getIsTransposed(), - toStdVector(amdWmma.getWarpsPerCTA()), - toStdVector(amdWmma.getInstrShape()), - toStdVector(amdWmma.getTilesPerWarp()), - toStdVector(ctaLayout.getCTAsPerCGA()), - toStdVector(ctaLayout.getCTASplitNum()), - toStdVector(ctaLayout.getCTAOrder())); + auto cgaBases = getCgaLayoutBases(amdWmma.getCTALayout()); + return layouts.AMDWMMALayout( + amdWmma.getVersion(), amdWmma.getIsTransposed(), + toStdVector(amdWmma.getWarpsPerCTA()), + toStdVector(amdWmma.getInstrShape()), + toStdVector(amdWmma.getTilesPerWarp()), cgaBases); } else if (auto paddedShared = dyn_cast(layout)) { auto *ctx = paddedShared.getContext(); @@ -334,24 +343,14 @@ void init_gluon_ir(py::module &&m) { /*mutableMemory=*/true, /*allocShape=*/allocShape); }) - .def("get_cta_layout", - [](GluonOpBuilder &self, std::vector &ctasPerCga, - std::vector &ctaSplitNum, - std::vector &ctaOrder) -> Attribute { - auto ctx = self.getContext(); - return self.getChecked(ctx, ctasPerCga, - ctaSplitNum, ctaOrder); - }) .def("get_blocked_layout", [](GluonOpBuilder &self, std::vector &sizePerThread, std::vector &threadsPerWarp, std::vector &warpsPerCta, std::vector &order, - std::vector &ctasPerCga, - std::vector &ctaSplitNum, - std::vector &ctaOrder) -> Attribute { + std::vector> &cgaBases) -> Attribute { auto ctx = self.getContext(); - auto ctaLayout = self.getChecked( - ctx, ctasPerCga, ctaSplitNum, ctaOrder); + unsigned rank = order.size(); + auto ctaLayout = buildCtaLayoutAttr(ctx, cgaBases, rank); return self.getChecked( ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, ctaLayout); @@ -400,13 +399,11 @@ void init_gluon_ir(py::module &&m) { .def("get_mma_layout", [](GluonOpBuilder &self, std::vector &version, std::vector &warpsPerCta, - std::vector &ctasPerCga, - std::vector &ctaSplitNum, - std::vector &ctaOrder, + std::vector> &cgaBases, std::vector &instrShape) -> Attribute { auto ctx = self.getContext(); - auto ctaLayout = self.getChecked( - ctx, ctasPerCga, ctaSplitNum, ctaOrder); + unsigned rank = warpsPerCta.size(); + auto ctaLayout = buildCtaLayoutAttr(ctx, cgaBases, rank); return self.getChecked( ctx, version[0], version[1], warpsPerCta, ctaLayout, instrShape); @@ -415,14 +412,12 @@ void init_gluon_ir(py::module &&m) { [](GluonOpBuilder &self, unsigned version, std::vector &warpsPerCta, std::vector &instrShape, bool transposed, - std::vector &ctasPerCga, - std::vector &ctaSplitNum, - std::vector &ctaOrder, + std::vector> &cgaBases, std::vector &tilesPerWarp, unsigned elementBitWidth) -> Attribute { auto ctx = self.getContext(); - auto ctaLayout = self.getChecked( - ctx, ctasPerCga, ctaSplitNum, ctaOrder); + unsigned rank = warpsPerCta.size(); + auto ctaLayout = buildCtaLayoutAttr(ctx, cgaBases, rank); return ttg::AMDMfmaEncodingAttr::get( ctx, version, warpsPerCta, instrShape, transposed, ctaLayout, tilesPerWarp, elementBitWidth); @@ -431,13 +426,11 @@ void init_gluon_ir(py::module &&m) { [](GluonOpBuilder &self, unsigned version, bool transposed, std::vector &warpsPerCta, std::vector &tilesPerWarp, - std::vector &ctasPerCga, - std::vector &ctaSplitNum, - std::vector &ctaOrder, + std::vector> &cgaBases, std::vector &instrShape) -> Attribute { auto ctx = self.getContext(); - auto ctaLayout = self.getChecked( - ctx, ctasPerCga, ctaSplitNum, ctaOrder); + unsigned rank = warpsPerCta.size(); + auto ctaLayout = buildCtaLayoutAttr(ctx, cgaBases, rank); return ttg::AMDWmmaEncodingAttr::get(ctx, version, transposed, warpsPerCta, tilesPerWarp, ctaLayout, instrShape); @@ -485,12 +478,10 @@ void init_gluon_ir(py::module &&m) { .def("get_nvmma_shared_layout", [](GluonOpBuilder &self, unsigned swizzleByteWidth, unsigned elementBitwidth, bool transposed, bool fp4Padded, - std::vector &ctasPerCga, - std::vector &ctaSplitNum, - std::vector &ctaOrder) -> Attribute { + std::vector> &cgaBases, + unsigned rank) -> Attribute { auto ctx = self.getContext(); - auto ctaLayout = self.getChecked( - ctx, ctasPerCga, ctaSplitNum, ctaOrder); + auto ctaLayout = buildCtaLayoutAttr(ctx, cgaBases, rank); return self.getChecked( ctx, swizzleByteWidth, transposed, elementBitwidth, fp4Padded, ctaLayout); @@ -506,12 +497,11 @@ void init_gluon_ir(py::module &&m) { }) .def("get_swizzled_shared_layout", [](GluonOpBuilder &self, int vec, int perPhase, int maxPhase, - std::vector &order, std::vector &ctasPerCga, - std::vector &ctaSplitNum, - std::vector &ctaOrder) -> Attribute { + std::vector &order, + std::vector> &cgaBases) -> Attribute { auto ctx = self.getContext(); - auto ctaLayout = self.getChecked( - ctx, ctasPerCga, ctaSplitNum, ctaOrder); + unsigned rank = order.size(); + auto ctaLayout = buildCtaLayoutAttr(ctx, cgaBases, rank); return self.getChecked( ctx, vec, perPhase, maxPhase, order, ctaLayout); }) @@ -903,8 +893,7 @@ void init_gluon_ir(py::module &&m) { "compute_tmem_reg_layout", [](py::object elementTyObj, std::vector shape, py::object layoutObj, unsigned numWarps, const std::string &atomName, - std::vector ctasPerCga, std::vector ctaSplitNum, - std::vector ctaOrder) -> py::object { + std::vector> cgaBases) -> py::object { DialectRegistry registry; registry.insert(); @@ -922,12 +911,12 @@ void init_gluon_ir(py::module &&m) { auto allocShape = shape; auto ctx = builder.getContext(); + unsigned rank = shape.size(); auto memDescTy = builder.getChecked( shape, elementType, layoutAttr, ttng::TensorMemorySpaceAttr::get(ctx), /*mutableMemory=*/true, allocShape); - auto ctaLayoutAttr = builder.getChecked( - ctx, ctasPerCga, ctaSplitNum, ctaOrder); + auto ctaLayoutAttr = buildCtaLayoutAttr(ctx, cgaBases, rank); auto maybeAtom = llvm::StringSwitch>(atomName) @@ -957,6 +946,20 @@ void init_gluon_ir(py::module &&m) { return layoutToGluon(attr); }); + m.def( + "make_cga_layout", + [](std::vector ctasPerCga, std::vector ctaSplitNum, + std::vector ctaOrder) -> std::vector> { + DialectRegistry registry; + registry.insert(); + MLIRContext ctx(MLIRContext::Threading::DISABLED); + ctx.appendDialectRegistry(registry); + ctx.loadAllAvailableDialects(); + auto attr = ttg::CTAEncodingAttr::fromSplitParams( + &ctx, ctasPerCga, ctaSplitNum, ctaOrder); + return getCgaLayoutBases(attr); + }); + m.def("get_amd_mfma_scale_layout", [](unsigned opIdx, std::vector &shape, unsigned mfmaMDim, std::vector &tilesPerWarp, diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index 1eae27ae5c..19ed4e08f6 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -359,7 +359,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): @gluon.jit def kernel(input_desc, FAILURE: ttgl.constexpr, MEM_ACCESS_KIND: ttgl.constexpr): - acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], col_stride=1, cta_split_num=[1, 1]) + acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], col_stride=1) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1]) smemA = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout) @@ -518,7 +518,7 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]): @gluon.jit def kernel(input_desc, BUF_IDX: ttgl.constexpr, BAR_IDX: ttgl.constexpr): - acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], col_stride=1, cta_split_num=[1, 1]) + acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], col_stride=1) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1]) smemA = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout) @@ -581,7 +581,7 @@ def kernel(input_desc, FAILURE: ttgl.constexpr): num_buffers: ttgl.constexpr = 2 if FAILURE else 3 num_mma_stages: ttgl.constexpr = 2 - acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], col_stride=1, cta_split_num=[1, 1]) + acc_layout: ttgl.constexpr = blackwell.TensorMemoryLayout([XBLOCK, XBLOCK], col_stride=1) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[1, XBLOCK], threads_per_warp=[32, 1], warps_per_cta=[4, 1], order=[0, 1]) zero = ttgl.zeros([XBLOCK, XBLOCK], ttgl.float32, blocked_layout) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index eff5c3d73d..1f08824b13 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -75,7 +75,7 @@ def test_copy_kernel(layout, XBLOCK, device): def test_copy_kernel_multi_cta(): XBLOCK = 2048 layout = ttgl.BlockedLayout(size_per_thread=[8], threads_per_warp=[THREADS_PER_WARP], warps_per_cta=[8], order=[0], - ctas_per_cga=[2], cta_split_num=[2]) + cga_layout=[[1]]) inp = torch.randn(XBLOCK * 4 - 7, device="cuda") out = torch.empty_like(inp) @@ -233,18 +233,27 @@ def is_two_ctas(layout_a: ttgl.constexpr, layout_b: ttgl.constexpr) -> ttgl.cons if isinstance(layout_a, TensorMemoryLayout): return layout_a.two_ctas + # TODO Implement as a helper def has_cta_split(layout, cta_split_num): - if hasattr(layout, "cta_split_num"): - return layout.cta_split_num == cta_split_num - else: - # Super hacky - assert isinstance(layout, ttgl.SharedLinearLayout) - max_stride = [0, 0] - for b in itertools.chain(layout.offset_bases, layout.block_bases): - for i, bi in enumerate(b): - max_stride[i] = max(max_stride[i], bi) - basis = [max_stride[0], 0] if cta_split_num == [2, 1] else [0, max_stride[1]] - return len(layout.block_bases) == 1 and layout.block_bases[0] == basis + if hasattr(layout, "cga_layout"): + if not layout.cga_layout: + return cta_split_num == [1, 1] + rank = layout.rank + derived_split = [1] * rank + for basis in layout.cga_layout: + idx = next((i for i, v in enumerate(basis) if v != 0), None) + if idx is not None and idx < rank: + derived_split[idx] *= 2 + return derived_split == cta_split_num + + # Fallback for SharedLinearLayout + assert isinstance(layout, ttgl.SharedLinearLayout) + max_stride = [0, 0] + for b in itertools.chain(layout.offset_bases, layout.block_bases): + for i, bi in enumerate(b): + max_stride[i] = max(max_stride[i], bi) + basis = [max_stride[0], 0] if cta_split_num == [2, 1] else [0, max_stride[1]] + return len(layout.block_bases) == 1 and layout.block_bases[0] == basis return has_cta_split(layout_a, [2, 1]) and has_cta_split(layout_b, [1, 2]) @@ -253,7 +262,7 @@ def has_cta_split(layout, cta_split_num): def mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, block_layout_a: ttgl.constexpr, block_layout_b: ttgl.constexpr, block_layout_c: ttgl.constexpr, mma_layout: ttgl.constexpr, shared_layout_a: ttgl.constexpr, shared_layout_b: ttgl.constexpr, acc_dtype: ttgl.constexpr, - ASYNC: ttgl.constexpr, USE_TCGEN05: ttgl.constexpr): + ASYNC: ttgl.constexpr, USE_TCGEN05: ttgl.constexpr, mma_barrier_layout: ttgl.constexpr = None): a_offs_m = ttgl.arange(0, M)[:, None] a_offs_k = ttgl.arange(0, K)[None, :] b_offs_k = ttgl.arange(0, K)[:, None] @@ -273,16 +282,11 @@ def mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexp fence_async_shared(cluster=two_ctas) if USE_TCGEN05: - tmem_shape: ttgl.constexpr = (min(M // mma_layout.cta_split_num[0], 128), N // mma_layout.cta_split_num[1]) - tmem_layout: ttgl.constexpr = TensorMemoryLayout(tmem_shape, col_stride=32 // acc_dtype.primitive_bitwidth, - cta_split_num=mma_layout.cta_split_num, two_ctas=two_ctas) - - # The layout of this mbarrier seems to be irrelevant. We might want to change the API to just acacept num_ctas - mma_barrier = ttgl.allocate_shared_memory(ttgl.int64, [1], - mbarrier.MBarrierLayout(ctas_per_cga=(2 if two_ctas else 1))) + assert mma_barrier_layout is not None, "Expected an mbarrier layout for TCGen05 MMA execution" + mma_barrier = ttgl.allocate_shared_memory(ttgl.int64, [1], mma_barrier_layout) mbarrier.init(mma_barrier, count=1) - acc_tmem = allocate_tensor_memory(acc_dtype, [M, N], tmem_layout) + acc_tmem = allocate_tensor_memory(acc_dtype, [M, N], mma_layout) tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=False, mbarriers=[mma_barrier]) mbarrier.wait(mma_barrier, phase=0) @@ -291,11 +295,9 @@ def mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexp tmem_reg_layout: ttgl.constexpr = get_tmem_reg_layout( acc_dtype, (M, N), - tmem_layout, + mma_layout, num_warps=ttgl.num_warps(), - ctas_per_cga=mma_layout.ctas_per_cga, - cta_split_num=mma_layout.cta_split_num, - cta_order=mma_layout.cta_order, + cga_layout=block_layout_c.cga_layout, ) acc = acc_tmem.load(tmem_reg_layout) else: @@ -469,27 +471,46 @@ def transpose_bases(bases): gl_acc_dtype = acc_dtype_map[acc_dtype] out_dtype = torch.float32 cta_order = [1, 0] - block_layout_a = ttgl.BlockedLayout([1, 8], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[1, 0], - ctas_per_cga=ctas_per_cga, cta_split_num=cta_split_a, cta_order=cta_order) + + # TODO Remove this function altogether + from triton._C.libtriton.gluon_ir import make_cga_layout + cga_layout_a = make_cga_layout(ctas_per_cga, cta_split_a, cta_order) + cga_layout_b = make_cga_layout(ctas_per_cga_b, cta_split_b, cta_order) + cga_layout_c = make_cga_layout(ctas_per_cga, ctas_per_cga, cta_order) + block_layout_a = ttgl.BlockedLayout([1, 8], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[0, 1], + cga_layout=cga_layout_a) block_layout_b = ttgl.BlockedLayout([1, 8], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[1, 0], - ctas_per_cga=ctas_per_cga_b, cta_split_num=cta_split_b, cta_order=cta_order) + cga_layout=cga_layout_b) if swizzling_a == 0: shared_layout_a = get_shared_swizzling_zero(M, K, transpose_a, ctas_per_cga, cta_split_a, cta_order) else: shared_layout_a = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_a, element_bitwidth=bitwidth, rank=2, - transposed=transpose_a, ctas_per_cga=ctas_per_cga, - cta_split_num=cta_split_a, cta_order=cta_order) + transposed=transpose_a, cga_layout=cga_layout_a) if swizzling_b == 0: shared_layout_b = get_shared_swizzling_zero(K, N, transpose_b, ctas_per_cga_b, cta_split_b, cta_order) else: shared_layout_b = ttgl.NVMMASharedLayout(swizzle_byte_width=swizzling_b, element_bitwidth=bitwidth, rank=2, - transposed=transpose_b, ctas_per_cga=ctas_per_cga_b, - cta_split_num=cta_split_b, cta_order=cta_order) - mma_layout = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=warps, instr_shape=instr_shape, - ctas_per_cga=ctas_per_cga, cta_split_num=ctas_per_cga, cta_order=cta_order) + transposed=transpose_b, cga_layout=cga_layout_b) + if use_tcgen05: + tmem_shape = (min(M // ctas_per_cga[0], 128), N // ctas_per_cga[1]) + mma_layout = TensorMemoryLayout(tmem_shape, col_stride=32 // torch.finfo(acc_dtype).bits, + cta_split_num=tuple(ctas_per_cga), two_ctas=two_ctas) + else: + mma_layout = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=warps, instr_shape=instr_shape, + cga_layout=cga_layout_c) block_layout_c = ttgl.BlockedLayout([1, 8], [1, THREADS_PER_WARP], warps_per_cta=warps, order=[1, 0], - ctas_per_cga=ctas_per_cga, cta_split_num=ctas_per_cga, cta_order=cta_order) + cga_layout=cga_layout_c) + num_ctas = ctas_per_cga[0] * ctas_per_cga[1] + mma_barrier_layout = None + if use_tcgen05: + # The layout of this mbarrier seems to be irrelevant right now + # We might want to change the API here + barrier_cga_layout = [] + if two_ctas: + barrier_cga_layout.append([0]) + barrier_cga_layout.extend([2**i] for i in range(num_ctas // (2 if two_ctas else 1))) + mma_barrier_layout = mbarrier.MBarrierLayout(cga_layout=barrier_cga_layout) torch.manual_seed(0) def cast(x, dtype): @@ -523,8 +544,9 @@ def cast(x, dtype): gl_acc_dtype, False, use_tcgen05, + mma_barrier_layout, num_warps=warps[0] * warps[1], - num_ctas=ctas_per_cga[0] * ctas_per_cga[1], + num_ctas=num_ctas, ) except OutOfResources: # FIXME: Compute a priori diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index f2110aa200..204cb10e0c 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -388,7 +388,7 @@ def shared_memory_cast_kernel(): anchor_noinline(perm) layout_b: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, transposed=False, element_bitwidth=16, - rank=4, cta_order=[3, 2, 1, 0]) + rank=4) smem = ttgl.allocate_shared_memory(ttgl.float16, [32, 1, 4, 64], layout_b) smem.reshape((128, 64)) @@ -412,13 +412,13 @@ def test_shared_memory_cast(target): %c0_i32 = arith.constant 0 : i32 %1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable> %2 = ttg.memdesc_trans %1 {order = array} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable> - tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) -> () + tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False__NVMMALAS[128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) -> () %3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> %4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable> %5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable> tt.return } - tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) attributes {noinline = true} { + tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False__NVMMALAS[128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) attributes {noinline = true} { tt.return } } @@ -959,10 +959,10 @@ def kernel(): module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { tt.func public @kernel() attributes {noinline = false} { %0 = ttg.local_alloc : () -> !ttg.memdesc<32x32xi32, #shared, #smem, mutable> - tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1, 0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> () + tt.call @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0__SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1, 0), cga_layout=__)_"(%0) : (!ttg.memdesc<32x32xi32, #shared, #smem, mutable>) -> () tt.return } - tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0_1_1_1_1_1_0_SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1, 0), ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} { + tt.func private @"test_frontend.smem_and_layout_user__MDi32S32_32SLSSS_1_1_1_1_0__SSSLAS[32, 32]ASMD__(1,)cconstexpr_SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=(1, 0), cga_layout=__)_"(%arg0: !ttg.memdesc<32x32xi32, #shared, #smem, mutable>) attributes {noinline = false} { tt.return } } @@ -1200,10 +1200,10 @@ def test_reduce(target): %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> %cst_1 = arith.constant 2.000000e+00 : f32 %cst_2 = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> - %0 = tt.call @"triton.language.standard.sum__fp32S16_16SLB1_1B1_32B4_1B1_0B1_1B1_1B1_0BL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cNone"(%cst_0) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> - %1 = tt.call @"triton.language.standard.sum__fp32S16_16SLB1_1B1_32B4_1B1_0B1_1B1_1B1_0BL__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%cst_0) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> - %2 = tt.call @"triton.language.standard.sum__fp32S16_16SLB1_1B1_32B4_1B1_0B1_1B1_1B1_0BL__(1,)cNone_(2,)cconstexpr_False__(3,)cNone"(%cst_0) : (tensor<16x16xf32, #blocked>) -> f32 - %3 = tt.call @"triton.language.standard.max__fp32S16SLSL0_B1_1B1_32B4_1B1_0B1_1B1_1B1_0BSLL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%0) : (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) -> f32 + %0 = tt.call @"triton.language.standard.sum__fp32S16_16SLB1_1_1_32_4_1_1_0_BL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cNone"(%cst_0) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> + %1 = tt.call @"triton.language.standard.sum__fp32S16_16SLB1_1_1_32_4_1_1_0_BL__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%cst_0) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.call @"triton.language.standard.sum__fp32S16_16SLB1_1_1_32_4_1_1_0_BL__(1,)cNone_(2,)cconstexpr_False__(3,)cNone"(%cst_0) : (tensor<16x16xf32, #blocked>) -> f32 + %3 = tt.call @"triton.language.standard.max__fp32S16SLSL0_B1_1_1_32_4_1_1_0_BSLL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%0) : (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) -> f32 %4 = ttg.convert_layout %1 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> %5:2 = "tt.reduce"(%cst_0, %cst_2) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): @@ -1220,7 +1220,7 @@ def test_reduce(target): tt.store %12, %9 : tensor<16x!tt.ptr, #ttg.slice<{dim = 0, parent = #blocked}>> tt.return } - tt.func private @"triton.language.standard.sum__fp32S16_16SLB1_1B1_32B4_1B1_0B1_1B1_1B1_0BL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cNone"(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> attributes {noinline = false} { + tt.func private @"triton.language.standard.sum__fp32S16_16SLB1_1_1_32_4_1_1_0_BL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cNone"(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> attributes {noinline = false} { %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 @@ -1238,7 +1238,7 @@ def test_reduce(target): %1 = ub.poison : f32 tt.return %1 : f32 } - tt.func private @"triton.language.standard.sum__fp32S16_16SLB1_1B1_32B4_1B1_0B1_1B1_1B1_0BL__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> attributes {noinline = false} { + tt.func private @"triton.language.standard.sum__fp32S16_16SLB1_1_1_32_4_1_1_0_BL__(1,)cconstexpr_1__(2,)cconstexpr_False__(3,)cNone"(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> attributes {noinline = false} { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = tt.call @triton.language.standard._sum_combine__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 @@ -1249,7 +1249,7 @@ def test_reduce(target): %1 = ub.poison : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> tt.return %1 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> } - tt.func private @"triton.language.standard.sum__fp32S16_16SLB1_1B1_32B4_1B1_0B1_1B1_1B1_0BL__(1,)cNone_(2,)cconstexpr_False__(3,)cNone"(%arg0: tensor<16x16xf32, #blocked>) -> f32 attributes {noinline = false} { + tt.func private @"triton.language.standard.sum__fp32S16_16SLB1_1_1_32_4_1_1_0_BL__(1,)cNone_(2,)cconstexpr_False__(3,)cNone"(%arg0: tensor<16x16xf32, #blocked>) -> f32 attributes {noinline = false} { %0 = tt.reshape %arg0 : tensor<16x16xf32, #blocked> -> tensor<256xf32, #linear> %1 = "tt.reduce"(%0) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): @@ -1261,7 +1261,7 @@ def test_reduce(target): %2 = ub.poison : f32 tt.return %2 : f32 } - tt.func private @"triton.language.standard.max__fp32S16SLSL0_B1_1B1_32B4_1B1_0B1_1B1_1B1_0BSLL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%arg0: tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) -> f32 attributes {noinline = false} { + tt.func private @"triton.language.standard.max__fp32S16SLSL0_B1_1_1_32_4_1_1_0_BSLL__(1,)cconstexpr_0__(2,)cconstexpr_False__(3,)cconstexpr_True__(4,)cconstexpr_False_"(%arg0: tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) -> f32 attributes {noinline = false} { %0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): %2 = tt.call @triton.language.standard._elementwise_max__fp32_fp32__(%arg1, %arg2) : (f32, f32) -> f32 @@ -1366,7 +1366,7 @@ def test_tensor_permute(): def test_split_join(): # CHECK: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> # CHECK: [[BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> - layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0], [1], [1], [0]) + layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0]) a = ttgl.full([128], 1, ttgl.int32, layout) b = ttgl.full([128], 2, ttgl.int32, layout) # CHECK: tt.join {{.*}} : tensor<128xi32, [[BLOCKED]]> -> tensor<128x2xi32, [[BLOCKED1]]> @@ -1464,7 +1464,7 @@ def kernel(reg_type: ttgl.constexpr, shared_type: ttgl.constexpr, ref_conflicts: "layout, expected", [ ( - ttgl.BlockedLayout([1], [4], [4], [0], [1], [1], [0]), + ttgl.BlockedLayout([1], [4], [4], [0]), ttgl.DistributedLinearLayout( reg_bases=[], lane_bases=[[1], [2]], @@ -1474,7 +1474,7 @@ def kernel(reg_type: ttgl.constexpr, shared_type: ttgl.constexpr, ref_conflicts: ), ), ( - ttgl.BlockedLayout([1], [4], [4], [0], [4], [2], [0]), + ttgl.BlockedLayout([1], [4], [4], [0], [[1], [0]]), ttgl.DistributedLinearLayout( reg_bases=[], lane_bases=[[1], [2]], @@ -1484,7 +1484,7 @@ def kernel(reg_type: ttgl.constexpr, shared_type: ttgl.constexpr, ref_conflicts: ), ), ( - ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [1, 2], [1, 2], [1, 0]), + ttgl.BlockedLayout([8, 1], [8, 4], [1, 4], [0, 1], [[0, 1]]), ttgl.DistributedLinearLayout( reg_bases=[[1, 0], [2, 0], [4, 0], [0, 16], [0, 32]], lane_bases=[[8, 0], [16, 0], [32, 0], [0, 1], [0, 2]], @@ -1790,20 +1790,17 @@ def amd_mfma_layout_kernel(): 8], transposed=True, warps_per_cta=[4, 1], tiles_per_warp=[2, 2])) - ttgl.full([128, 32], 0, ttgl.float32, - layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32, 8], transposed=True, # - warps_per_cta=[4, 1], tiles_per_warp=[1, 1], # - ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])) + ttgl.full([128, 32], 0, ttgl.float32, layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32, + 8], transposed=True, + warps_per_cta=[4, 1], tiles_per_warp=[1, 1])) ttgl.full([128, 32], 0, ttgl.float64, layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16, 16], transposed=True, # - warps_per_cta=[4, 1], element_bitwidth=64, tiles_per_warp=[1, 1], # - ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])) + warps_per_cta=[4, 1], element_bitwidth=64, tiles_per_warp=[1, 1])) ttgl.full([128, 32], 0, ttgl.int32, layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16, 16], transposed=True, # - warps_per_cta=[4, 1], element_bitwidth=32, # - ctas_per_cga=[1, 1], cta_split_num=[1, 1], tiles_per_warp=[1, 1])) + warps_per_cta=[4, 1], element_bitwidth=32, tiles_per_warp=[1, 1])) @pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) @@ -1842,8 +1839,7 @@ def add_int(a, b): @gluon.jit def infer_layout_for_amd_mfma_kernel(): layout: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32, 8], transposed=True, - warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[1, 0]) + warps_per_cta=[4, 1]) a = ttgl.full([128, 32], 1, ttgl.int32, layout) b = ttgl.reduce(a, 1, add_int) ttgl.static_assert(b.type.layout == ttgl.SliceLayout(1, layout)) diff --git a/python/test/gluon/test_lowerings.py b/python/test/gluon/test_lowerings.py index 724ef244cf..e10af3db0a 100644 --- a/python/test/gluon/test_lowerings.py +++ b/python/test/gluon/test_lowerings.py @@ -113,12 +113,10 @@ def _reduce_layouts(): # FIXME: Do not enable these tests until the SLPVectorizor problem with nvptx target has been resolved # SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 1, 4], [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2])), # SliceLayout(dim=0, parent=BlockedLayout([1, 4, 1], [1, 8, THREADS_PER_WARP // 8], [1, 4, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2])), - ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[1, 0], instr_shape=[16, 16, 16]), + ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]), + ttgl.BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 16, 16]), ttgl.amd.AMDMFMALayout(version=1, instr_shape=[32, 32, 8], transposed=True, warps_per_cta=[1, 4]), ttgl.amd.AMDMFMALayout(version=2, instr_shape=[32, 32, 8], transposed=True, warps_per_cta=[1, 4]), ttgl.amd.AMDMFMALayout(version=3, instr_shape=[32, 32, 8], transposed=True, warps_per_cta=[1, 4]), @@ -128,22 +126,17 @@ def _reduce_layouts(): ttgl.intel.IntelDPASLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, warps_per_cta=[4, 1], rep_cluster=[1, 1], threads_per_warp=32), ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), + parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], instr_shape=[16, 8]), operand_index=1, k_width=8), ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]), + parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], instr_shape=[16, 32, 16]), operand_index=0, k_width=2), ttgl.SliceLayout( - dim=0, parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], - cta_split_num=[1, 1, 1], cta_order=[2, 1, - 0], instr_shape=[1, 16, 8])), + dim=0, parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], instr_shape=[1, 16, 8])), ttgl.SliceLayout( dim=1, parent=ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], - cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], - instr_shape=[1, 16, 8]), operand_index=1, k_width=2)), + parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], instr_shape=[1, 16, 8]), + operand_index=1, k_width=2)), ]) rets = [] @@ -228,10 +221,9 @@ def kernel(x_ptr, z_ptr, M: ttgl.constexpr, N: ttgl.constexpr, layout: ttgl.cons @pytest.mark.parametrize( "src_layout", _filter_layouts([ - ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), + ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]), + ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), ])) def test_store_layouts(M, src_layout, device): @@ -251,20 +243,15 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, layout: ttgl.constexpr): _1d_layouts = _filter_layouts([ - ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[1, 0], instr_shape=[16, 32, 16]), - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]), - operand_index=0, k_width=2), + ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]), + ttgl.BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 2], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), + parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16]), operand_index=0, k_width=2), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 2], instr_shape=[16, 8]), + operand_index=0, k_width=2), ]) @@ -342,39 +329,27 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, src_layout: ttgl.constexpr, dst_layo _2d_layouts = _filter_layouts([ ttgl.BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1]), ttgl.BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[1, 0], instr_shape=[16, 32, 16]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16]), ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]), + parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16]), operand_index=0, k_width=2), ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 32, 16]), + parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16]), operand_index=0, k_width=1), - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[1, 0], instr_shape=[16, 8]), - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 8]), - operand_index=1, k_width=2), - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 2], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 8]), - operand_index=0, k_width=2), - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 8]), - operand_index=0, k_width=8), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), + operand_index=1, k_width=2), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 2], instr_shape=[16, 8]), + operand_index=0, k_width=2), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), + operand_index=0, k_width=8), ttgl.SliceLayout( dim=1, parent=ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], - cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[16, 32, 16]), + parent=ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1, 1], instr_shape=[16, 32, 16]), operand_index=0, k_width=2)), ttgl.SliceLayout( dim=1, parent=ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], - cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16, 8]), + parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], instr_shape=[1, 16, 8]), operand_index=1, k_width=2)), ]) @@ -467,54 +442,38 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, N: ttgl.constexpr, src_layout: ttgl. _mma_pairs = [ # MMA v2.0 layouts [ - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[1, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[1, 4], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), ], [ - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 8], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[8, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 8], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[8, 2], instr_shape=[16, 8]), ], # MMA v2.1 layouts [ - ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[1, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), - ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[1, 4], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[4, 1], instr_shape=[16, 8]), ], [ - ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[2, 8], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), - ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[8, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[2, 8], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[2, 1], warps_per_cta=[8, 2], instr_shape=[16, 8]), ], # MMA v3.0 layouts [ - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 32, 32]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 64, 32]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 32]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 64, 32]), ], [ - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[1, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 32, 32]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 64, 32]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[1, 4], instr_shape=[16, 32, 32]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 64, 32]), ], [ - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[2, 8], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 64, 32]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 32, 32]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[2, 8], instr_shape=[16, 64, 32]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 2], instr_shape=[16, 32, 32]), ], [ - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 128, 16]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 64, 16]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 128, 16]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 64, 16]), ], # AMD MFMA v1 layouts [ @@ -665,37 +624,23 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, N: ttgl.constexpr, src_layout: ttgl. _ld_st_dot_layouts = _filter_layouts([ - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 8]), - operand_index=0, k_width=4), - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - operand_index=1, k_width=4), - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - operand_index=0, k_width=2), - ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 8]), - operand_index=1, k_width=2), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), + operand_index=0, k_width=4), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), + operand_index=1, k_width=4), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), + operand_index=0, k_width=2), + ttgl.DotOperandLayout(parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1], instr_shape=[16, 8]), + operand_index=1, k_width=2), ]) _ld_st_mma_layouts = _filter_layouts([ - ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[1, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 8]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 128, 16]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 128, 16]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 64, 16]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 128, 16]), - ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], - cta_order=[0, 1], instr_shape=[16, 64, 16]), + ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[1, 4], instr_shape=[16, 8]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 128, 16]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 2], instr_shape=[16, 128, 16]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 2], instr_shape=[16, 64, 16]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 1], instr_shape=[16, 128, 16]), + ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[8, 4], instr_shape=[16, 64, 16]), ]) _ld_st_shared_layouts = _filter_layouts([ @@ -793,11 +738,10 @@ def _assert_close(actual, expected): _ld_st_3d_layouts = _filter_layouts([ - ttgl.BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), - ttgl.BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + ttgl.BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0]), + ttgl.BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0]), ttgl.DotOperandLayout( - parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], ctas_per_cga=[1, 1, 1], - cta_split_num=[1, 1, 1], cta_order=[2, 1, 0], instr_shape=[1, 16, 8]), + parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[4, 1, 1], instr_shape=[1, 16, 8]), operand_index=0, k_width=1), ]) diff --git a/python/test/unit/intel/test_block_io.py b/python/test/unit/intel/test_block_io.py index 2f3a2f9fdb..f38506c500 100644 --- a/python/test/unit/intel/test_block_io.py +++ b/python/test/unit/intel/test_block_io.py @@ -53,18 +53,14 @@ def __str__(self): class BlockedLayout: - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[0, 1]): + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): self.sz_per_thread = size_per_thread self.threads_per_warp = threads_per_warp self.warps_per_cta = warps_per_cta self.order = order - self.ctas_per_cga = ctas_per_cga - self.cta_split_num = cta_split_num - self.cta_order = cta_order def __str__(self): - return f"#ttg.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#ttg.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" def warps_per_cta(layout): @@ -75,7 +71,7 @@ def warps_per_cta(layout): layouts = [ - BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0]), # DPAS layout DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16, warps_per_cta=[1, 4], rep_cluster=[1, 2]), @@ -110,8 +106,7 @@ def warps_per_cta(layout): parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=1, k_width=1), # Slice layout - SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [2, 1, 16], [2, 1, 2], [2, 1, 0], [1, 1, 1], [1, 1, 1], - [0, 1, 2])), + SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [2, 1, 16], [2, 1, 2], [2, 1, 0])), ] @@ -136,7 +131,8 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, tran block_io = "\"column_major\"" if transpose else "\"row_major\"" strides = "[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]" - + #breakpoint() + print(layout) if load_block_ptr: load_ops = f""" %src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {strides}, [%c0_i32, %c0_i32] {{order = array}} : > diff --git a/python/test/unit/intel/test_core.py b/python/test/unit/intel/test_core.py index d3a6bf6c98..41073bb243 100644 --- a/python/test/unit/intel/test_core.py +++ b/python/test/unit/intel/test_core.py @@ -62,33 +62,26 @@ def __str__(self): class BlockedLayout: - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1], - cta_split_num=[1, 1], cta_order=[0, 1]): + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): self.sz_per_thread = size_per_thread self.threads_per_warp = threads_per_warp self.warps_per_cta = warps_per_cta self.order = order - self.ctas_per_cga = ctas_per_cga - self.cta_split_num = cta_split_num - self.cta_order = cta_order def __str__(self): - return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" class SwizzledSharedLayout: - def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): + def __init__(self, vec, per_phase, max_phase, order): self.vec = vec self.per_phase = per_phase self.max_phase = max_phase self.order = order - self.ctas_per_cga = ctas_per_cga - self.cta_split_num = cta_split_num - self.cta_order = cta_order def __str__(self): - return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>" class PaddedSharedLayout: @@ -172,17 +165,17 @@ def get_reduce_input(dtype_str, shape): scan_layouts = [ - BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]), + BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]), + BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]), + BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]), + BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]), + BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0]), ] @@ -254,8 +247,8 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa layouts = [ - BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[4, 1], rep_cluster=[1, 1]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32, @@ -305,8 +298,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov store_range = "%7" if axis == 0 else "%1" warps = warps_per_cta(src_layout, [M, N]) num_warps = int(np.prod(warps)) - blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1]) - one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0], [1], [1], [0]) + blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1]) + one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0]) expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1" other_axis = 1 - axis @@ -397,8 +390,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov layouts = [ - BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[4, 1], rep_cluster=[1, 1]), ] @@ -443,8 +436,8 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path): layouts = [ - BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[4, 1], rep_cluster=[1, 1]) ] @@ -532,10 +525,10 @@ def test_convert1d_bool(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp layouts = [ - BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]) + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]), + BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0]), + BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]) ] @@ -611,8 +604,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli # TODO: backend should be tested separately layouts = [ - BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1]), + BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]), DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, warps_per_cta=[4, 1], rep_cluster=[1, 1]), DpasLayout(repeatCount=2, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, @@ -621,10 +614,10 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli intermediate_layouts = [ None, - SwizzledSharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]), - SwizzledSharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), - SwizzledSharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), - SwizzledSharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]), + SwizzledSharedLayout(1, 1, 1, [0, 1]), + SwizzledSharedLayout(1, 1, 1, [1, 0]), + SwizzledSharedLayout(4, 2, 4, [1, 0]), + SwizzledSharedLayout(2, 2, 4, [1, 0]), ] @@ -736,15 +729,15 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t layouts_3d = [ - BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), - BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0]), + BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0]), ] shared_layouts_3d = [ - SwizzledSharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), - SwizzledSharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]), - SwizzledSharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), - SwizzledSharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]), + SwizzledSharedLayout(1, 1, 1, [2, 1, 0]), + SwizzledSharedLayout(4, 2, 4, [1, 2, 0]), + SwizzledSharedLayout(8, 2, 4, [0, 2, 1]), + SwizzledSharedLayout(4, 2, 1, [2, 0, 1]), ] @@ -841,9 +834,9 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path: ] shared_layouts = [ - SwizzledSharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]), - SwizzledSharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]), - SwizzledSharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]), + SwizzledSharedLayout(4, 2, 4, [0, 1]), + SwizzledSharedLayout(8, 1, 8, [1, 0]), + SwizzledSharedLayout(16, 1, 16, [1, 0]), ] @@ -855,7 +848,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device, tmp_path: pathlib num_repeats_N = triton.cdiv(N, N_tile_size) ir = f""" - #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}> + #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0]}}> #shared = #ttg.swizzled_shared<{{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}}> #smem = #ttg.shared_memory @@ -989,7 +982,7 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t ] shared_layouts = [ - SwizzledSharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]), + SwizzledSharedLayout(8, 1, 1, [1, 0]), ] diff --git a/python/triton/experimental/gluon/language/_layouts.py b/python/triton/experimental/gluon/language/_layouts.py index 0d72214b05..7f5a2c4002 100644 --- a/python/triton/experimental/gluon/language/_layouts.py +++ b/python/triton/experimental/gluon/language/_layouts.py @@ -1,22 +1,11 @@ from dataclasses import dataclass, field -from typing import List, Optional +from typing import List + from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type from triton.runtime.jit import constexpr_function import math -def _realize_cta_layout(layout, rank): - ctas_per_cga = layout.ctas_per_cga or [1] * rank - cta_split_num = layout.cta_split_num or [1] * rank - cta_order = layout.cta_order or list(reversed(range(rank))) - # Canonicalize CTA order to [n,n-1,...,0] if CTAsPerCGA is [1...1]. This matches logic in C++. - if all(num_cta == 1 for num_cta in ctas_per_cga): - cta_order = list(range(rank - 1, -1, -1)) - object.__setattr__(layout, "ctas_per_cga", ctas_per_cga) - object.__setattr__(layout, "cta_split_num", cta_split_num) - object.__setattr__(layout, "cta_order", cta_order) - - class DistributedLayout: """ Base class for distributed memory layouts in Gluon IR. @@ -28,7 +17,7 @@ def type(self): @property def rank(self): - return len(self.cta_order) + raise NotImplementedError("DistributedLayout subclasses must define rank") @dataclass(frozen=True) @@ -69,35 +58,25 @@ class BlockedLayout(DistributedLayout): threads_per_warp (List[int]): Number of threads per warp per dimension. warps_per_cta (List[int]): Number of warps per CTA per dimension. order (List[int]): The ordering of dimensions for partitioning. - ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping. - cta_split_num (Optional[List[int]]): Split factors for CTAs. - cta_order (Optional[List[int]]): Ordering for CTAs. + cga_layout (Optional[List[List[int]]]): Bases describing how CTAs tile each dimension. """ size_per_thread: List[int] threads_per_warp: List[int] warps_per_cta: List[int] order: List[int] - ctas_per_cga: Optional[List[int]] = None - cta_split_num: Optional[List[int]] = None - cta_order: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) def __post_init__(self): super().__setattr__("size_per_thread", _unwrap_if_constexpr(self.size_per_thread)) super().__setattr__("threads_per_warp", _unwrap_if_constexpr(self.threads_per_warp)) super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) super().__setattr__("order", _unwrap_if_constexpr(self.order)) - super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) - super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) - super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) rank = len(self.size_per_thread) - _realize_cta_layout(self, rank) + object.__setattr__(self, "cga_layout", self.cga_layout) assert len(self.threads_per_warp) == rank assert len(self.warps_per_cta) == rank assert len(self.order) == rank - assert len(self.ctas_per_cga) == rank - assert len(self.cta_split_num) == rank - assert len(self.cta_order) == rank def _to_ir(self, builder): return builder.get_blocked_layout( @@ -105,9 +84,7 @@ def _to_ir(self, builder): self.threads_per_warp, self.warps_per_cta, self.order, - self.ctas_per_cga, - self.cta_split_num, - self.cta_order, + self.cga_layout, ) def mangle(self) -> str: @@ -121,21 +98,16 @@ def stringify(x): threads_per_warp = stringify(self.threads_per_warp) warps_per_cta = stringify(self.warps_per_cta) order = stringify(self.order) - ctas_per_cga = stringify(self.ctas_per_cga) - cta_split_num = stringify(self.cta_split_num) - cta_order = stringify(self.cta_order) - return f"B{size_per_thread}B{threads_per_warp}B{warps_per_cta}B{order}B{ctas_per_cga}B{cta_split_num}B{cta_order}B" + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"B{size_per_thread}_{threads_per_warp}_{warps_per_cta}_{order}_{cga_layout}B" def __hash__(self): - return hash(( - tuple(self.size_per_thread), - tuple(self.threads_per_warp), - tuple(self.warps_per_cta), - tuple(self.order), - tuple(self.ctas_per_cga) if self.ctas_per_cga else None, - tuple(self.cta_split_num) if self.cta_split_num else None, - tuple(self.cta_order) if self.cta_order else None, - )) + return hash((tuple(self.size_per_thread), tuple(self.threads_per_warp), tuple(self.warps_per_cta), + tuple(self.order), tuple(tuple(vec) for vec in self.cga_layout))) + + @property + def rank(self): + return len(self.order) @dataclass(frozen=True) @@ -170,6 +142,16 @@ def __hash__(self): def rank(self): return self.parent.rank - 1 + @property + def cga_layout(self): + parent_cga_layout = self.parent.cga_layout + if not parent_cga_layout: + return [] + + rank = self.parent.rank + assert 0 <= self.dim < rank + return [basis[:self.dim] + basis[self.dim + 1:] for basis in parent_cga_layout] + @dataclass(frozen=True) class DistributedLinearLayout(DistributedLayout): @@ -261,6 +243,25 @@ def __hash__(self): def rank(self): return self.parent.rank + @property + def cga_layout(self): + parent_cga_layout = _unwrap_if_constexpr(getattr(self.parent, "cga_layout", [])) or [] + if not parent_cga_layout: + return [] + + rank = self.parent.rank + assert all(len(basis) == rank for basis in parent_cga_layout) + + k_dim = rank - 1 if self.operand_index == 0 else rank - 2 + assert 0 <= k_dim < rank + + derived = [] + for basis in parent_cga_layout: + new_basis = list(basis) + new_basis[k_dim] = 0 + derived.append(new_basis) + return derived + @dataclass(frozen=True, eq=True) class NVMMADistributedLayout(DistributedLayout): @@ -271,43 +272,39 @@ class NVMMADistributedLayout(DistributedLayout): version (List[int]): Version identifier for the MMA instruction. warps_per_cta (List[int]): Number of warps per CTA. instr_shape (List[int]): Instruction shape for MMA. - ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping. - cta_split_num (Optional[List[int]]): Split factors for CTAs. - cta_order (Optional[List[int]]): CTA ordering. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. """ version: List[int] warps_per_cta: List[int] instr_shape: List[int] - ctas_per_cga: Optional[List[int]] = None - cta_split_num: Optional[List[int]] = None - cta_order: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) def __post_init__(self): super().__setattr__("version", _unwrap_if_constexpr(self.version)) super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape)) - super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) - super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) - super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) - rank = len(self.warps_per_cta) - _realize_cta_layout(self, rank) - assert len(self.ctas_per_cga) == rank - assert len(self.cta_split_num) == rank - assert len(self.cta_order) == rank + object.__setattr__(self, "cga_layout", self.cga_layout) def _to_ir(self, builder): - return builder.get_mma_layout(self.version, self.warps_per_cta, self.ctas_per_cga, self.cta_split_num, - self.cta_order, self.instr_shape) + return builder.get_mma_layout( + self.version, + self.warps_per_cta, + self.cga_layout, + self.instr_shape, + ) def mangle(self) -> str: - return f"MMA_{self.version}_{self.warps_per_cta}_{self.instr_shape}_{self.ctas_per_cga}_{self.cta_split_num}_{self.cta_order}_MMA" + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"MMA_{self.version}_{self.warps_per_cta}_{self.instr_shape}_{cga_layout}_MMA" def __hash__(self): - return hash((tuple(self.version), tuple(self.warps_per_cta), - tuple(self.instr_shape), tuple(self.ctas_per_cga) if self.ctas_per_cga else None, - tuple(self.cta_split_num) if self.cta_split_num else None, - tuple(self.cta_order) if self.cta_order else None)) + return hash((tuple(self.version), tuple(self.warps_per_cta), tuple(self.instr_shape), + tuple(tuple(vec) for vec in self.cga_layout))) + + @property + def rank(self): + return len(self.warps_per_cta) class SharedLayout: @@ -321,11 +318,22 @@ def type(self): @constexpr_function -def _get_shape_per_cta(shape, cta_split_num): - shape_per_cta = shape - if cta_split_num is not None: - assert len(cta_split_num) == len(shape) - shape_per_cta = [shape_per_cta[dim] // cta_split_num[dim] for dim in range(len(shape_per_cta))] +def _get_shape_per_cta(shape, cga_layout): + if not cga_layout: + return shape + shape_per_cta = list(shape) + rank = len(cga_layout[0]) + cga_shape = [1] * rank + for basis in cga_layout: + assert len(basis) == rank + for i in range(rank): + cga_shape[i] = max(cga_shape[i], basis[i]) + # The shape is the largest stride * 2 + for i in range(rank): + cga_shape[i] *= 2 + for dim in range(rank): + assert shape_per_cta[dim] % cga_shape[dim] == 0, f"Shape {shape} is not divisible by CGA layout {cga_layout}" + shape_per_cta[dim] //= cga_shape[dim] return shape_per_cta @@ -340,36 +348,31 @@ class NVMMASharedLayout(SharedLayout): rank (int): Rank of the tensor. transposed (bool): Whether the layout is transposed. fp4_padded (bool): Whether FP4 padding is used. - ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping. - cta_split_num (Optional[List[int]]): Split factors for CTAs. - cta_order (Optional[List[int]]): CTA ordering. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. """ swizzle_byte_width: int element_bitwidth: int - rank: int + rank: int = 2 transposed: bool = False fp4_padded: bool = False - ctas_per_cga: Optional[List[int]] = None - cta_split_num: Optional[List[int]] = None - cta_order: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) def __post_init__(self): super().__setattr__("swizzle_byte_width", _unwrap_if_constexpr(self.swizzle_byte_width)) super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth)) - super().__setattr__("rank", _unwrap_if_constexpr(self.rank)) super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed)) super().__setattr__("fp4_padded", _unwrap_if_constexpr(self.fp4_padded)) - super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) - super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) - super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) + + # TODO: Make rank optional and check that (rank or cga_layout) + cga_layout = self.cga_layout or [] + if cga_layout: + assert len(cga_layout[0]) == self.rank + + super().__setattr__("rank", _unwrap_if_constexpr(self.rank)) + super().__setattr__("cga_layout", _unwrap_if_constexpr(cga_layout)) assert self.element_bitwidth in [8, 16, 32, 64] assert self.swizzle_byte_width in [0, 32, 64, 128] - rank = self.rank - _realize_cta_layout(self, rank) - assert len(self.ctas_per_cga) == rank - assert len(self.cta_split_num) == rank - assert len(self.cta_order) == rank def _to_ir(self, builder): return builder.get_nvmma_shared_layout( @@ -377,22 +380,20 @@ def _to_ir(self, builder): self.element_bitwidth, self.transposed, self.fp4_padded, - self.ctas_per_cga, - self.cta_split_num, - self.cta_order, + self.cga_layout, + self.rank, ) @staticmethod @constexpr_function - def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, ctas_per_cga=None, cta_split_num=None, - cta_order=None): + def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, cga_layout=None): """Returns an NVMMASharedLayout with default swizzling for a given shape. This picks the largest swizzle pattern compatible with the shape, which allows emitting the fewest TMA or MMA messages. """ packing_factor = 2 if fp4_padded else 1 - shape_per_cta = _get_shape_per_cta(block_shape, cta_split_num) + shape_per_cta = block_shape if cga_layout is None else _get_shape_per_cta(block_shape, cga_layout) rank = len(block_shape) if transposed: shape_per_cta = shape_per_cta[1:] + shape_per_cta[:1] @@ -419,19 +420,16 @@ def get_default_for(block_shape, dtype, transposed=False, fp4_padded=False, ctas rank=rank, transposed=transposed, fp4_padded=fp4_padded, - ctas_per_cga=ctas_per_cga, - cta_split_num=cta_split_num, - cta_order=cta_order, + cga_layout=cga_layout, ) def mangle(self) -> str: - return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_NVMMA" + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"NVMMA_{self.swizzle_byte_width}_{self.element_bitwidth}_{self.transposed}_{self.fp4_padded}_{cga_layout}_NVMMA" def __hash__(self): return hash((self.swizzle_byte_width, self.element_bitwidth, self.rank, self.transposed, self.fp4_padded, - tuple(self.ctas_per_cga) if self.ctas_per_cga else None, - tuple(self.cta_split_num) if self.cta_split_num else None, - tuple(self.cta_order) if self.cta_order else None)) + tuple(tuple(vec) for vec in self.cga_layout) if self.cga_layout else None)) @dataclass(frozen=True, eq=True) @@ -444,32 +442,21 @@ class SwizzledSharedLayout(SharedLayout): per_phase (int): Elements per swizzle phase. max_phase (int): Maximum number of swizzle phases. order (List[int]): Dimension ordering for swizzling. - ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping. - cta_split_num (Optional[List[int]]): Split factors for CTAs. - cta_order (Optional[List[int]]): CTA ordering. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. """ vec: int per_phase: int max_phase: int order: List[int] - ctas_per_cga: Optional[List[int]] = None - cta_split_num: Optional[List[int]] = None - cta_order: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) def __post_init__(self): super().__setattr__("vec", _unwrap_if_constexpr(self.vec)) super().__setattr__("per_phase", _unwrap_if_constexpr(self.per_phase)) super().__setattr__("max_phase", _unwrap_if_constexpr(self.max_phase)) super().__setattr__("order", _unwrap_if_constexpr(self.order)) - super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) - super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) - super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) - rank = len(self.order) - _realize_cta_layout(self, rank) - assert len(self.ctas_per_cga) == rank - assert len(self.cta_split_num) == rank - assert len(self.cta_order) == rank + object.__setattr__(self, "cga_layout", self.cga_layout) def _to_ir(self, builder): return builder.get_swizzled_shared_layout( @@ -477,9 +464,7 @@ def _to_ir(self, builder): self.per_phase, self.max_phase, self.order, - self.ctas_per_cga, - self.cta_split_num, - self.cta_order, + self.cga_layout, ) def mangle(self) -> str: @@ -489,13 +474,12 @@ def stringify(x): return "" return "_".join(map(str, x)) - return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_SSS" + cga_layout = "_".join("~".join(map(str, vec)) for vec in self.cga_layout) if self.cga_layout else "" + return f"SSS_{self.vec}_{self.per_phase}_{self.max_phase}_{stringify(self.order)}_{cga_layout}_SSS" def __hash__(self): - return hash((self.vec, self.per_phase, self.max_phase, - tuple(self.order), tuple(self.ctas_per_cga) if self.ctas_per_cga else None, - tuple(self.cta_split_num) if self.cta_split_num else None, - tuple(self.cta_order) if self.cta_order else None)) + return hash( + (self.vec, self.per_phase, self.max_phase, tuple(self.order), tuple(tuple(vec) for vec in self.cga_layout))) @dataclass(frozen=True, eq=True) diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index cfbed04814..ec019cbe4a 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -18,8 +18,7 @@ def _is_int_list(value): return isinstance(value, Sequence) and all(isinstance(i, int) for i in value) -def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant, ctas_per_cga, cta_split_num, - cta_order): +def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant, cga_layout=None): _check(isinstance(instr_variant, str), lambda: "instr_variant must be a string") _check(instr_variant in ("32x32b", "16x64b", "16x128b", "16x256b", "16x32bx2", "32x32b_splitn"), lambda: f"unknown instr_variant: {instr_variant}") @@ -31,15 +30,14 @@ def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant rank = len(shape) _check(rank == 2, lambda: "expected a 2D tensor") - ctas_per_cga = list(ctas_per_cga) - cta_split_num = list(cta_split_num) - cta_order = list(cta_order) + if cga_layout is None: + cga_layout = [] splitn = instr_variant == "32x32b_splitn" atom_variant = "32x32b" if splitn else instr_variant - _check(len(ctas_per_cga) == rank, lambda: "ctas_per_cga rank mismatch") - _check(len(cta_split_num) == rank, lambda: "cta_split_num rank mismatch") - _check(len(cta_order) == rank, lambda: "cta_order rank mismatch") + if cga_layout: + for basis in cga_layout: + _check(len(basis) == rank, lambda: "cga_layout basis rank mismatch") layout_obj = compute_tmem_reg_layout( element_ty, @@ -47,9 +45,7 @@ def _compute_tmem_reg_layout(element_ty, shape, layout, num_warps, instr_variant layout, num_warps, atom_variant, - ctas_per_cga, - cta_split_num, - cta_order, + cga_layout, ) _check(layout_obj is not None, lambda: f"TMEM layout '{atom_variant}' unsupported for shape {shape} and num_warps {num_warps}") diff --git a/python/triton/experimental/gluon/language/amd/_layouts.py b/python/triton/experimental/gluon/language/amd/_layouts.py index 5ce3934f25..a3d616fea9 100644 --- a/python/triton/experimental/gluon/language/amd/_layouts.py +++ b/python/triton/experimental/gluon/language/amd/_layouts.py @@ -1,10 +1,10 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Optional from triton.language.core import _unwrap_if_constexpr -from triton.experimental.gluon.language._layouts import _realize_cta_layout, DistributedLayout +from triton.experimental.gluon.language._layouts import DistributedLayout __all__ = [ "AMDMFMALayout", @@ -24,9 +24,7 @@ class AMDMFMALayout(DistributedLayout): warps_per_cta (List[int]): The warp layout in the block. element_bitwidth Optional(int): Bit width of the output element type. Supported values are 32 and 64. Defaults to 32. tiles_per_warp Optional(List[int]): The tile layout within a warp. Defaults to unit tile layout, i.e., single tile on all dimensions. - ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping. - cta_split_num (Optional[List[int]]): Split factors for CTAs. - cta_order (Optional[List[int]]): CTA ordering. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. Current supported versions: @@ -41,9 +39,7 @@ class AMDMFMALayout(DistributedLayout): warps_per_cta: List[int] element_bitwidth: Optional[int] = None tiles_per_warp: Optional[List[int]] = None - ctas_per_cga: Optional[List[int]] = None - cta_split_num: Optional[List[int]] = None - cta_order: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) def __post_init__(self): super().__setattr__("version", _unwrap_if_constexpr(self.version)) @@ -52,21 +48,25 @@ def __post_init__(self): super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) super().__setattr__("element_bitwidth", _unwrap_if_constexpr(self.element_bitwidth)) super().__setattr__("tiles_per_warp", _unwrap_if_constexpr(self.tiles_per_warp)) - super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) - super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) - super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) if self.element_bitwidth is None: object.__setattr__(self, "element_bitwidth", 32) if self.tiles_per_warp is None: object.__setattr__(self, "tiles_per_warp", [1] * len(self.warps_per_cta)) + object.__setattr__(self, "cga_layout", self.cga_layout) self.verify() def _to_ir(self, builder): - return builder.get_amd_mfma_layout(self.version, self.warps_per_cta, self.instr_shape, self.transposed, - self.ctas_per_cga, self.cta_split_num, self.cta_order, self.tiles_per_warp, - self.element_bitwidth) + return builder.get_amd_mfma_layout( + self.version, + self.warps_per_cta, + self.instr_shape, + self.transposed, + self.cga_layout, + self.tiles_per_warp, + self.element_bitwidth, + ) def mangle(self) -> str: @@ -75,7 +75,8 @@ def stringify(x): return "" return "_".join(map(str, x)) - return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{self.element_bitwidth}_{stringify(self.tiles_per_warp)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_MFMA" + cga_layout = stringify(["~".join(map(str, vec)) for vec in self.cga_layout] if self.cga_layout else None) + return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{self.element_bitwidth}_{stringify(self.tiles_per_warp)}_{cga_layout}_MFMA" def verify(self): assert self.version >= 1 and self.version <= 4, "version must be in the [1, 4] range" @@ -85,10 +86,7 @@ def verify(self): assert self.element_bitwidth in [32, 64], "element bitwidth must be 32 or 64" rank = len(self.warps_per_cta) - _realize_cta_layout(self, rank) - assert len(self.ctas_per_cga) == rank - assert len(self.cta_split_num) == rank - assert len(self.cta_order) == rank + assert all(len(vec) == rank for vec in self.cga_layout), "cga_layout basis rank mismatch" def __hash__(self): return hash(( @@ -98,11 +96,13 @@ def __hash__(self): tuple(self.warps_per_cta), self.element_bitwidth if self.element_bitwidth else None, tuple(self.tiles_per_warp) if self.tiles_per_warp else None, - tuple(self.ctas_per_cga) if self.ctas_per_cga else None, - tuple(self.cta_split_num) if self.cta_split_num else None, - tuple(self.cta_order) if self.cta_order else None, + tuple(tuple(vec) for vec in self.cga_layout), )) + @property + def rank(self): + return len(self.warps_per_cta) + @dataclass(frozen=True) class AMDWMMALayout(DistributedLayout): @@ -114,9 +114,7 @@ class AMDWMMALayout(DistributedLayout): transposed (bool): Indicates the result tensor is transposed. warps_per_cta (List[int]): Number of warps per CTA. instr_shape (Optional[List[int]]): Instruction shape (M, N, K). Defaults to (16, 16, 16). - ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping. - cta_split_num (Optional[List[int]]): Split factors for CTAs. - cta_order (Optional[List[int]]): CTA ordering. + cga_layout (Optional[List[List[int]]]): Bases describing CTA tiling. Current supported versions: @@ -129,17 +127,12 @@ class AMDWMMALayout(DistributedLayout): warps_per_cta: List[int] instr_shape: Optional[List[int]] = None tiles_per_warp: Optional[List[int]] = None - ctas_per_cga: Optional[List[int]] = None - cta_split_num: Optional[List[int]] = None - cta_order: Optional[List[int]] = None + cga_layout: List[List[int]] = field(default_factory=list) def __post_init__(self): super().__setattr__("version", _unwrap_if_constexpr(self.version)) super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed)) super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta)) - super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga)) - super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num)) - super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order)) if self.tiles_per_warp is None: tiles_per_warp = [1] * len(self.warps_per_cta) @@ -150,11 +143,18 @@ def __post_init__(self): instr_shape = _unwrap_if_constexpr(self.instr_shape) if self.instr_shape is not None else [16, 16, 16] super().__setattr__("instr_shape", _unwrap_if_constexpr(instr_shape)) + object.__setattr__(self, "cga_layout", self.cga_layout) self.verify() def _to_ir(self, builder): - return builder.get_amd_wmma_layout(self.version, self.transposed, self.warps_per_cta, self.tiles_per_warp, - self.ctas_per_cga, self.cta_split_num, self.cta_order, self.instr_shape) + return builder.get_amd_wmma_layout( + self.version, + self.transposed, + self.warps_per_cta, + self.tiles_per_warp, + self.cga_layout, + self.instr_shape, + ) def mangle(self) -> str: @@ -163,16 +163,14 @@ def stringify(x): return "" return "_".join(map(str, x)) - return f"WMMA_{self.version}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.tiles_per_warp)}_{stringify(self.instr_shape)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_WMMA" + cga_layout = stringify(["~".join(map(str, vec)) for vec in self.cga_layout] if self.cga_layout else None) + return f"WMMA_{self.version}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.tiles_per_warp)}_{stringify(self.instr_shape)}_{cga_layout}_WMMA" def verify(self): assert self.version >= 1 and self.version <= 3, "version must be in the [1, 3] range" rank = len(self.warps_per_cta) - _realize_cta_layout(self, rank) - assert len(self.ctas_per_cga) == rank - assert len(self.cta_split_num) == rank - assert len(self.cta_order) == rank + assert all(len(vec) == rank for vec in self.cga_layout), "cga_layout basis rank mismatch" def __hash__(self): return hash(( @@ -181,7 +179,9 @@ def __hash__(self): tuple(self.warps_per_cta), tuple(self.tiles_per_warp) if self.tiles_per_warp else None, tuple(self.instr_shape) if self.instr_shape else None, - tuple(self.ctas_per_cga) if self.ctas_per_cga else None, - tuple(self.cta_split_num) if self.cta_split_num else None, - tuple(self.cta_order) if self.cta_order else None, + tuple(tuple(vec) for vec in self.cga_layout), )) + + @property + def rank(self): + return len(self.warps_per_cta) diff --git a/python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py b/python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py index 44cdb2e1c7..f69d3005fb 100644 --- a/python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py +++ b/python/triton/experimental/gluon/language/amd/gfx1250/mbarrier.py @@ -10,20 +10,11 @@ class MBarrierLayout(SwizzledSharedLayout): Layout for mbarrier synchronization. Args: - ctas_per_cga (int): CTAs per CGA grouping. Defaults to 1. - cta_split_num (int): CTA split factor. Defaults to 1. + cga_layout (List[List[int]]): CTA layout bases. Defaults to []. """ - def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1): - super().__init__( - vec=1, - per_phase=1, - max_phase=1, - order=[0], - ctas_per_cga=[ctas_per_cga], - cta_split_num=[cta_split_num], - cta_order=[0], - ) + def __init__(self, cga_layout=None): + super().__init__(vec=1, per_phase=1, max_phase=1, order=[0], cga_layout=cga_layout or []) @builtin diff --git a/python/triton/experimental/gluon/language/intel/_layouts.py b/python/triton/experimental/gluon/language/intel/_layouts.py index a07bb626e6..bf12fecb28 100644 --- a/python/triton/experimental/gluon/language/intel/_layouts.py +++ b/python/triton/experimental/gluon/language/intel/_layouts.py @@ -84,3 +84,7 @@ def __hash__(self): self.threads_per_warp, tuple(self.cta_order), )) + + @property + def rank(self): + return len(self.warps_per_cta) diff --git a/python/triton/experimental/gluon/language/nvidia/ampere/mbarrier.py b/python/triton/experimental/gluon/language/nvidia/ampere/mbarrier.py index d8c37c3af8..8f7ac34570 100644 --- a/python/triton/experimental/gluon/language/nvidia/ampere/mbarrier.py +++ b/python/triton/experimental/gluon/language/nvidia/ampere/mbarrier.py @@ -9,20 +9,11 @@ class MBarrierLayout(SwizzledSharedLayout): Layout for mbarrier synchronization in Ampere and later architectures. Args: - ctas_per_cga (int): CTAs per CGA grouping. Defaults to 1. - cta_split_num (int): CTA split factor. Defaults to 1. + cga_layout (List[List[int]]): CTA layout bases. Defaults to []. """ - def __init__(self, ctas_per_cga: int = 1, cta_split_num: int = 1): - super().__init__( - vec=1, - per_phase=1, - max_phase=1, - order=[0], - ctas_per_cga=[ctas_per_cga], - cta_split_num=[cta_split_num], - cta_order=[0], - ) + def __init__(self, cga_layout=None): + super().__init__(vec=1, per_phase=1, max_phase=1, order=[0], cga_layout=cga_layout or []) @builtin diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index 1a9e239bd3..6d1b21c011 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -75,7 +75,7 @@ def mangle(self) -> str: return f"TL{block_str}{stride_str}{cta_split_str}{two_ctas_str}TL" def __hash__(self): - return hash((self.block, self.col_stride, self.cta_split_num)) + return hash((self.block, self.col_stride, self.cta_split_num, self.two_ctas)) @dataclass(frozen=True, eq=True) @@ -93,8 +93,8 @@ def __post_init__(self): assert self.cta_split_num is None or len(self.cta_split_num) == 2 def _to_ir(self, builder): - cta_split_num = self.cta_split_num or [1, 1] - return builder.get_tensor_memory_scales_layout(cta_split_num, ) + cta_split_num = list(self.cta_split_num) if self.cta_split_num else [1, 1] + return builder.get_tensor_memory_scales_layout(cta_split_num) def mangle(self) -> str: cta_split_str = f"CS{self.cta_split_num[0]}x{self.cta_split_num[1]}" if self.cta_split_num else "" @@ -111,9 +111,7 @@ def get_tmem_reg_layout( layout, num_warps, instr_variant="32x32b", - ctas_per_cga=(1, 1), - cta_split_num=(1, 1), - cta_order=(1, 0), + cga_layout=(), ): """ Returns a DistributedLinearLayout compatible with TMEM load/store instructions. @@ -124,9 +122,7 @@ def get_tmem_reg_layout( layout (TensorMemoryLayout): Tensor memory layout descriptor. num_warps (int): Number of warps participating in the operation. instr_variant (str): TMEM instruction variant (e.g. ``\"32x32b\"``). - ctas_per_cga (tuple[int, int]): CTA grouping along each dimension. - cta_split_num (tuple[int, int]): CTA split factors along each dimension. - cta_order (tuple[int, int]): CTA order. + cga_layout (Sequence[Sequence[int]]): CTA layout bases describing CTA distribution. """ def _unwrap(x): @@ -144,9 +140,7 @@ def _unwrap(x): _unwrap(layout), _unwrap(num_warps), _unwrap(instr_variant), - _unwrap(ctas_per_cga), - _unwrap(cta_split_num), - _unwrap(cta_order), + _unwrap(cga_layout), ) @@ -274,7 +268,7 @@ def slice(self, start, length, _semantic: GluonSemantic) -> None: (layout.block[0], min(layout.block[1], length)), layout.col_stride, layout.cta_split_num, - two_ctas=layout.two_ctas, + layout.two_ctas, ) ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape) builder = _semantic.builder diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 15adbf6e44..52fce7b630 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -3,11 +3,12 @@ from dataclasses import dataclass import triton +from triton_kernels import target_info from triton_kernels.target_info import get_cdna_version from triton_kernels.tensor import FP4 import torch from .opt_flags_details import opt_flags_amd, opt_flags_nvidia, opt_flags_intel -from triton_kernels.tensor import bitwidth +from triton_kernels.tensor import bitwidth, get_layout @dataclass @@ -297,8 +298,12 @@ def make_default_opt_flags_nvidia( n_sms = torch.cuda.get_device_properties(0).multi_processor_count tiles_per_sm = grid_size_tma / n_sms supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9) + requires_persistent = (get_layout(precision_config.act_scale) is not None or get_layout(precision_config.weight_scale) is not None) and target_info.has_native_mxfp() if constraints.get("is_persistent", None) is not None: is_persistent = constraints["is_persistent"] + elif requires_persistent: + assert supports_persistent, "persistent kernel required but not supported" + is_persistent = True else: has_simple_epilogue = precision_config.max_num_imprecise_acc is None is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4 diff --git a/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir b/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir index 4087fcf85c..690ff63cbb 100644 --- a/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir +++ b/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir @@ -56,8 +56,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- // Broadcast to all CTAs so we should just see 15 (0b1111) as the broadcast mask since we have 4 CTAs per CGA -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> -#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [2, 2], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[0, 0], [0, 0]]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CGALayout = [[0, 0], [0, 0]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: async_load_multicast_to_all_ctas @@ -74,14 +74,13 @@ module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- // 8 CTAs, 2 multicast groups of 4 CTAs each. Each group is strided by 1 so the base mask should be 0b1010101 (85) and the non free mask is -7 (~0b110) -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[1, 0], [0, 0], [0, 0]]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CGALayout = [[1, 0], [0, 0], [0, 0]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: async_load_multicast_to_half_ctas tt.func public @async_load_multicast_to_half_ctas(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}, %arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { - // CHECK: llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-7 : i32) : i32 // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]] @@ -96,15 +95,14 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- // 16 CTAs, 8 multicast groups of 2 CTAs each, each group is strided by 8 so the base mask should be 0b100000001 (257) and the non free mask is -9 (~0b1000) -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 8], CTAOrder = [1, 0]}> -#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 8], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[0, 1], [0, 2], [0, 4], [0, 0]]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CGALayout = [[0, 1], [0, 2], [0, 4], [0, 0]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: async_load_multicast_group_of_2_strided_by_8 tt.func public @async_load_multicast_group_of_2_strided_by_8(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}, %arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { // Skip the first cluster id because it's emitted for address calculation - // CHECK: llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32 // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]] @@ -119,8 +117,8 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha // ----- // 16 CTAs split into 16 multicast groups so we should not emit cluster load since we do not share any data -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 16], CTAOrder = [1, 0]}> -#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 16], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[0, 1], [0, 2], [0, 4], [0, 8]]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CGALayout = [[0, 1], [0, 2], [0, 4], [0, 8]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: async_load_multi_cta_but_not_data_sharing @@ -139,14 +137,13 @@ module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.sha // Test with linear layout as src layout // 16 CTAs, 8 multicast groups of 2 CTAs each, each group is strided by 8 so the base mask should be 0b100000001 (257) and the non free mask is -9 (~0b1000) #linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[0, 0], [0, 0], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = [[0, 4], [0, 8], [0, 16], [0, 0]], order = [1, 0]}> -#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CTAsPerCGA = [1, 16], CTASplitNum = [1, 8], CTAOrder = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0], CGALayout = [[0, 1], [0, 2], [0, 4], [0, 0]]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 16 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: async_load_multi_cta_linear_layout tt.func public @async_load_multi_cta_linear_layout(%arg0: tensor<32x32x!tt.ptr, #linear> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}, %arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { // Skip the first cluster id because it's emitted for address calculation - // CHECK: llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[CTA_ID:.*]] = {{.*}}llvm.amdgcn.cluster.workgroup.id.x // CHECK: %[[NON_FREE_BITS:.*]] = llvm.mlir.constant(-9 : i32) : i32 // CHECK: %[[SHIFT_AMOUNT:.*]] = llvm.and %[[CTA_ID]], %[[NON_FREE_BITS]] diff --git a/test/Conversion/amd/cluster_load.mlir b/test/Conversion/amd/cluster_load.mlir index 5ca68513c6..1edfa58c68 100644 --- a/test/Conversion/amd/cluster_load.mlir +++ b/test/Conversion/amd/cluster_load.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 | FileCheck %s -// CTASplitNum == CTAsPerCGA so we should not emit cluster loads since there is no cross-CTA data sharing -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [8, 1], CTAOrder = [1, 0]}> +// CGA layout has no broadcasting so we should not emit cluster loads +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[1, 0], [2, 0], [4, 0]]}> module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: load_multi_cta_but_no_broadcast tt.func public @load_multi_cta_but_no_broadcast(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}) { @@ -14,7 +14,7 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- // 8 CTAs, 2 multicast groups of 4 CTAs each. Each group is strided by 1 so the base mask should be 0b1010101 (85) and the non free mask is -7 (~0b110) -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[1, 0], [0, 0], [0, 0]]}> module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: cluster_load_b128 tt.func public @cluster_load_b128(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}) { @@ -33,7 +33,7 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- // Note that we already check the correct multicast mask in previous tests, so we only check the cluster load instruction here -#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[1, 0], [0, 0], [0, 0]]}> module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: cluster_load_b64 tt.func public @cluster_load_b64(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}) { @@ -47,7 +47,7 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- // Note that we already check the correct multicast mask in previous tests, so we only check the cluster load instruction here -#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[1, 0], [0, 0], [0, 0]]}> module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: cluster_load_b32 tt.func public @cluster_load_b32(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}) { @@ -61,7 +61,7 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- // Smaller vector size than 2 (32bit) should not produce cluster loads -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[1, 0], [0, 0], [0, 0]]}> module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: not_cluster_load_for_b16 tt.func public @not_cluster_load_for_b16(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}) { @@ -75,7 +75,7 @@ module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // Check that we break sizePerThread > 4 (>128bit) into multiple cluster loads b128 // Note that we already check the correct multicast mask in previous tests, so we only check the cluster load instruction here -#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [8, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CGALayout = [[1, 0], [0, 0], [0, 0]]}> module attributes {"ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { // CHECK-LABEL: cluster_load_2_b128 tt.func public @cluster_load_2_b128(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}) { diff --git a/test/Conversion/amd/math-denorm-handling.mlir b/test/Conversion/amd/math-denorm-handling.mlir index c3ab9df370..a09cf197cb 100644 --- a/test/Conversion/amd/math-denorm-handling.mlir +++ b/test/Conversion/amd/math-denorm-handling.mlir @@ -64,22 +64,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { tt.func public @test_sqrt_rn_f32(%arg0: tensor<64xf32, #blocked>) { - // LLVM_FTZ-LABEL: test_sqrt_rn_f32 - // LLVM_FTZ: llvm.amdgcn.rsq.f32 - // LLVM_FTZ: llvm.fmul - // LLVM_FTZ: llvm.fmul - // LLVM_FTZ: llvm.fneg - // LLVM_FTZ: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.fneg - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.fma - // LLVM_FTZ-NEXT: llvm.intr.is.fpclass - // LLVM_FTZ-NEXT: llvm.select - // - // LLVM_NO_FTZ-LABEL: test_sqrt_rn_f32 - // LLVM_NO_FTZ: llvm.intr.sqrt + // COMMON-LABEL: test_sqrt_rn_f32 + // COMMON: llvm.intr.sqrt %0 = tt.precise_sqrt %arg0 : tensor<64xf32, #blocked> tt.return } @@ -96,3 +82,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_divf_rn_f32(%arg0: tensor<64xf32, #blocked>, %arg1: tensor<64xf32, #blocked>) { + // COMMON-LABEL: test_divf_rn_f32 + // COMMON: llvm.fdiv + %0 = tt.precise_divf %arg0, %arg1 : tensor<64xf32, #blocked> + tt.return + } +} diff --git a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir index 6337fec57e..9f67f5cb66 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir @@ -200,3 +200,29 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [32, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [16, 0], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [4, 1], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8_chained + tt.func @wmma_scaled_dot_fp8_chained(%arg0: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg2: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %scale0 = arith.constant dense<127> : tensor<128x4xi8, #linear> + %scale1 = arith.constant dense<127> : tensor<128x4xi8, #linear1> + // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %mm0 = tt.dot_scaled %arg0 scale %scale0, %arg2 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma> + // CHECK-NOT: rocdl.ds_swizzle + // CHECK-NOT: llvm.call_intrinsic "llvm.amdgcn.permlane16.swap" + %op0 = ttg.convert_layout %mm0 : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + %op1 = tt.fp_to_fp %op0, rounding = rtne : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> -> tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + // CHECK-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %mm1 = tt.dot_scaled %op1 scale %scale0, %arg3 scale %scale1, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear> * tensor<128x128xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>, tensor<128x4xi8, #linear1> -> tensor<128x128xf32, #mma> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<128x128x!tt.ptr, #mma> + tt.store %ptr0, %mm1 : tensor<128x128x!tt.ptr, #mma> + tt.return + } +} diff --git a/test/Conversion/intel/tritongpu_to_gen.mlir b/test/Conversion/intel/tritongpu_to_gen.mlir index 5c9ca30a21..79b6506e36 100644 --- a/test/Conversion/intel/tritongpu_to_gen.mlir +++ b/test/Conversion/intel/tritongpu_to_gen.mlir @@ -1389,7 +1389,7 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- -#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0], [0]]}> module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_program_id tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index ed2cf3d44e..9b9bfc6580 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1624,7 +1624,7 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- -#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0], [0]]}> module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { // CHECK-LABEL: test_get_program_id tt.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { @@ -1667,7 +1667,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- -#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}> +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0], [0]]}> module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} { tt.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { %blockdimx = tt.get_num_programs x : i32 diff --git a/test/Conversion/tritongpu_to_llvm_blackwell.mlir b/test/Conversion/tritongpu_to_llvm_blackwell.mlir index 212e6b7606..48eceb7c38 100644 --- a/test/Conversion/tritongpu_to_llvm_blackwell.mlir +++ b/test/Conversion/tritongpu_to_llvm_blackwell.mlir @@ -67,10 +67,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} { // ----- -#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 2], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> -#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> -#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [2], CTASplitNum = [1], CTAOrder = [0]}> +#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CGALayout = [[0, 0]], instrShape = [16, 256, 32]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16, CGALayout = [[0, 0]]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16, CGALayout = [[0, 0]]}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 8 : i32} { // CHECK-LABEL: @tc_gen5_mma_multi_ctas @@ -266,9 +266,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- -#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> -#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [2], CTASplitNum = [1], CTAOrder = [0]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16, CGALayout = [[0, 1]]}> +#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CGALayout = [[0]]}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttng.two-ctas" = true} { // CHECK-LABEL: @tc_gen5_mma_2ctas diff --git a/test/Triton/invalid.mlir b/test/Triton/invalid.mlir index 1aeb76eceb..3761245552 100644 --- a/test/Triton/invalid.mlir +++ b/test/Triton/invalid.mlir @@ -269,8 +269,8 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32>) { // ----- // Valid op with blocked encoding. -#blocked2 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> -#blocked3 = #ttg.blocked<{sizePerThread = [2,1,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CGALayout = [[0, 1, 0], [0, 0, 1], [0, 0, 2]]}> +#blocked3 = #ttg.blocked<{sizePerThread = [2,1,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CGALayout = [[1, 0, 0], [0, 0, 1], [0, 0, 2]]}> module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked2>) { %b = tt.trans %arg0 {order = array} : tensor<16x32x64xf32, #blocked2> -> tensor<32x16x64xf32, #blocked3> @@ -283,8 +283,8 @@ tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked2>) { // Valid op with shared encoding. #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [3, 2, 1, 0]}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 2, 0, 3]}> -#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CTAsPerCGA = [1, 2], CTASplitNum = [2, 4], CTAOrder = [0, 1]}> -#shared3 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32, CTAsPerCGA = [2, 1], CTASplitNum = [4, 2], CTAOrder = [1, 0]}> +#shared2 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32, CGALayout = [[1, 0], [0, 1], [0, 2]]}> +#shared3 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 32, CGALayout = [[0, 1], [1, 0], [2, 0]]}> #smem = #ttg.shared_memory module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: !ttg.memdesc<2x4x8x16xf32, #shared, #smem>, %arg1: !ttg.memdesc<16x32xf32, #shared2, #smem>) { @@ -297,8 +297,8 @@ tt.func public @fn(%arg0: !ttg.memdesc<2x4x8x16xf32, #shared, #smem>, %arg1: !tt // ----- // Invalid blocked encoding. -#blocked = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CTAsPerCGA = [1,2,4], CTASplitNum = [1,2,4], CTAOrder = [0,1,2]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CTAsPerCGA = [2,1,4], CTASplitNum = [2,1,4], CTAOrder = [1,0,2]}> +#blocked = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [2,4,4], warpsPerCTA = [2,4,8], order = [0,1,2], CGALayout = [[0, 1, 0], [0, 0, 1], [0, 0, 2]]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1,2,4], threadsPerWarp = [4,2,4], warpsPerCTA = [4,2,8], order = [1,0,2], CGALayout = [[1, 0, 0], [0, 0, 1], [0, 0, 2]]}> module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 8 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @fn(%arg0: tensor<16x32x64xf32, #blocked>) { // expected-error @+1 {{type}} diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 901b087b9f..932c6cf8db 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -285,15 +285,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[1, 0]]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[1, 0]]}> module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // LAYOUT_16x256{LITERAL}: #ttg.linear<{register = [[0, 1], [8, 0], [0, 8], [0, 16], [0, 32], [0, 64], [0, 128]], lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], warp = [[16, 0], [32, 0]], block = [[64, 0]]}> // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding - // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = {{\[\[1, 0\]\]}}}> // CHECK-DAG: #[[$L:.+]] = #ttg.linear<{register = {{\[\[0, 1\], \[0, 2\], \[0, 4\], \[0, 8\], \[0, 16\], \[0, 32\], \[0, 64\]\]}}, lane = {{\[\[1, 0\], \[2, 0\], \[4, 0\], \[8, 0\], \[0, 128\]\]}}, warp = {{\[\[16, 0\], \[32, 0\]\]}}, block = {{\[\[64, 0\]\]}}}> - // CHECK-LABEL: mmav5 + // CHECK-LABEL: mmav5_multi_ctas // CHECK-DAG: %[[TRUE:.+]] = arith.constant true // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !ttg.memdesc<64x256xf16, #{{.*}}, #smem @@ -302,8 +301,8 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: %[[R:.+]], %{{.*}} = ttng.tmem_load %[[ACC]][%[[MMA_TOK]]] : !ttg.memdesc<128x256xf32, #{{.*}}, #ttng.tensor_memory, mutable> -> tensor<128x256xf32 // CHECK: %[[CVT:.+]] = ttg.convert_layout %[[R]] : tensor<128x256xf32, #[[$L]]> -> tensor<128x256xf32, #[[$B]]> // CHECK: tt.return %[[CVT]] : tensor<128x256xf32 - tt.func public @mmav5_multi_ctas(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { - %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + tt.func public @mmav5_multi_ctas(%a: tensor<128x64xf16, #blocked>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> { + %ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> %bd = ttg.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> %d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked> tt.return %d : tensor<128x256xf32, #blocked> @@ -313,16 +312,16 @@ module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[1, 0]]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[1, 0]]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = [[1, 0]]}> module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding - // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> + // CHECK-DAG: #[[$B:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CGALayout = {{\[\[1, 0\]\]}}}> // CHECK-DAG: #[[$L:.+]] = #ttg.linear<{register = {{\[\[0, 1\], \[0, 2\], \[0, 4\], \[0, 8\], \[0, 16\], \[0, 32\], \[0, 64\]\]}}, lane = {{\[\[1, 0\], \[2, 0\], \[4, 0\], \[8, 0\], \[16, 0\]\]}}, warp = {{\[\[32, 0\], \[0, 128\]\]}}, block = {{\[\[64, 0\]\]}}}> - // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> - // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> - // CHECK-LABEL: mmav5 + // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = {{\[\[1, 0\]\]}}}> + // CHECK-DAG: #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = {{\[\[0, 1\]\]}}}> + // CHECK-LABEL: mmav5_2ctas // CHECK-DAG: %[[TRUE:.+]] = arith.constant true // CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xf16, #{{.*}}>) -> !ttg.memdesc<128x64xf16, #{{.*}}, #smem // CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x256xf16, #{{.*}}>) -> !ttg.memdesc<64x256xf16, #{{.*}}, #smem diff --git a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir index 9a967924e4..01e92768d9 100644 --- a/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir +++ b/test/TritonGPU/amd/amd-pipeline-chained-dots.mlir @@ -160,3 +160,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return %6 : tensor<128x16xf32, #mma> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [8, 1], instrShape = [16, 16, 16], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: tt.func @chained_dots_with_load_bias_in_between + + // Similar to the previous test but load bias tensor bewteen 2 dots + // We expect the unstreamable load can be kept after pipelining + + // CHECK: scf.for + // CHECK: tt.dot + // CHECK: ttg.async_copy_global_to_local + // CHECK: tt.dot + // CHECK: ttg.async_wait + // CHECK: ttg.local_load + // CHECK: tt.load + // CHECK: scf.yield + + tt.func @chained_dots_with_load_bias_in_between(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg2: i64 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32) -> tensor<256x64xf32, #mma> { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x64xf32, #mma> + %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %3 = tt.broadcast %1 : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %4 = tt.addptr %2, %3 : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %5 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %6 = tt.splat %arg3 : !tt.ptr -> tensor<256x64x!tt.ptr, #blocked> + %7 = scf.for %arg5 = %c0_i32 to %arg4 step %c1_i32 iter_args(%arg6 = %cst) -> (tensor<256x64xf32, #mma>) : i32 { + %8 = tt.load %4 : tensor<64x64x!tt.ptr, #blocked> + %9 = ttg.convert_layout %8 : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %10 = tt.dot %arg1, %9, %cst : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x64xf32, #mma> + %11 = arith.muli %arg5, %c64_i32 : i32 + %12 = tt.splat %11 : i32 -> tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %13 = arith.addi %12, %5 : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %14 = tt.expand_dims %13 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %15 = tt.broadcast %14 : tensor<1x64xi32, #blocked> -> tensor<256x64xi32, #blocked> + %bias_ptr = tt.addptr %6, %15 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + %bias = tt.load %bias_ptr : tensor<256x64x!tt.ptr, #blocked> + %bias_mma = ttg.convert_layout %bias : tensor<256x64xf16, #blocked> -> tensor<256x64xf16, #mma> + %bias_f32 = arith.extf %bias_mma : tensor<256x64xf16, #mma> to tensor<256x64xf32, #mma> + %dot_bias = arith.addf %10, %bias_f32 : tensor<256x64xf32, #mma> + %21 = arith.truncf %dot_bias : tensor<256x64xf32, #mma> to tensor<256x64xf16, #mma> + %22 = ttg.convert_layout %21 : tensor<256x64xf16, #mma> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %23 = tt.dot %22, %9, %arg6 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x64xf32, #mma> + scf.yield %23 : tensor<256x64xf32, #mma> + } + tt.return %7 : tensor<256x64xf32, #mma> + } +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 997354685f..5421fa8d19 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2500,11 +2500,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> - // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) - // FIXME: The optimal number of conversions should be 4. - // CHECK-COUNT-5: convert_layout + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-COUNT-4: convert_layout // CHECK-NOT: convert_layout - // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> // CHECK: } // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 1c126282f5..c5e6f61d6c 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -173,16 +173,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [0, 1]}> -#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> -#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0], CGALayout = [[1, 0]]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 2], order = [0, 1], CGALayout = [[0, 1]]}> +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[1, 0]]}> +#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = [[0, 1]]}> #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.target" = "cuda:100", "ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - // CHECK-DAG: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> - // CHECK-DAG: #[[SHARED:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> - // CHECK-DAG: #[[SHARED_TRANS:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16, CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [0, 1]}> + // CHECK-DAG: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0], CGALayout = {{\[\[1, 0\]\]}}}> + // CHECK-DAG: #[[SHARED:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16, CGALayout = {{\[\[0, 1\]\]}}}> + // CHECK-DAG: #[[SHARED_TRANS:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16, CGALayout = {{\[\[1, 0\]\]}}}> // CHECK: %[[ALLOC:.*]] = ttg.local_alloc %arg0 : (tensor<128x64xf8E4M3FN, #[[BLOCKED]]>) -> !ttg.memdesc<128x64xf8E4M3FN, #[[SHARED_TRANS]], #smem> // CHECK: %[[TRANS:.*]] = ttg.memdesc_trans %[[ALLOC]] {order = array} : !ttg.memdesc<128x64xf8E4M3FN, #[[SHARED_TRANS]], #smem> -> !ttg.memdesc<64x128xf8E4M3FN, #[[SHARED]], #smem> // CHECK: ttng.tc_gen5_mma %arg1, %[[TRANS]] diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index 14151b6b0a..b7c28233ab 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt --split-input-file %s --verify-diagnostics -#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CGALayout = [[0, 0]]}> #smem = #ttg.shared_memory tt.func public @non_trivial_block(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) { %zero = arith.constant 0 : i32 diff --git a/test/TritonGPU/verify-blocked-layout.mlir b/test/TritonGPU/verify-blocked-layout.mlir index 0f523ab291..c6e0ad7d56 100644 --- a/test/TritonGPU/verify-blocked-layout.mlir +++ b/test/TritonGPU/verify-blocked-layout.mlir @@ -4,10 +4,7 @@ sizePerThread=[1, 1], threadsPerWarp=[16, 1], warpsPerCTA=[4, 1], - order=[0, 1], - CTAsPerCGA=[2, 1], - CTASplitNum=[1, 1], - CTAOrder=[0, 1] + order=[0, 1], CGALayout = [[0, 0]] }> module attributes { "ttg.num-warps" = 4 : i32, @@ -27,10 +24,7 @@ module attributes { sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 2], - order=[0, 1], - CTAsPerCGA=[2, 1], - CTASplitNum=[1, 1], - CTAOrder=[0, 1] + order=[0, 1], CGALayout = [[0, 0]] }> module attributes { "ttg.num-warps" = 4 : i32, @@ -70,10 +64,7 @@ module attributes { sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], - order=[0, 1], - CTAsPerCGA=[1, 2], - CTASplitNum=[1, 1], - CTAOrder=[0, 1] + order=[0, 1], CGALayout = [[0, 0]] }> module attributes { "ttg.num-warps" = 4 : i32, @@ -94,10 +85,7 @@ module attributes { sizePerThread=[1, 1], threadsPerWarp=[32, 1], warpsPerCTA=[4, 1], - order=[0, 1], - CTAsPerCGA=[1, 2], - CTASplitNum=[1, 1], - CTAOrder=[0, 1] + order=[0, 1], CGALayout = [[0, 0]] }> module attributes { "ttg.num-warps" = 4 : i32, diff --git a/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir b/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir index 417e06aaee..9c4b24e464 100644 --- a/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir +++ b/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir @@ -2,7 +2,7 @@ // COM: Tests reduction when threads_per_warp < num_warps. -#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: reduce_problem_size_64_threads_per_warp_32 tt.func @reduce_problem_size_64_threads_per_warp_32(%f : tensor<2048xi32, #blocked>) { diff --git a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir index a020c41d6f..3aecd10ee6 100644 --- a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir +++ b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir @@ -162,8 +162,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 2], CTASplitNum = [1, 2], CTAOrder = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [2, 1], CTASplitNum = [2, 1], CTAOrder = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 32], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[0, 1]]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 64], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CGALayout = [[1, 0]]}> #tmem = #ttng.tensor_memory_encoding #tmem1 = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 2 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32, ttg.shared = 65536 : i32} { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 2dddeb898c..5d5165796c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -2264,73 +2264,6 @@ struct SqrtOpConversion } } -private: - bool ftz; -}; - -struct PreciseSqrtOpConversion - : ElementwiseOpConversionBase { - explicit PreciseSqrtOpConversion(LLVMTypeConverter &typeConverter, - ModuleAxisInfoAnalysis &axisInfoAnalysis, - bool ftz, PatternBenefit benefit) - : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), - ftz(ftz) {} - - SmallVector createDestOps(triton::PreciseSqrtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - auto b = TritonLLVMOpBuilder(loc, rewriter); - // If the op is neither FP32 nor denorm flushing(ftz), it's directly lowered - // to LLVM::SqrtOp. - if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz) { - return {LLVM::SqrtOp::create(rewriter, loc, elemTy, operands[0], - adaptor.getAttributes().getValue())}; - } - - // On the AMDGPU backend, instructions legalized from LLVM::SqrtOp are - // designed to always preserve denorms, according to - // https://github.com/llvm/llvm-project/blob/3d6b2d49/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp#L5235-L5314. - // - // For f32 inputs with ftz enabled, we need to manually lower the op to - // bypass the scaling-up-and-down process while keeping other parts - // unchanged. To ensure IEEE-compliant results, we approximate `sqrt(x)` - // using `x * rsq(x)` and apply extra refinement iterations to correct the - // result. - StringRef funcName = "llvm.amdgcn.rsq.f32"; - - Type funcType = getFunctionType(elemTy, operands[0]); - LLVM::LLVMFuncOp funcOp = - appendOrGetExternFuncOp(rewriter, op, funcName, funcType); - - Value sqrtR = - LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult(); - - Value sqrtX = operands[0][0]; - Value sqrtS = b.fmul(f32_ty, sqrtX, sqrtR); - - // Refine the approximation with Newton iteration - Value sqrtH = b.fmul(f32_ty, sqrtR, b.f32_val(0.5f)); - Value sqrtE = b.fma(b.neg(f32_ty, sqrtH), sqrtS, b.f32_val(0.5f)); - sqrtH = b.fma(sqrtH, sqrtE, sqrtH); - sqrtS = b.fma(sqrtS, sqrtE, sqrtS); - Value sqrtD = b.fma(b.neg(f32_ty, sqrtS), sqrtS, sqrtX); - sqrtS = b.fma(sqrtD, sqrtH, sqrtS); - - // Handle +0/-0/+inf - // These flags come from - // https://github.com/llvm/llvm-project/blob/217e0f39/llvm/include/llvm/ADT/FloatingPointMode.h#L239-L265. - const unsigned fcPosInf = 0x0200; - const unsigned fcNegZero = 0x0020; - const unsigned fcPosZero = 0x0040; - const unsigned fcZero = fcNegZero | fcPosZero; - - Value isZeroOrPosInf = - LLVM::IsFPClass::create(rewriter, loc, i1_ty, sqrtX, fcPosInf | fcZero); - return {b.select(isZeroOrPosInf, sqrtX, sqrtS)}; - } - private: bool ftz; }; @@ -2382,6 +2315,8 @@ void populateElementwiseOpToLLVMPatterns( typeConverter, axisInfoAnalysis, benefit); patterns.add>( typeConverter, axisInfoAnalysis, benefit); + patterns.add>( + typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); @@ -2409,8 +2344,6 @@ void populateElementwiseOpToLLVMPatterns( patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); - patterns.add(typeConverter, axisInfoAnalysis, ftz, - benefit); triton::populateElementwiseOpToLLVMPatterns( typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); bool hwNanPropagationSupported = targetInfo.supportMaximumMinimum(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 3d1d50ba9d..2d7c4256f4 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1034,8 +1034,13 @@ struct AsyncCopyGlobalToLocalOpConversion zipLoadValues(rewriter, loc, vec, srcElems, srcPtrTy, maskElements, otherElems, otherTy, swizzledLaneOffsets); - Value threadPred = emitRedundantThreadPredicate(getFreeVariableMasks(srcTy), - rewriter, loc, targetInfo); + auto freeVarMasks = getFreeVariableMasks(srcTy); + // We load redundant data on different CTAs so each CTA has a copy in its + // shared memory; the multicast mask will be used by the hardware to + // efficiently broadcast to different CTAs. + freeVarMasks[rewriter.getStringAttr("block")] = 0; + Value threadPred = + emitRedundantThreadPredicate(freeVarMasks, rewriter, loc, targetInfo); auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); auto emitGlobalLoadLds = diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index b6431afcd2..df44e439c2 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -806,7 +806,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { auto moduleOp = dotOp->getParentOfType(); int numWarps = ttg::lookupNumWarps(dotOp); - ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + ttg::CTAEncodingAttr ctaLayout = + ttg::getCTALayout(oldRetType.getEncoding()); int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); // Choose a suitable MFMA instruction for this scaled dot op. @@ -1063,7 +1064,8 @@ class ScaledBlockedToScaledMFMAF8F6F4 final MLIRContext *ctx = dotOp.getContext(); - ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + ttg::CTAEncodingAttr ctaLayout = + ttg::getCTALayout(oldRetType.getEncoding()); unsigned numWarps = ttg::lookupNumWarps(dotOp); if (numWarps == 1) return rewriter.notifyMatchFailure(dotOp, @@ -1279,7 +1281,8 @@ class ScaledBlockedToScaledWMMAF8F6F4 final MLIRContext *ctx = dotOp.getContext(); - ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + ttg::CTAEncodingAttr ctaLayout = + ttg::getCTALayout(oldRetType.getEncoding()); unsigned numWarps = ttg::lookupNumWarps(dotOp); constexpr unsigned mDim = 16; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 7afac143a8..0c59a3d71a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -361,6 +361,10 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW atomicRmwOp == RMWOp::FADD) { return rewriter.notifyMatchFailure(op, "RMW FADD does not support bf16"); } + if (isaFamily == ISAFamily::RDNA4 && checkType.isF64() && + atomicRmwOp == RMWOp::FADD) { + return rewriter.notifyMatchFailure(op, "RMW FADD does not support F64"); + } LDBG("RMW FADD supported 16-bit type"); auto vecSize = getVectorSize(ptr, axisAnalysisPass); @@ -624,7 +628,8 @@ struct TritonAMDGPUConvertToBufferOpsPass triton::AMD::ISAFamily isaFamily = triton::AMD::deduceISAFamily(archGenerationName); if (this->allowBufferAtomics && - (ISAFamily::CDNA3 == isaFamily || ISAFamily::CDNA4 == isaFamily)) + (ISAFamily::CDNA3 == isaFamily || ISAFamily::CDNA4 == isaFamily || + ISAFamily::RDNA4 == isaFamily)) patterns.add( context, assumptions, axisInfoAnalysis, solver, isaFamily, this->analyzeSmallTensorOfst); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp index ea54600c24..2b0966c16f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/LowerLoops.cpp @@ -718,7 +718,7 @@ void updateSchedule(scf::ForOp &forOp, const LoadToInfoMap &loadToInfo, useAsyncCopy, axisInfoAnalysis); scheduleStreamOps(loadToStreamOps, schedule, clusters); - for (auto [l, _] : loadToInfo) { + for (auto [l, _] : loadToStreamOps) { schedule.erase(l); l->erase(); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeDotOperands.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeDotOperands.cpp index fcfa358b41..32a05282dd 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeDotOperands.cpp @@ -9,7 +9,6 @@ #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Tools/LayoutUtils.h" #include "triton/Tools/LinearLayout.h" diff --git a/third_party/amd/python/test/test_convert_op_permlane_swap.py b/third_party/amd/python/test/test_convert_op_permlane_swap.py index 835f39f0b5..0ac77af66e 100644 --- a/third_party/amd/python/test/test_convert_op_permlane_swap.py +++ b/third_party/amd/python/test/test_convert_op_permlane_swap.py @@ -25,20 +25,17 @@ def __str__(self): class BlockedLayout: - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): self.sz_per_thread = size_per_thread self.threads_per_warp = threads_per_warp self.warps_per_cta = warps_per_cta self.order = order - self.ctas_per_cga = ctas_per_cga - self.cta_split_num = cta_split_num - self.cta_order = cta_order def __str__(self): - return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" -src_layouts = [BlockedLayout([1, 1], [1, 64], [1, 1], [0, 1], [1, 1], [1, 1], [0, 1])] +src_layouts = [BlockedLayout([1, 1], [1, 64], [1, 1], [0, 1])] dst_layouts = [ LinearLayout([[0, 32]], [[0, 1], [0, 2], [0, 4], [0, 8], [0, 16], [1, 0]], [], []), diff --git a/third_party/amd/python/test/test_extract_slice_concat_op.py b/third_party/amd/python/test/test_extract_slice_concat_op.py index 0f30b8e683..82d0f29584 100644 --- a/third_party/amd/python/test/test_extract_slice_concat_op.py +++ b/third_party/amd/python/test/test_extract_slice_concat_op.py @@ -6,8 +6,6 @@ from triton._internal_testing import is_hip -num_ctas_list = [1] - GPU_DIALECT = "ttg" if is_hip(): @@ -30,17 +28,14 @@ def __str__(self): class BlockedLayout: - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order): self.sz_per_thread = size_per_thread self.threads_per_warp = threads_per_warp self.warps_per_cta = warps_per_cta self.order = order - self.ctas_per_cga = ctas_per_cga - self.cta_split_num = cta_split_num - self.cta_order = cta_order def __str__(self): - return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>" # ----------------------- @@ -50,26 +45,25 @@ def __str__(self): regs2x2 = [[1, 0], [0, 1]] lanes8x8 = [[2, 0], [4, 0], [8, 0], [0, 2], [0, 4], [0, 8]] warps2x2 = [[16, 0], [0, 16]] -cta_layout = [[1, 1], [1, 1], [0, 1]] redundant_ll = LinearLayout([[0, 0]] + regs2x2, lanes8x8, warps2x2, block=[]) non_redundant_ll = LinearLayout(regs2x2, lanes8x8, warps2x2, block=[]) # list of pairs defining ExtractSliceOp input and output layouts extract_layout = [ - (BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], *cta_layout), ) * 2, - (BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], *cta_layout), ) * 2, - (BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], *cta_layout), ) * 2, - (BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], *cta_layout), ) * 2, - (BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], *cta_layout), ) * 2, + (BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0]), ) * 2, + (BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0]), ) * 2, + (BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1]), ) * 2, + (BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0]), ) * 2, + (BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1]), ) * 2, (redundant_ll, non_redundant_ll), (non_redundant_ll, redundant_ll), ] blocked_layout = [ - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], *cta_layout), - BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], *cta_layout), - BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], *cta_layout), - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], *cta_layout), - BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], *cta_layout), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1]), ] @@ -140,8 +134,7 @@ def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_t # test concat op # ----------------------- -cta_layout = [[1, 1], [1, 1], [0, 1]] -blocked_32x32 = BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1], *cta_layout) +blocked_32x32 = BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1]) broadcasted_32x32 = LinearLayout(register=[[0, 0], [1, 0], [0, 1]], lane=[[2, 0], [4, 0], [8, 0], [0, 2], [0, 4], [0, 8]], warp=[[16, 0], [0, 16]], block=[]) @@ -173,7 +166,7 @@ def test_concat_op(dtype, M, N, M_tile_size, N_tile_size, src_layout, dst_layout pytest.skip("concat op is AMD specific instruction.") ir = f""" - #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[16, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}> + #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[16, 4], warpsPerCTA=[4, 1], order=[1, 0]}}> #src_layout = {src_layout} #dst_layout = {dst_layout} diff --git a/third_party/amd/python/test/test_gluon_gfx1250.py b/third_party/amd/python/test/test_gluon_gfx1250.py index 1cb3f550e6..b69cd529af 100644 --- a/third_party/amd/python/test/test_gluon_gfx1250.py +++ b/third_party/amd/python/test/test_gluon_gfx1250.py @@ -1437,19 +1437,19 @@ def cluster_load_and_write_back_kernel(a_ptr, out_ptr, M, N, BLOCK_M: ttgl.const @pytest.mark.parametrize("blocked_layout", [ ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 2], order=[1, 0], - ctas_per_cga=[1, 2], cta_split_num=[1, 2]), + cga_layout=[[0, 1]]), ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[2, 2], order=[1, 0], - ctas_per_cga=[2, 1], cta_split_num=[2, 1]), + cga_layout=[[1, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 1], threads_per_warp=[4, 8], warps_per_cta=[4, 1], order=[1, 0], - ctas_per_cga=[4, 4], cta_split_num=[1, 4]), + cga_layout=[[1, 0], [2, 0], [0, 0], [0, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 2], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0], - ctas_per_cga=[4, 4], cta_split_num=[2, 2]), + cga_layout=[[0, 1], [0, 0], [1, 0], [0, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[4, 8], warps_per_cta=[2, 2], order=[1, 0], - ctas_per_cga=[4, 4], cta_split_num=[1, 4]), + cga_layout=[[1, 0], [2, 0], [0, 0], [0, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 4], order=[1, 0], - ctas_per_cga=[4, 4], cta_split_num=[2, 2]), + cga_layout=[[0, 1], [0, 0], [1, 0], [0, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 16], threads_per_warp=[4, 8], warps_per_cta=[2, 2], order=[1, 0], - ctas_per_cga=[2, 8], cta_split_num=[1, 8]), + cga_layout=[[0, 1], [0, 2], [0, 4], [0, 0]]), ]) @pytest.mark.parametrize("dtype", [ # Test from 1 byte -> 8 bytes dtypes @@ -1460,7 +1460,7 @@ def test_runtime_cluster_load(blocked_layout, dtype): N = 128 BLOCK_M = 64 BLOCK_N = 64 - num_ctas = blocked_layout.ctas_per_cga[0] * blocked_layout.ctas_per_cga[1] + num_ctas = 2**len(blocked_layout.cga_layout) if dtype == torch.float8_e4m3fn: # range from min normal (0 00001 00) to max normal (0 11110 11) @@ -1567,28 +1567,26 @@ def test_runtime_async_copy(M, N, vec_size, shared_layout, dtype): @pytest.mark.parametrize("blocked_layout", [ + ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0]), ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0], - ctas_per_cga=[1, 1], cta_split_num=[1, 1]), + cga_layout=[[0, 1]]), ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0], - ctas_per_cga=[1, 2], cta_split_num=[1, 2]), + cga_layout=[[1, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0], - ctas_per_cga=[2, 1], cta_split_num=[2, 1]), + cga_layout=[[0, 1], [0, 2], [0, 0], [0, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0], - ctas_per_cga=[4, 4], cta_split_num=[1, 4]), + cga_layout=[[0, 1], [0, 0], [1, 0], [0, 0]]), ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0], - ctas_per_cga=[4, 4], cta_split_num=[2, 2]), - ttgl.BlockedLayout(size_per_thread=[1, 8], threads_per_warp=[4, 8], warps_per_cta=[1, 1], order=[1, 0], - ctas_per_cga=[2, 8], cta_split_num=[1, 8]), + cga_layout=[[0, 1], [0, 2], [0, 4], [0, 0]]), ]) def test_runtime_async_copy_layouts_multi_cta(blocked_layout): M = 1024 N = 1024 BLOCK_M = 128 BLOCK_N = 128 - num_ctas = blocked_layout.ctas_per_cga[0] * blocked_layout.ctas_per_cga[1] + num_ctas = 2**len(blocked_layout.cga_layout) - shared_layout = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0], blocked_layout.ctas_per_cga, - blocked_layout.cta_split_num) + shared_layout = ttgl.SwizzledSharedLayout(1, 1, 1, [1, 0], blocked_layout.cga_layout) a = torch.rand((M, N), dtype=torch.float32) out = torch.empty_like(a) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 95256135d5..1ae60949c5 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -310,7 +310,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", let parameters = ( ins ArrayRefParameter<"unsigned">:$warpsPerCTA, - "CTALayoutAttr":$CTALayout, + "CTAEncodingAttr":$CTALayout, ArrayRefParameter<"unsigned">:$instrShape, "unsigned":$numBlocks, ArrayRefParameter<"unsigned">:$order, diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 67c6fd890b..166d8919b4 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -146,22 +146,12 @@ DpasEncodingAttr::getRepOrderForOperand(OpIdx opIdx) const { return getOrderForDotOperand(unsigned(opIdx), rank, /*kMajor*/ true); } -SmallVector DpasEncodingAttr::getCTASplitNum() const { +CTAEncodingAttr DpasEncodingAttr::getCTALayout() const { size_t rank = getWarpsPerCTA().size(); - SmallVector res(rank, 1); - return res; -} - -SmallVector DpasEncodingAttr::getCTAOrder() const { - size_t rank = getWarpsPerCTA().size(); - auto res = llvm::to_vector(llvm::reverse(llvm::seq(rank))); - return res; -} - -SmallVector DpasEncodingAttr::getCTAsPerCGA() const { - size_t rank = getWarpsPerCTA().size(); - SmallVector res(rank, 1); - return res; + SmallVector CTAsPerCGA(rank, 1); + auto CTAOrder = llvm::to_vector(llvm::reverse(llvm::seq(rank))); + return CTAEncodingAttr::fromSplitParams(getContext(), CTAsPerCGA, CTAsPerCGA, + CTAOrder); } SmallVector @@ -441,16 +431,8 @@ LinearLayout WarpEncodingAttr::toLinearLayout(ArrayRef shape) const { llvm::report_fatal_error("NYI. WarpEncodingAttr::toLinearLayout"); } -SmallVector WarpEncodingAttr::getCTAsPerCGA() const { - llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTAsPerCGA"); -} - -SmallVector WarpEncodingAttr::getCTAOrder() const { - llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTAOrder"); -} - -SmallVector WarpEncodingAttr::getCTASplitNum() const { - llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTASplitNum"); +CTAEncodingAttr WarpEncodingAttr::getCTALayout() const { + llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTALayout"); } Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) { @@ -506,16 +488,16 @@ void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const { //===----------------------------------------------------------------------===// namespace { -std::optional getCTALayoutOrError( +std::optional getCTALayoutOrError( AsmParser &parser, std::optional> CTAsPerCGA, std::optional> CTASplitNum, std::optional> CTAOrder, unsigned rank) { if (CTAsPerCGA && CTASplitNum && CTAOrder) { - return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum, - *CTAOrder); + return CTAEncodingAttr::fromSplitParams(parser.getContext(), *CTAsPerCGA, + *CTASplitNum, *CTAOrder); } if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) { - return CTALayoutAttr::getDefault(parser.getContext(), rank); + return CTAEncodingAttr::getDefault(parser.getContext(), rank); } parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder " "must all be present or all be absent"); @@ -524,8 +506,8 @@ std::optional getCTALayoutOrError( // Print the CTALayout if it's not equal to the default. void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer, - CTALayoutAttr layout, unsigned rank) { - if (layout != CTALayoutAttr::getDefault(context, rank)) { + CTAEncodingAttr layout, unsigned rank) { + if (layout != CTAEncodingAttr::getDefault(context, rank)) { printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]" << ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]" << ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]"; @@ -536,7 +518,7 @@ void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer, LogicalResult Subgroup2DBlockEncodingAttr::verify( function_ref emitError, - ArrayRef warpsPerCTA, CTALayoutAttr CTALayout, + ArrayRef warpsPerCTA, CTAEncodingAttr CTALayout, ArrayRef instrShape, unsigned numBlocks, ArrayRef order, unsigned kWidth, unsigned threadsPerWarp) { if (instrShape.size() != 2) { @@ -621,7 +603,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { } } - std::optional CTALayout = getCTALayoutOrError( + std::optional CTALayout = getCTALayoutOrError( parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size()); if (!CTALayout.has_value()) return {}; @@ -898,8 +880,10 @@ struct TritonIntelGPUInferLayoutInterface // Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA // should be like the other fields in blocked encoding, but I'm not sure how // to handle CTASplitNum. - if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || - !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { + if (!all_of(src.getCTALayout().getCTAsPerCGA(), + [](int32_t x) { return x == 1; }) || + !all_of(src.getCTALayout().getCTASplitNum(), + [](int32_t x) { return x == 1; })) { return failure(); } @@ -1074,7 +1058,7 @@ struct TritonIntelGPUInferLayoutInterface auto dstOrder = inversePermutation(dstInvOrder); // CTALayout can be all 1's because we bailed on multi-CTA layouts above. - auto CTALayout = CTALayoutAttr::get( + auto CTALayout = CTAEncodingAttr::fromSplitParams( src.getContext(), /*CTAsPerCGA=*/SmallVector(dstShape.size(), 1), /*CTASplitNum=*/SmallVector(dstShape.size(), 1), diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index dea15a8f28..f9f3299d31 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -24,10 +24,10 @@ namespace { // for register layouts, and input dims [offset] for shared layouts. // - cgaLayout: Arrangement of multiple blocks, i.e. input dims [block]. // -// Note that this is inconsistent with the type name CTALayoutAttr. That type +// Note that this is inconsistent with the type name CTAEncodingAttr. That type // is equivalent to our cgaLayout. // -// IMO the name CTALayoutAttr is wrong. If we tried to be consistent anyway, +// IMO the name CTAEncodingAttr is wrong. If we tried to be consistent anyway, // then we'd have to rename ctaLayout to "warpLayout". I think that's more // confusing than being inconsistent about "cgaLayout", especially when we have // to consider the size of the warpLayout (surely that's not the "warpSize"). @@ -57,8 +57,8 @@ LinearLayout identityND(StringAttr inDimName, ArrayRef shape, // the CTAsPerCGA CTAs (i.e. blocks) in the CGA (i.e. groups). // // See the nomenclature note at the top of the file for an explanation of why -// this is called makeCgaLayout when it accepts a CTALayoutAttr. -LinearLayout makeCgaLayout(CTALayoutAttr layout) { +// this is called makeCgaLayout when it accepts a CTAEncodingAttr. +LinearLayout makeCgaLayout(CTAEncodingAttr layout) { MLIRContext *ctx = layout.getContext(); StringAttr kBlock = S("block"); @@ -464,7 +464,7 @@ LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]); return combineCtaCgaWithShape(std::move(tileLayout), - CTALayoutAttr::getDefault(ctx, rank), shape); + CTAEncodingAttr::getDefault(ctx, rank), shape); } LinearLayout dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 9062aef7ad..c45655d1e6 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1076,7 +1076,7 @@ struct PrefetchOpConversion identityStandardND(S("warp"), warpsPerCTA, order); return combineCtaCgaWithShape(std::move(ctaLayout), - CTALayoutAttr::getDefault(ctx, rank), + CTAEncodingAttr::getDefault(ctx, rank), tensorShape); } }; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h index f42149be2c..4773dc724d 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVMBase.h @@ -21,7 +21,6 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::SharedMemoryObject; using ::mlir::triton::gpu::BlockedEncodingAttr; -using ::mlir::triton::gpu::CTALayoutAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; using ::mlir::triton::gpu::intel::DpasEncodingAttr; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp index 94292367cb..fcbcff5bde 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeDotOperands.cpp @@ -16,7 +16,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" -#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 2aa5224bfd..b8e863221c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -349,7 +349,7 @@ struct DpasOperandPattern final : OpRewritePattern { 1, dpasEncoding.getWarpsPerCTA()[0]}; constexpr std::array order{0, 1, 2, 3, 4, 5, 6}; - CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); @@ -407,7 +407,7 @@ struct DpasOperandPattern final : OpRewritePattern { dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; constexpr std::array order{0, 1, 2, 3, 4}; - CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); @@ -440,7 +440,7 @@ struct DpasOperandPattern final : OpRewritePattern { dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; constexpr std::array order{0, 1, 2, 3}; - CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank); auto encoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); @@ -483,7 +483,7 @@ struct DpasOperandPattern final : OpRewritePattern { std::array warpsPerCTA{dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; constexpr std::array order{0, 1}; - CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + CTAEncodingAttr ctaLayout = CTAEncodingAttr::getDefault(getContext(), rank); auto parentEncoding = rewriter.getAttr( sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); diff --git a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp index 2e1eee2adf..30d409c1ef 100644 --- a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp +++ b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp @@ -31,11 +31,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { // TODO: could put the getOrderForDotOperand in the builder? auto layout = Subgroup2DBlockEncodingAttr::get( - &ctx, warpsPerCTA, - CTALayoutAttr::get( - &ctx, dpasLayout.getCTAsPerCGA(), // TODO: add to DpasLayout? - dpasLayout.getCTASplitNum(), dpasLayout.getCTAOrder()), - instrShape, numBlocks, + &ctx, warpsPerCTA, dpasLayout.getCTALayout(), instrShape, numBlocks, getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth, dpasLayout.getThreadsPerWarp()); return layout; diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp index fd68da78ee..f5fda09481 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSCodePartition.cpp @@ -540,9 +540,7 @@ static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { triton::gpu::SharedMemorySpaceAttr::get(funcOp.getContext()); Location loc = funcOp.getLoc(); auto context = funcOp.getContext(); - auto barrierCTALayout = - ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, - /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierCTALayout = ttg::CTAEncodingAttr::getDefault(context, 1); auto barrierEncoding = ttg::SwizzledSharedEncodingAttr::get( context, 1, 1, 1, {0}, barrierCTALayout); Type barrierMemDescType = ttg::MemDescType::get( diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp index 31c62dc24f..2ae84ed4af 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerToken.cpp @@ -113,9 +113,7 @@ void lowerTokenOperations(Operation *parentOp, int numCTAs, Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(context); - auto barrierCTALayout = - ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, - /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierCTALayout = ttg::CTAEncodingAttr::getDefault(context, 1); auto barrierEncoding = ttg::SwizzledSharedEncodingAttr::get( context, 1, 1, 1, {0}, barrierCTALayout); Type barrierMemDescType = ttg::MemDescType::get( diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h index 60b0d380f6..aa8b682719 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAHelpers.h @@ -192,7 +192,7 @@ class DotOpMmaSmemLoader : public DotOpMmaMemLoader { auto kOffset = str_attr("offset"); // Any CTALayout, it's not really used within getCoreMatrixLinearLayout - auto CTALayout = triton::gpu::CTALayoutAttr::getDefault(ctx, 2); + auto CTALayout = triton::gpu::CTAEncodingAttr::getDefault(ctx, 2); for (bool fp4Padded : (bitwidth == 4 ? SmallVector({false, true}) : SmallVector({false}))) { diff --git a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp index f314e335a5..11a21f9f50 100644 --- a/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp +++ b/third_party/proton/Dialect/lib/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/AMDPatternProtonGPUOpToLLVM.cpp @@ -2,8 +2,10 @@ #include "Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/TargetInfo.h" #include "Conversion/ProtonGPUToLLVM/Utility.h" #include "Dialect/ProtonGPU/IR/Dialect.h" +#include "amd/lib/TritonAMDGPUToLLVM/Utility.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -37,7 +39,8 @@ struct CircularStoreOpConversion // TODO(crobeck): see what buffer ops performance looks like here for // global mem (address space 1) compared to predicated ops to shared // memory - llvm::report_fatal_error("unimplemented"); + mlir::LLVM::AMD::llStore(rewriter, loc, dataPack.ptr, dataPack.record, + dataPack.isWriter); } else if (addrSpace == 3) { targetInfo.getTritonTargetInfo().storeDShared( rewriter, loc, dataPack.ptr, std::nullopt, dataPack.record, diff --git a/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp b/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp index bf1a93c64a..127b4c2944 100644 --- a/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp +++ b/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp @@ -282,9 +282,7 @@ class ConvertProtonToProtonGPUPass Value segment; Value buffer; if (bufferType == gpu::BufferType::SHARED) { - auto ctaLayout = triton::gpu::CTALayoutAttr::get( - context, /*CTAsPerCGA=*/{1}, - /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto ctaLayout = triton::gpu::CTAEncodingAttr::getDefault(context, 1); auto encoding = triton::gpu::SwizzledSharedEncodingAttr::get( context, 1, 1, 1, {0}, ctaLayout); Attribute sharedMemorySpace = diff --git a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp index 19c0f8de1e..28ad28d92a 100644 --- a/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/RocTracer/RoctracerProfiler.cpp @@ -350,19 +350,22 @@ void RoctracerProfiler::RoctracerProfilerPimpl::activityCallback( // data on stop maxCorrelationId = std::max(maxCorrelationId, record->correlation_id); - // TODO(Keren): Roctracer doesn't support cuda graph yet. + bool hasCorrelation = + correlation.corrIdToExternId.contain(record->correlation_id); auto externId = - correlation.corrIdToExternId.contain(record->correlation_id) + hasCorrelation ? correlation.corrIdToExternId.at(record->correlation_id).first : Scope::DummyScopeId; auto isAPI = correlation.apiExternIds.contain(externId); bool isGraph = pImpl->CorrIdToIsHipGraph.contain(record->correlation_id); - processActivity(correlation.corrIdToExternId, correlation.apiExternIds, - externId, dataSet, record, isAPI, isGraph); - // Track correlation ids from the same stream and erase those < - // correlationId - correlation.corrIdToExternId.erase(record->correlation_id); - correlation.apiExternIds.erase(externId); + if (hasCorrelation) { + processActivity(correlation.corrIdToExternId, correlation.apiExternIds, + externId, dataSet, record, isAPI, isGraph); + // Track correlation ids from the same stream and erase those < + // correlationId + } else { + correlation.apiExternIds.erase(externId); + } roctracer::getNextRecord(record, &record); } correlation.complete(maxCorrelationId); diff --git a/third_party/proton/test/test_instrumentation.py b/third_party/proton/test/test_instrumentation.py index 388a587baf..271b0ff835 100644 --- a/third_party/proton/test/test_instrumentation.py +++ b/third_party/proton/test/test_instrumentation.py @@ -15,7 +15,6 @@ is_cuda, is_hip, is_hip_cdna2, - is_hip_cdna4, supports_tma, supports_ws, ) @@ -644,7 +643,6 @@ def foo(x, y, size: tl.constexpr): assert trace_events[-1]["args"]["call_stack"][-2] == "test" -@pytest.mark.skipif(is_hip_cdna4(), reason="nondeterministic failure") def test_globaltime(tmp_path: pathlib.Path): temp_file = tmp_path / "test_globaltime.chrome_trace" mode = proton.mode.Default( @@ -760,7 +758,6 @@ def session_kernel_time(session_name: str) -> Tuple[int, int]: assert session1_loop_time / session0_loop_time < loop_threshold, "Loop kernel overhead too high" -@pytest.mark.skipif(is_hip(), reason="not implemented yet") def test_gmem_buffer(tmp_path: pathlib.Path): @triton.jit diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index 489064c3e7..9071c394a8 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -80,7 +80,6 @@ def foo(x, y): assert data[0]["children"][1]["frame"]["name"] == "test2" -@pytest.mark.skipif(is_hip(), reason="Currently broken after updating to ROCm 7") def test_cudagraph(tmp_path: pathlib.Path, device: str): if is_xpu(): pytest.skip("xpu doesn't support cudagraph; FIXME: double check") diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index cab4ad898b..1d9abd223f 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -35,9 +35,10 @@ createDistributedEncodings(MLIRContext &ctx) { // Define a tensor shape auto rank = 2; SmallVector> orders = {{0, 1}, {1, 0}}; - SmallVector ctaLayouts = { - triton::gpu::CTALayoutAttr::getDefault(&ctx, rank), - triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}), + SmallVector ctaLayouts = { + triton::gpu::CTAEncodingAttr::getDefault(&ctx, rank), + triton::gpu::CTAEncodingAttr::fromSplitParams(&ctx, {4, 2}, {2, 2}, + {1, 0}), }; std::vector distributedEncodings; @@ -478,8 +479,8 @@ class AMDLayoutTest : public ::testing::Test { public: AMDLayoutTest() { ctx.getOrLoadDialect(); - ctaLayout = - triton::gpu::CTALayoutAttr::get(&ctx, ctaPerCGA, ctaSplit, ctaOrder); + ctaLayout = triton::gpu::CTAEncodingAttr::fromSplitParams( + &ctx, ctaPerCGA, ctaSplit, ctaOrder); f16Ty = Float16Type::get(&ctx); } @@ -493,7 +494,7 @@ class AMDLayoutTest : public ::testing::Test { const SmallVector ctaPerCGA{1, 1, 1}; const SmallVector ctaSplit{1, 1, 1}; const SmallVector ctaOrder{2, 1, 0}; - triton::gpu::CTALayoutAttr ctaLayout; + triton::gpu::CTAEncodingAttr ctaLayout; Type f16Ty; }; @@ -582,14 +583,17 @@ TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { // SliceEncoding is not well-defined for CGAs if (!isa(distributedEncoding)) { auto baseEncoding = cast(distributedEncoding); - ASSERT_EQ(baseEncoding.getCTASplitNum(), - linearEncoding.getCTASplitNum()); - ASSERT_EQ(baseEncoding.getCTAsPerCGA(), baseEncoding.getCTAsPerCGA()); + auto baseCTALayout = baseEncoding.getCTALayout(); + auto linearCTALayout = linearEncoding.getCTALayout(); + ASSERT_EQ(baseCTALayout.getCTASplitNum(), + linearCTALayout.getCTASplitNum()); + ASSERT_EQ(baseCTALayout.getCTAsPerCGA(), + linearCTALayout.getCTAsPerCGA()); // If we are not using CGAs, the order is meaningless auto useCGA = - baseEncoding.getCTAsPerCGA() != SmallVector(rank, 1); + baseCTALayout.getCTAsPerCGA() != SmallVector(rank, 1); if (useCGA && !is_dot_op_with_block_parent(distributedEncoding)) { - ASSERT_EQ(baseEncoding.getCTAOrder(), linearEncoding.getCTAOrder()); + ASSERT_EQ(baseCTALayout.getCTAOrder(), linearCTALayout.getCTAOrder()); } } } diff --git a/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp b/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp index 31d595faf5..4eac53067d 100644 --- a/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp +++ b/unittest/Dialect/TritonGPU/DumpLayoutTest.cpp @@ -18,7 +18,8 @@ class DumpLayoutTest : public ::testing::Test { ArrayRef cSplit, ArrayRef ord, ArrayRef cOrd) { return BlockedEncodingAttr::get( - &ctx, spt, tpw, wpb, ord, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + &ctx, spt, tpw, wpb, ord, + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd)); } SwizzledSharedEncodingAttr shared(unsigned vec, unsigned perPhase, @@ -28,7 +29,7 @@ class DumpLayoutTest : public ::testing::Test { ArrayRef cOrd) { return SwizzledSharedEncodingAttr::get( &ctx, vec, perPhase, maxPhase, ord, - CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd)); } void assertSameStr(const std::string &refStr, const std::string &output) { diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index 7e46792429..62633327af 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -33,7 +33,8 @@ class LinearLayoutConversionsTest : public ::testing::Test { ArrayRef cSplit, ArrayRef ord, ArrayRef cOrd) { return BlockedEncodingAttr::get( - &ctx, spt, tpw, wpb, ord, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + &ctx, spt, tpw, wpb, ord, + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd)); } NvidiaMmaEncodingAttr mma(unsigned versionMaj, unsigned versionMin, @@ -43,13 +44,13 @@ class LinearLayoutConversionsTest : public ::testing::Test { ArrayRef cOrd) { return NvidiaMmaEncodingAttr::get( &ctx, versionMaj, versionMin, wbp, - CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), instrShape); + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd), instrShape); } NvidiaMmaEncodingAttr mma(unsigned versionMaj, unsigned versionMin, ArrayRef instrShape, ArrayRef numWarps) { - auto ctaLayout = CTALayoutAttr::getDefault(&ctx, numWarps.size()); + auto ctaLayout = CTAEncodingAttr::getDefault(&ctx, numWarps.size()); return NvidiaMmaEncodingAttr::get(&ctx, versionMaj, versionMin, numWarps, std::move(ctaLayout), instrShape); } @@ -67,7 +68,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { SmallVector cOrd(warps.size()); std::iota(cOrd.begin(), cOrd.end(), 0); - auto ctaLayout = CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd); + auto ctaLayout = CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd); return AMDMfmaEncodingAttr::get(&ctx, version, warps, instrShape, isTransposed, ctaLayout, tilesPerWarp, elementBitWidth); @@ -86,9 +87,9 @@ class LinearLayoutConversionsTest : public ::testing::Test { SmallVector cSplit(warps.size(), 1u); SmallVector cOrd(warps.size()); std::iota(cOrd.begin(), cOrd.end(), 0); - return AMDWmmaEncodingAttr::get(&ctx, version, transposed, warps, - CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd), - instrShape); + return AMDWmmaEncodingAttr::get( + &ctx, version, transposed, warps, + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd), instrShape); } DotOperandEncodingAttr wmmaDotOp(AMDWmmaEncodingAttr wmma, unsigned opIdx, @@ -107,7 +108,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { ArrayRef cOrd) { return SwizzledSharedEncodingAttr::get( &ctx, vec, perPhase, maxPhase, ord, - CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd)); } NVMMASharedEncodingAttr @@ -117,7 +118,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { ArrayRef cOrd, bool fp4Padded = false) { return NVMMASharedEncodingAttr::get( &ctx, swizzleSizeInBytes, transposed, elementBitWidth, fp4Padded, - CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd)); } AMDRotatingSharedEncodingAttr @@ -126,7 +127,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { ArrayRef ord, ArrayRef cOrd) { return AMDRotatingSharedEncodingAttr::get( &ctx, vec, perPhase, maxPhase, ord, - CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd)); + CTAEncodingAttr::fromSplitParams(&ctx, cpg, cSplit, cOrd)); } TensorMemoryEncodingAttr tmem(unsigned blockM, unsigned blockN, @@ -3328,7 +3329,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{128, 2}, /*opIdx=*/0, /*warpsPerCTA=*/{1, 1}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout({{S("register"), {{0, 1}, {16, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, {S("warp"), {}}, @@ -3339,7 +3341,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{128, 2}, /*opIdx=*/1, /*warpsPerCTA=*/{1, 1}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {8, 0}, {16, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3351,7 +3354,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{128, 4}, /*opIdx=*/0, /*warpsPerCTA=*/{2, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout({{S("register"), {{0, 1}, {0, 2}, {32, 0}, {64, 0}}}, {S("lane"), {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, {S("warp"), {{0, 0}, {16, 0}}}, @@ -3362,7 +3366,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{256, 4}, /*opIdx=*/1, /*warpsPerCTA=*/{1, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, {S("lane"), {{0, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3374,7 +3379,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{128, 8}, /*opIdx=*/0, /*warpsPerCTA=*/{2, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {32, 0}, {64, 0}}}, {S("lane"), {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3386,7 +3392,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{128, 8}, /*opIdx=*/1, /*warpsPerCTA=*/{2, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {16, 0}, {32, 0}, {64, 0}}}, {S("lane"), {{0, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3398,7 +3405,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{256, 2}, /*opIdx=*/0, /*warpsPerCTA=*/{1, 1}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, {S("lane"), {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3410,7 +3418,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{256, 2}, /*opIdx=*/1, /*warpsPerCTA=*/{1, 1}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {8, 0}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, {S("lane"), {{0, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3422,7 +3431,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{256, 4}, /*opIdx=*/0, /*warpsPerCTA=*/{2, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {32, 0}, {64, 0}, {128, 0}}}, {S("lane"), {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3434,7 +3444,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{256, 4}, /*opIdx=*/1, /*warpsPerCTA=*/{1, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, {S("lane"), {{0, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3446,7 +3457,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{256, 8}, /*opIdx=*/0, /*warpsPerCTA=*/{2, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {32, 0}, {64, 0}, {128, 0}}}, {S("lane"), {{8, 0}, {0, 0}, {1, 0}, {2, 0}, {4, 0}}}, @@ -3458,7 +3470,8 @@ TEST_F(LinearLayoutConversionsTest, SM120DotScaledScaleLayout) { layout = getSM120DotScaledScaleLayout( &ctx, /*shape=*/{256, 8}, /*opIdx=*/1, /*warpsPerCTA=*/{2, 2}, - /*ctaLayout=*/CTALayoutAttr::get(&ctx, {1, 1}, {1, 1}, {1, 0})); + /*ctaLayout=*/ + CTAEncodingAttr::fromSplitParams(&ctx, {1, 1}, {1, 1}, {1, 0})); ll = LinearLayout( {{S("register"), {{0, 1}, {0, 2}, {0, 4}, {16, 0}, {32, 0}, {64, 0}, {128, 0}}}, diff --git a/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/unittest/Dialect/TritonGPU/SwizzleTest.cpp index 240daade5d..0ec72915d9 100644 --- a/unittest/Dialect/TritonGPU/SwizzleTest.cpp +++ b/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -73,8 +73,11 @@ class BankConflictTest : public ::testing::Test { splitStorage.assign(spt.size(), 1); if (cOrder.empty()) cOrderStorage.assign(order.begin(), order.end()); - - auto cta = mlir::triton::gpu::CTALayoutAttr::get( + auto test = mlir::triton::gpu::CTAEncodingAttr::fromSplitParams( + &ctx, {1, 1}, {1, 1}, {0, 1}); + llvm::errs() << "HERE\n"; + llvm::errs() << test.getLinearLayout().toString(); + auto cta = mlir::triton::gpu::CTAEncodingAttr::fromSplitParams( &ctx, cpgStorage.empty() ? cpg : ArrayRef(cpgStorage), splitStorage.empty() ? split : ArrayRef(splitStorage), cOrderStorage.empty() ? cOrder : ArrayRef(cOrderStorage)); @@ -85,8 +88,8 @@ class BankConflictTest : public ::testing::Test { mlir::triton::gpu::NvidiaMmaEncodingAttr mma(ArrayRef version, ArrayRef warpsPerCTA, ArrayRef instrShape) { - auto cta = - mlir::triton::gpu::CTALayoutAttr::getDefault(&ctx, warpsPerCTA.size()); + auto cta = mlir::triton::gpu::CTAEncodingAttr::getDefault( + &ctx, warpsPerCTA.size()); return mlir::triton::gpu::NvidiaMmaEncodingAttr::get( &ctx, version[0], version[1], warpsPerCTA, cta, instrShape); } @@ -96,7 +99,8 @@ class BankConflictTest : public ::testing::Test { bool transposed = false) { SmallVector cpg(rank, 1), split(rank, 1), order(rank); std::iota(order.begin(), order.end(), 0); - auto cta = mlir::triton::gpu::CTALayoutAttr::get(&ctx, cpg, split, order); + auto cta = mlir::triton::gpu::CTAEncodingAttr::fromSplitParams( + &ctx, cpg, split, order); return mlir::triton::gpu::NVMMASharedEncodingAttr::get( &ctx, swizzle, transposed, bitwidth, /*fp4Padded=*/false, cta); @@ -302,13 +306,13 @@ TEST_F(BankConflictTest, bankConflicts) { {blocked({1}, {32}, {4}, {0}), mlir::triton::gpu::SwizzledSharedEncodingAttr::get( &ctx, 1, 1, 1, {0}, - mlir::triton::gpu::CTALayoutAttr::getDefault(&ctx, 1)), + mlir::triton::gpu::CTAEncodingAttr::getDefault(&ctx, 1)), {32}, 32}, {blocked({1}, {32}, {4}, {0}), mlir::triton::gpu::SwizzledSharedEncodingAttr::get( &ctx, 1, 1, 1, {0}, - mlir::triton::gpu::CTALayoutAttr::getDefault(&ctx, 1)), + mlir::triton::gpu::CTAEncodingAttr::getDefault(&ctx, 1)), {32}, 16}, {mmaV3,