diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 2fec340d45..8f475b3e29 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -6,23 +6,61 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" //===----------------------------------------------------------------------===// -// TritonGPU Attribute Definitions +// Traits and Interfaces //===----------------------------------------------------------------------===// -def TritonGPU_AttrTrait : AttrInterface<"TritonGPU_AttrTrait"> { - let cppNamespace = "::mlir::triton::gpu"; +def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">; +def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">; + +def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; + let description = [{ + Common trait for all TTGIR layouts. + }]; let methods = [ + InterfaceMethod<"Get the shape of the CTAs per CGA.", + "SmallVector", + "getCTAsPerCGA", (ins), [{}], [{ + return llvm::to_vector($_attr.getCTALayout().getCTAsPerCGA()); + }]>, + InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", + "SmallVector", + "getCTAOrder", (ins), [{}], [{ + return llvm::to_vector($_attr.getCTALayout().getCTAOrder()); + }]>, + InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", + "SmallVector", + "getCTASplitNum", (ins), [{}], [{ + return llvm::to_vector($_attr.getCTALayout().getCTASplitNum()); + }]>, + InterfaceMethod<"Get the rank of the layout.", "unsigned", "getRank", (ins), [{}], [{ + return $_attr.getCTAOrder().size(); + }]> ]; } +def DeclareLayoutEncodingMethods : DeclareAttrInterfaceMethods< + LayoutEncodingTrait, ["getCTAsPerCGA", "getCTAOrder", "getCTASplitNum"]>; -def MemDescViewTrait : NativeOpTrait<"MemDescViewTrait">; +def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { + let cppNamespace = "::mlir::triton::gpu"; -def LocalLoadTrait : NativeOpTrait<"LocalLoadTrait">; + let description = [{ + Common trait describing shared memory. + }]; + let methods = [ + InterfaceMethod<"Return the default alignment for the layout.", + "int32_t", "getAlignment", (ins), [{}], [{ return 16; }]>, + ]; +} +def DeclareSharedEncodingMethods : DeclareAttrInterfaceMethods< + SharedEncodingTrait, ["getAlignment"]>; + +//===----------------------------------------------------------------------===// +// Base Attribute +//===----------------------------------------------------------------------===// -class TritonGPU_Attr traits = [], - Dialect dialect = TritonGPU_Dialect, - string baseCppClass = "::mlir::Attribute"> - : AttrDef { +class TritonGPU_Attr traits = [], Dialect dialect = TritonGPU_Dialect> + : AttrDef { let description = [{ TritonGPU tensors differ from usual tensors in that they contain a _layout_ attribute which determines @@ -123,51 +161,17 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to CTAOrder.push_back(i); return get(context, CTAsPerCGA, CTASplitNum, CTAOrder); } - unsigned getRank() const { - return getCTAOrder().size(); - } + unsigned getRank() const { return getCTAOrder().size(); } }]; let genVerifyDecl = 1; let skipDefaultBuilders = 1; } - -def LayoutEncodingTrait : AttrInterface<"LayoutEncodingTrait"> { - let cppNamespace = "::mlir::triton::gpu"; - let description = [{ - Common trait for all TTGIR layouts. - }]; - let methods = [ - InterfaceMethod<"Get the shape of the CTAs per CGA.", - "SmallVector", - "getCTAsPerCGA">, - InterfaceMethod<"Get the order of the CTAs per CGA. The fastest-changing axis first", - "SmallVector", - "getCTAOrder">, - InterfaceMethod<"Each CTA processes 1/CTASplitNum of the tensor.", - "SmallVector", - "getCTASplitNum">, - ]; -} - //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// -def SharedEncodingTrait : AttrInterface<"SharedEncodingTrait"> { - let cppNamespace = "::mlir::triton::gpu"; - - let description = [{ - Common trait describing shared memory. - }]; - let methods = [ - InterfaceMethod<"Return the default alignment for the layout.", - "int32_t", - "getAlignment">, - ]; -} - def SwizzledSharedEncodingAttr : TritonGPU_Attr<"SwizzledSharedEncoding", "swizzled_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { @@ -359,13 +363,6 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at }]>, ]; - let extraClassDeclaration = extraBaseClassDeclaration # [{ - unsigned getRank() const { return getCTAOrder().size(); } - int32_t getAlignment() const; - SmallVector getCTAsPerCGA() const; - SmallVector getCTAOrder() const; - SmallVector getCTASplitNum() const; - }]; let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; } @@ -433,9 +430,6 @@ attributes too, for example, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ - unsigned getRank() const { return getOrder().size(); } - int32_t getAlignment() const { return 16; } - unsigned getMinInterval() const { return *llvm::min_element(getIntervals()); } @@ -443,17 +437,12 @@ attributes too, for example, // Returns the total number of elements including padding given the input // tensor shape. int64_t getPaddedSize(ArrayRef shape) const; - - SmallVector getCTAsPerCGA() const; - SmallVector getCTAOrder() const; - SmallVector getCTASplitNum() const; }]; let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; } -def NVMMASharedEncodingAttr : - TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [SharedEncodingTrait, LayoutEncodingTrait]> { +def NVMMASharedEncodingAttr : TritonGPU_Attr<"NVMMASharedEncoding", "nvmma_shared_encoding", [DeclareSharedEncodingMethods, LayoutEncodingTrait]> { let mnemonic = "nvmma_shared"; let description = [{ @@ -513,11 +502,6 @@ def NVMMASharedEncodingAttr : ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ - unsigned getRank() const { return getCTAOrder().size(); } - int32_t getAlignment() const; - SmallVector getCTAsPerCGA() const; - SmallVector getCTAOrder() const; - SmallVector getCTASplitNum() const; int getPerPhase() const; int getMaxPhase() const; int getVec() const; @@ -619,13 +603,6 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1): "CTALayoutAttr":$CTALayout ); - let extraClassDeclaration = extraBaseClassDeclaration # [{ - unsigned getRank() const { return getCTAOrder().size(); } - int32_t getAlignment() const; - SmallVector getCTAsPerCGA() const; - SmallVector getCTAOrder() const; - SmallVector getCTASplitNum() const; - }]; let hasCustomAssemblyFormat = 1; } @@ -633,6 +610,7 @@ Swizzling examples (matrix is filled with numbers 0, 1, 2, .. columns*rows-1): //===----------------------------------------------------------------------===// // Distributed Layout Encoding //===----------------------------------------------------------------------===// + def DistributedEncodingTrait : AttrInterface<"DistributedEncodingTrait"> { let cppNamespace = "::mlir::triton::gpu"; @@ -719,12 +697,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; code extraDistributedDeclaration = extraBaseClassDeclaration # [{ - unsigned getRank() const { return getCTAOrder().size(); } // Implemented in subclasses SmallVector getRepOrder() const; - SmallVector getCTAsPerCGA() const; - SmallVector getCTAOrder() const; - SmallVector getCTASplitNum() const; LinearLayout toLinearLayout(ArrayRef shape) const; }]; @@ -739,7 +713,7 @@ def LinearLayoutParam : AttrOrTypeParameter<"LinearLayout", let cppAccessorType = "const LinearLayout &"; } -def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding", [DeclareLayoutEncodingMethods]> { let mnemonic = "linear"; let description = [{ @@ -1376,7 +1350,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: let hasCustomAssemblyFormat = 1; } -def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { +def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding", [DeclareLayoutEncodingMethods]> { let mnemonic = "slice"; let description = [{ @@ -1419,9 +1393,10 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding", "slice_encoding"> { }]; let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; } -def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding"> { +def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding", "dot_operand_encoding", [DeclareLayoutEncodingMethods]> { let mnemonic = "dot_op"; let description = [{ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index ccf2b6ef3d..8bb0a59f7b 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -27,11 +27,13 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" // TritonNvidiaGPU depends on Triton #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" namespace mlir::triton::nvidia_gpu::impl { @@ -61,13 +63,19 @@ struct TMemAllocation { TMemAllocation getTmemAllocSizes(gpu::MemDescType memDescType); -Attribute getTmemCompatibleLayout(unsigned M, unsigned N, - RankedTensorType oltType, unsigned numWarps); +gpu::DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N, + RankedTensorType oltType, + unsigned numWarps); +gpu::DistributedEncodingTrait +getTmemLoadLayoutSplitLongM(RankedTensorType tensorType, + gpu::MemDescType memType, int numWarps); +SmallVector +getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, + gpu::MemDescType memType); bool isDistributedLayoutTMemCompatible(Operation *op, RankedTensorType tensorType, gpu::MemDescType memType); - bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType, gpu::MemDescType memType, int numWarps); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 4ab7fb8cae..2f8be4f168 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -17,6 +17,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LayoutUtility.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Dialect/TritonGPU/IR/Types.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -254,12 +255,9 @@ CTALayoutAttr getCTALayout(Attribute layout) { } SmallVector getCTAsPerCGA(Attribute layout) { - ArrayRef ref; if (auto ttgLayout = mlir::dyn_cast(layout)) return ttgLayout.getCTAsPerCGA(); - else - llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); - return SmallVector(ref.begin(), ref.end()); + llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA"); } SmallVector getCTASplitNum(Attribute layout) { @@ -581,235 +579,11 @@ static void maybePrintCTALayout(mlir::MLIRContext *context, //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// -#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc" +#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc" #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc" -// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. -// But we need to have a consistent interface with e.g. SliceEncodingAttr, which -// computes some of these fields. -SmallVector BlockedEncodingAttr::getRepOrder() const { - return SmallVector(getOrder()); -} -SmallVector BlockedEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} -SmallVector BlockedEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} -SmallVector BlockedEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} - -template -SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { - size_t rank = shape.size(); - unsigned dim = getDim(); - SmallVector retShape(rank + 1); - for (unsigned d = 0; d < rank + 1; ++d) { - if (d < dim) - retShape[d] = shape[d]; - else if (d == dim) - retShape[d] = 1; - else - retShape[d] = shape[d - 1]; - } - return retShape; -} -template SmallVector -SliceEncodingAttr::paddedShape(ArrayRef shape) const; -template SmallVector -SliceEncodingAttr::paddedShape(ArrayRef shape) const; -SmallVector SliceEncodingAttr::getRepOrder() const { - auto parentRepOrder = getParent().getRepOrder(); - return eraseOrder(parentRepOrder, getDim()); -} -SmallVector SliceEncodingAttr::getCTASplitNum() const { - SmallVector res = ::getCTASplitNum(getParent()); - res.erase(res.begin() + getDim()); - return res; -} -SmallVector SliceEncodingAttr::getCTAOrder() const { - auto parentCTAOrder = ::getCTAOrder(getParent()); - return eraseOrder(parentCTAOrder, getDim()); -} -SmallVector SliceEncodingAttr::getCTAsPerCGA() const { - auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); - if (parentCTAsPerCGA[getDim()] == 1) { - parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); - return parentCTAsPerCGA; - } - /* For getCTAsPerCGA of a slice layout, we have two choices: - * (1) Return CTAsPerCGA of its parent. This is not a perfect solution - * because the rank of the returned CTAsPerCGA does not match the rank of - * tensorShape. - * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a - * perfect solution because the product of the returned CTAsPerCGA might not - * match numCTAs. - * To avoid introducing inconsistencies to the shape and - * layout system, the usage of directly getting CTAsPerCGA of a slice layout - * in which the sliced dim is not 1 is banned. You should always consider - * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) - * in the branch where layout is an instance of SliceEncodingAttr. This is - * inconvenient but safe. - */ - llvm::report_fatal_error( - "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); -} - -// Wmma encoding - -int32_t SwizzledSharedEncodingAttr::getAlignment() const { return 16; } - -SmallVector SwizzledSharedEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} -SmallVector SwizzledSharedEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} -SmallVector SwizzledSharedEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} - -SmallVector PaddedSharedEncodingAttr::getCTAsPerCGA() const { - return llvm::to_vector(getCTALayout().getCTAsPerCGA()); -} -SmallVector PaddedSharedEncodingAttr::getCTAOrder() const { - return llvm::to_vector(getCTALayout().getCTAOrder()); -} -SmallVector PaddedSharedEncodingAttr::getCTASplitNum() const { - return llvm::to_vector(getCTALayout().getCTASplitNum()); -} - -int32_t AMDRotatingSharedEncodingAttr::getAlignment() const { return 16; } - -SmallVector AMDRotatingSharedEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} -SmallVector AMDRotatingSharedEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} -SmallVector AMDRotatingSharedEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} - -SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { - return ::getCTAsPerCGA(getParent()); -} -SmallVector DotOperandEncodingAttr::getCTAOrder() const { - return ::getCTAOrder(getParent()); -} -SmallVector DotOperandEncodingAttr::getCTASplitNum() const { - SmallVector res = ::getCTASplitNum(getParent()); - auto rank = res.size(); - assert(rank == 2 || rank == 3 && "Invalid dotLayout"); - - // Do not split CTA in K dimension - auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; - res[kDim] = 1; - return res; -} - -LogicalResult DotOperandEncodingAttr::verify( - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - unsigned opIdx, Attribute parent, unsigned kWidth) { - if (opIdx != 0 && opIdx != 1) { - return emitError() << "ttg.dot_op opIdx parameter can be 0 or 1, got: " - << opIdx; - } - if (!parent) { - return emitError() << "ttg.dot_op parent parameter cannot be null"; - } - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) - return emitError() << "ttg.dot_op kWidth parameter can only be " - "non-zero for Ampere or Hopper MMA parent"; - if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) - return emitError() << "ttg.dot_op kWidth parameter is mandatory for " - "Ampere or Hopper MMA parent"; - if (opIdx != 0 && parentAttr.isHopper()) - return emitError() - << "ttg.dot_op opIdx parameter must be 0 for " - "Hopper MMA parent, since Hopper WGMMA only allows first " - "operand to be in registers"; - return success(); - } - - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 16 && parentAttr.getVersion() == 1 || - kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 2) - return emitError() << "ttg.dot_op kWidth parameter must be 16 for " - "gfx11 and 4/8/16 for gfx12 (including packed " - "cases for `scaled_dot`)"; - return success(); - } - - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth == 0) - return emitError() << "ttg.dot_op kWidth parameter is mandatory for " - "MFMA parent"; - return success(); - } - - if (auto parentAttr = mlir::dyn_cast(parent)) { - int opsPerChannel = parentAttr.getOpsPerChannel(); - if (opIdx == 0) { - // operand A - if (opsPerChannel == 1) { - if (kWidth != opsPerChannel) - return emitError() << "ttg.dot_op kWidth parameter must match the " - "parent's opsPerChannel"; - } else { - if (kWidth != opsPerChannel / 2) - return emitError() << "ttg.dot_op kWidth parameter must match the " - "parent's opsPerChannel"; - } - - unsigned repeatCount = parentAttr.getRepeatCount(); - unsigned systolicDepth = parentAttr.getSystolicDepth(); - unsigned threadsPerWarp = parentAttr.getThreadsPerWarp(); - // OpsPerChannel: 4 is for i8 type. 2 is for f16/bf16 type. 1 is for - // float32 type. 2 i8 elements are packed into i16. The number of packed - // elements per row for A operand is: 8, 16, 16. - unsigned numPackedElemPerRowForA = - opsPerChannel == 1 ? systolicDepth : systolicDepth * 2; - if (repeatCount * numPackedElemPerRowForA < threadsPerWarp) - return emitError() - << "The DPAS encoding implies an invalid layout for A " - "operand. The non-uniform matrix A could not be " - "referred in kernel with threadsPerWarp: " - << threadsPerWarp - << ". numPackedElemPerRowForA:" << numPackedElemPerRowForA - << ". RC:" << repeatCount << ", systolicDepth:" << systolicDepth - << ", opsPerChan:" << opsPerChannel; - } else { - // operand B - if (kWidth != parentAttr.getOpsPerChannel()) - return emitError() << "ttg.dot_op kWidth parameter must match the " - "parent's opsPerChannel"; - } - - return success(); - } - - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0) - return emitError() << "ttg.dot_op kWidth parameter is not supported " - "when the parent is a warp layout"; - return success(); - } - - if (auto parentAttr = mlir::dyn_cast(parent)) { - if (kWidth != 0) - return emitError() << "ttg.dot_op kWidth parameter is not supported " - "when the parent is a blocked layout"; - return success(); - } - - return emitError() << "ttg.dot_op unexpected parent layout: " << parent; -} - //===----------------------------------------------------------------------===// // Blocked Encoding //===----------------------------------------------------------------------===// @@ -956,6 +730,17 @@ LinearEncodingAttr::verify(function_ref emitError, return success(); } +// If we only had BlockedEncodingAttr, we could simply return ArrayRefs here. +// But we need to have a consistent interface with e.g. SliceEncodingAttr, which +// computes some of these fields. +SmallVector BlockedEncodingAttr::getRepOrder() const { + return SmallVector(getOrder()); +} + +//===----------------------------------------------------------------------===// +// Linear Encoding +//===----------------------------------------------------------------------===// + void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { // We don't use the default implementation as it's a bit too verbose // This prints in the following format that is shape agnostic, in the sense @@ -1059,9 +844,9 @@ Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { std::move(linearLayout)); } -SmallVector basesPerDimImpl(const LinearLayout::BasesT &namedBases, - StringAttr dimName, size_t rank, - bool skipBroadcast = true) { +static SmallVector +basesPerDimImpl(const LinearLayout::BasesT &namedBases, StringAttr dimName, + size_t rank, bool skipBroadcast = true) { const auto &bases = namedBases.find(dimName)->second; if (bases.empty()) { @@ -1133,6 +918,7 @@ SmallVector LinearEncodingAttr::getRepOrder() const { // the same shape as the tensor that uses it return getOrder(); } + SmallVector LinearEncodingAttr::getCTAsPerCGA() const { // CTAs are split into an identity part (SplitNum) and a broadcast part return basesPerDim(StringAttr::get(getContext(), "block"), @@ -1156,6 +942,7 @@ SmallVector LinearEncodingAttr::getThreadsPerWarp() const { SmallVector LinearEncodingAttr::getThreadOrder() const { return orderPerDim(StringAttr::get(getContext(), "lane"), getOrder()); } + SmallVector LinearEncodingAttr::getSizePerThread() const { auto rank = getOrder().size(); auto ll = getLinearLayout(); @@ -1606,6 +1393,81 @@ void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const { << "parent = " << getParent() << "}>"; } +LogicalResult +SliceEncodingAttr::verify(function_ref emitError, + unsigned dim, DistributedEncodingTrait parent) { + if (mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) + return success(); + unsigned rank = cast(parent).getRank(); + if (rank <= 1) + return emitError() << "parent layout must have at least rank >= 2"; + if (dim >= rank) { + return emitError() << "slice dim=" << dim + << " must be less than the parent rank=" << rank; + } + return success(); +} + +SmallVector SliceEncodingAttr::getRepOrder() const { + auto parentRepOrder = getParent().getRepOrder(); + return eraseOrder(parentRepOrder, getDim()); +} + +SmallVector SliceEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + res.erase(res.begin() + getDim()); + return res; +} + +SmallVector SliceEncodingAttr::getCTAOrder() const { + auto parentCTAOrder = ::getCTAOrder(getParent()); + return eraseOrder(parentCTAOrder, getDim()); +} + +SmallVector SliceEncodingAttr::getCTAsPerCGA() const { + auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent()); + if (parentCTAsPerCGA[getDim()] == 1) { + parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim()); + return parentCTAsPerCGA; + } + /* For getCTAsPerCGA of a slice layout, we have two choices: + * (1) Return CTAsPerCGA of its parent. This is not a perfect solution + * because the rank of the returned CTAsPerCGA does not match the rank of + * tensorShape. + * (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a + * perfect solution because the product of the returned CTAsPerCGA might not + * match numCTAs. + * To avoid introducing inconsistencies to the shape and + * layout system, the usage of directly getting CTAsPerCGA of a slice layout + * in which the sliced dim is not 1 is banned. You should always consider + * slice layout as a special case and use getCTAsPerCGA(layout.getParent()) + * in the branch where layout is an instance of SliceEncodingAttr. This is + * inconvenient but safe. + */ + llvm::report_fatal_error( + "getCTAsPerCGA for SliceEncodingAttr is not well-defined"); +} + +template +SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const { + size_t rank = shape.size(); + unsigned dim = getDim(); + SmallVector retShape(rank + 1); + for (unsigned d = 0; d < rank + 1; ++d) { + if (d < dim) + retShape[d] = shape[d]; + else if (d == dim) + retShape[d] = 1; + else + retShape[d] = shape[d - 1]; + } + return retShape; +} +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; +template SmallVector +SliceEncodingAttr::paddedShape(ArrayRef shape) const; + //===----------------------------------------------------------------------===// // Helper shared encoding functions //===----------------------------------------------------------------------===// @@ -1931,16 +1793,6 @@ int32_t NVMMASharedEncodingAttr::getAlignment() const { return 128 * getMaxPhase(); } -SmallVector NVMMASharedEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} -SmallVector NVMMASharedEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} -SmallVector NVMMASharedEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} - //===----------------------------------------------------------------------===// // AMDRotatingShared encoding //===----------------------------------------------------------------------===// @@ -1965,16 +1817,6 @@ void AMDRotatingSharedEncodingAttr::print(AsmPrinter &printer) const { //===----------------------------------------------------------------------===// // TODO: there is a lot of common code with MmaEncoding here -SmallVector AMDMfmaEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} -SmallVector AMDMfmaEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} -SmallVector AMDMfmaEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} - bool AMDMfmaEncodingAttr::hasUnitTilesPerWarp() const { return !llvm::any_of(getTilesPerWarp(), [](int x) { return x != 1; }); } @@ -2099,16 +1941,6 @@ AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const { return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); } -SmallVector AMDWmmaEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} -SmallVector AMDWmmaEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} -SmallVector AMDWmmaEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} - SmallVector AMDWmmaEncodingAttr::getElemsPerInstrForOperands(int kDim, int opIdx) const { if (opIdx == 0) @@ -2202,15 +2034,6 @@ bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; } SmallVector NvidiaMmaEncodingAttr::getRepOrder() const { return getMatrixOrder(getRank(), /*rowMajor*/ true); } -SmallVector NvidiaMmaEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} -SmallVector NvidiaMmaEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} -SmallVector NvidiaMmaEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} SmallVector NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const { @@ -2262,6 +2085,7 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef shape, int bitwidth, //===----------------------------------------------------------------------===// // DotOperand Encoding //===----------------------------------------------------------------------===// + SmallVector DotOperandEncodingAttr::getRepOrder() const { if (auto mma = mlir::dyn_cast(getParent())) { return mma.getRepOrderForOperand(getOpIdx()); @@ -2273,6 +2097,125 @@ SmallVector DotOperandEncodingAttr::getRepOrder() const { return {}; } +SmallVector DotOperandEncodingAttr::getCTAsPerCGA() const { + return ::getCTAsPerCGA(getParent()); +} + +SmallVector DotOperandEncodingAttr::getCTAOrder() const { + return ::getCTAOrder(getParent()); +} + +SmallVector DotOperandEncodingAttr::getCTASplitNum() const { + SmallVector res = ::getCTASplitNum(getParent()); + auto rank = res.size(); + assert(rank == 2 || rank == 3 && "Invalid dotLayout"); + + // Do not split CTA in K dimension + auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2; + res[kDim] = 1; + return res; +} + +LogicalResult DotOperandEncodingAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + unsigned opIdx, Attribute parent, unsigned kWidth) { + if (opIdx != 0 && opIdx != 1) { + return emitError() << "ttg.dot_op opIdx parameter can be 0 or 1, got: " + << opIdx; + } + if (!parent) { + return emitError() << "ttg.dot_op parent parameter cannot be null"; + } + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter can only be " + "non-zero for Ampere or Hopper MMA parent"; + if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper())) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "Ampere or Hopper MMA parent"; + if (opIdx != 0 && parentAttr.isHopper()) + return emitError() + << "ttg.dot_op opIdx parameter must be 0 for " + "Hopper MMA parent, since Hopper WGMMA only allows first " + "operand to be in registers"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 16 && parentAttr.getVersion() == 1 || + kWidth != 4 && kWidth != 8 && kWidth != 16 && + parentAttr.getVersion() == 2) + return emitError() << "ttg.dot_op kWidth parameter must be 16 for " + "gfx11 and 4/8/16 for gfx12 (including packed " + "cases for `scaled_dot`)"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth == 0) + return emitError() << "ttg.dot_op kWidth parameter is mandatory for " + "MFMA parent"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + int opsPerChannel = parentAttr.getOpsPerChannel(); + if (opIdx == 0) { + // operand A + if (opsPerChannel == 1) { + if (kWidth != opsPerChannel) + return emitError() << "ttg.dot_op kWidth parameter must match the " + "parent's opsPerChannel"; + } else { + if (kWidth != opsPerChannel / 2) + return emitError() << "ttg.dot_op kWidth parameter must match the " + "parent's opsPerChannel"; + } + + unsigned repeatCount = parentAttr.getRepeatCount(); + unsigned systolicDepth = parentAttr.getSystolicDepth(); + unsigned threadsPerWarp = parentAttr.getThreadsPerWarp(); + // OpsPerChannel: 4 is for i8 type. 2 is for f16/bf16 type. 1 is for + // float32 type. 2 i8 elements are packed into i16. The number of packed + // elements per row for A operand is: 8, 16, 16. + unsigned numPackedElemPerRowForA = + opsPerChannel == 1 ? systolicDepth : systolicDepth * 2; + if (repeatCount * numPackedElemPerRowForA < threadsPerWarp) + return emitError() + << "The DPAS encoding implies an invalid layout for A " + "operand. The non-uniform matrix A could not be " + "referred in kernel with threadsPerWarp: " + << threadsPerWarp + << ". numPackedElemPerRowForA:" << numPackedElemPerRowForA + << ". RC:" << repeatCount << ", systolicDepth:" << systolicDepth + << ", opsPerChan:" << opsPerChannel; + } else { + // operand B + if (kWidth != parentAttr.getOpsPerChannel()) + return emitError() << "ttg.dot_op kWidth parameter must match the " + "parent's opsPerChannel"; + } + + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "when the parent is a warp layout"; + return success(); + } + + if (auto parentAttr = mlir::dyn_cast(parent)) { + if (kWidth != 0) + return emitError() << "ttg.dot_op kWidth parameter is not supported " + "when the parent is a blocked layout"; + return success(); + } + + return emitError() << "ttg.dot_op unexpected parent layout: " << parent; +} + //===----------------------------------------------------------------------===// // ASM Interface (i.e.: alias) //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index f0c23c931b..68ebf56547 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -630,6 +630,17 @@ OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) { return loadSrc; } +int32_t LocalAllocOp::getAlignmentOrDefault() { + auto align = getAlignment(); + if (align) { + return *align; + } + + auto ty = getType(); + auto enc = dyn_cast(ty.getEncoding()); + return enc ? enc.getAlignment() : 16; +} + LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy, ShapedType dstTy) { if (srcTy.getElementType() != dstTy.getElementType()) { @@ -660,10 +671,27 @@ LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy) { return verifyMemoryOpTypes(op, cast(src.getType()), dstTy); } +static LogicalResult verifySharedMemoryRank(Operation *op, + RankedTensorType type, + MemDescType memdesc, + StringRef regName) { + auto enc = dyn_cast(memdesc.getEncoding()); + if (!enc) + return op->emitOpError("expected memdesc to have a shared memory encoding"); + if (type.getRank() != enc.getRank()) { + return op->emitOpError(regName) + << " has rank " << type.getRank() + << " but memdesc encoding has rank " << enc.getRank(); + } + return success(); +} + LogicalResult LocalAllocOp::verify() { if (!isa(getType().getMemorySpace())) return emitOpError("should create a buffer of shared memory"); - + if (getSrc() && failed(verifySharedMemoryRank(*this, getSrc().getType(), + getType(), "source"))) + return failure(); return verifyAllocOp(*this, getSrc(), getType()); } @@ -671,11 +699,17 @@ LogicalResult LocalAllocOp::verify() { LogicalResult LocalStoreOp::verify() { if (!getDst().getType().getMutableMemory()) return emitOpError("Cannot store into immutable memory"); + if (failed(verifySharedMemoryRank(*this, getSrc().getType(), + getDst().getType(), "source"))) + return failure(); return verifyMemoryOpTypes(*this, getSrc().getType(), getDst().getType()); } // LocalLoadOp LogicalResult LocalLoadOp::verify() { + if (failed(verifySharedMemoryRank(*this, getType(), getSrc().getType(), + "result"))) + return failure(); return verifyMemoryOpTypes(*this, getSrc().getType(), getType()); } @@ -807,19 +841,6 @@ LogicalResult MemDescSubsliceOp::verify() { return success(); } -// -- LocalAllocOp -- - -int32_t LocalAllocOp::getAlignmentOrDefault() { - auto align = getAlignment(); - if (align) { - return *align; - } - - auto ty = getType(); - auto enc = dyn_cast(ty.getEncoding()); - return enc ? enc.getAlignment() : 16; -} - // -- WarpSpecializeOp -- RegionRange WarpSpecializeOp::getPartitionRegions() { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp index 06bae6cffc..793fa09acc 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -23,6 +23,7 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Tools/Sys/GetEnv.hpp" #include @@ -99,9 +100,9 @@ TMemAllocation getTmemAllocSizes(MemDescType memDescType) { return TMemAllocation(numColumn, numRows); } -Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N, - RankedTensorType oldType, - unsigned numWarps) { +DistributedEncodingTrait getTmemLoadStoreLayout32x32b(unsigned M, unsigned N, + RankedTensorType oldType, + unsigned numWarps) { assert(numWarps == 4 || numWarps == 8); auto shape = getShapePerCTA(oldType); assert(shape.size() == 2); @@ -150,8 +151,9 @@ Attribute getTmemLoadStoreLayout32x32b(unsigned M, unsigned N, warpsPerCTA, order, ctaLayout); } -Attribute getTmemCompatibleLayout(unsigned M, unsigned N, - RankedTensorType oldType, unsigned numWarps) { +DistributedEncodingTrait getTmemCompatibleLayout(unsigned M, unsigned N, + RankedTensorType oldType, + unsigned numWarps) { bool prefer16x256 = triton::tools::getBoolEnv("TRITON_PREFER_TMEM_16x256_LAYOUT"); if (prefer16x256) { @@ -164,63 +166,85 @@ Attribute getTmemCompatibleLayout(unsigned M, unsigned N, return getTmemLoadStoreLayout32x32b(M, N, oldType, numWarps); } -bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType, - MemDescType memType, int numWarps) { +DistributedEncodingTrait +getTmemLoadLayoutSplitLongM(RankedTensorType tensorType, MemDescType memType, + int numWarps) { auto tmemEnc = dyn_cast( memType.getEncoding()); if (!tmemEnc || tmemEnc.getBlockM() != 128) - return false; + return {}; int M = tmemEnc.getBlockM(); int N = tmemEnc.getBlockN(); auto llEncoding = dyn_cast(tensorType.getEncoding()); if (!llEncoding) - return false; + return {}; auto CTALayout = getCTALayout(tensorType.getEncoding()); auto shapePerCTA = mlir::triton::gpu::getShapePerCTA(tensorType); if (numWarps != 8) - return false; + return {}; LinearLayout llLayout = - getTmemLoadLayoutSplitLongM(M, N, tensorType, numWarps); - return llEncoding.getLinearLayout() == llLayout; + gpu::getTmemLoadLayoutSplitLongM(M, N, tensorType, numWarps); + return LinearEncodingAttr::get(tensorType.getContext(), llLayout); } -// Verify if the distributed layout can be mapped onto tensor memory. -bool isDistributedLayoutTMemCompatible(Operation *op, - RankedTensorType tensorType, - MemDescType memType) { +bool isDistributedLayoutSplitMTmemLoadStore(RankedTensorType tensorType, + MemDescType memType, int numWarps) { + auto layout = getTmemLoadLayoutSplitLongM(tensorType, memType, numWarps); + if (!layout) + return false; + return areLayoutsEquivalent( + tensorType.getShape(), cast(layout), + cast(tensorType.getEncoding())); +} + +SmallVector +getTmemCompatibleLayouts(Operation *op, RankedTensorType tensorType, + MemDescType memType) { int numWarps = lookupNumWarps(op); assert(numWarps % 4 == 0); + if (isa( memType.getEncoding())) { - return tensorType.getEncoding() == - triton::gpu::LinearEncodingAttr::get( - tensorType.getContext(), - getScaleTMEMStoreLinearLayout(tensorType, numWarps)); + return {triton::gpu::LinearEncodingAttr::get( + tensorType.getContext(), + getScaleTMEMStoreLinearLayout(tensorType, numWarps))}; } + + SmallVector layouts; auto attr = cast(memType.getEncoding()); int blockM = attr.getBlockM(); int blockN = attr.getBlockN(); - if (isDistributedLayoutSplitMTmemLoadStore(tensorType, memType, numWarps)) - return true; - auto ll16x256 = - getTmemLoadStoreLayout16x256(blockM, blockN, tensorType, numWarps); - auto enc = - cast(tensorType.getEncoding()); - if (ll16x256.has_value() && - areLayoutsEquivalent( - tensorType.getShape(), - LinearEncodingAttr::get(tensorType.getContext(), ll16x256.value()), - enc)) - return true; - auto layout = cast( - nvidia_gpu::getTmemLoadStoreLayout32x32b(blockM, blockN, tensorType, - numWarps)); + if (DistributedEncodingTrait splitMLayout = + getTmemLoadLayoutSplitLongM(tensorType, memType, numWarps)) + layouts.push_back(splitMLayout); + + if (auto ll16x256 = + getTmemLoadStoreLayout16x256(blockM, blockN, tensorType, numWarps)) { + layouts.push_back( + LinearEncodingAttr::get(tensorType.getContext(), ll16x256.value())); + } + + layouts.push_back(nvidia_gpu::getTmemLoadStoreLayout32x32b( + blockM, blockN, tensorType, numWarps)); + // TODO: Add support for more layout compatible with tmem load/store. There // will only be a discret set of layout possible due to the limiations of // tmem_load/store. - return areLayoutsEquivalent(tensorType.getShape(), layout, enc); + return layouts; +} + +// Verify if the distributed layout can be mapped onto tensor memory. +bool isDistributedLayoutTMemCompatible(Operation *op, + RankedTensorType tensorType, + gpu::MemDescType memType) { + SmallVector layouts = + getTmemCompatibleLayouts(op, tensorType, memType); + auto enc = cast(tensorType.getEncoding()); + return llvm::any_of(layouts, [&](DistributedEncodingTrait layout) { + return areLayoutsEquivalent(tensorType.getShape(), layout, enc); + }); } LogicalResult impl::verifyMMAv5Op(Operation *op) { diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 20c6edbe49..0f68507f19 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -22,7 +22,10 @@ */ #include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/Support/LLVM.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOpInterfaces.cpp.inc" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" @@ -495,6 +498,35 @@ void TCGen5MMAScaledOp::build(OpBuilder &builder, OperationState &state, } // -- TMEMStoreOp -- +static LogicalResult verifyTMEMOperand(Operation *op, RankedTensorType type, + MemDescType memdesc, StringRef regName) { + if (type.getRank() != 2) + return op->emitOpError(regName) << " must be a 2D tensor"; + if (type.getEncoding()) { + auto enc = dyn_cast(type.getEncoding()); + if (!enc) { + return op->emitOpError(regName) + << " does not have an distributed encoding"; + } + SmallVector layouts = + getTmemCompatibleLayouts(op, type, memdesc); + if (layouts.empty()) { + return op->emitOpError(regName) + << " does not have any TMEM compatible layouts"; + } + if (llvm::none_of(layouts, [&](DistributedEncodingTrait layout) { + return areLayoutsEquivalent(type.getShape(), layout, enc); + })) { + InFlightDiagnostic diag = op->emitOpError(regName) + << " layout is not TMEM compatible"; + for (Attribute layout : layouts) + diag.attachNote() << "potential TMEM layout: " << layout; + return diag; + } + } + return success(); +} + LogicalResult TMEMStoreOp::verify() { if (!isa( getDst().getType().getMemorySpace())) @@ -505,6 +537,9 @@ LogicalResult TMEMStoreOp::verify() { if (!getDst().getType().getMutableMemory()) { return emitOpError("Cannot store into an immutable alloc"); } + if (failed(verifyTMEMOperand(*this, getSrc().getType(), getDst().getType(), + "source"))) + return failure(); return triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(), getDst().getType()); } @@ -517,6 +552,8 @@ LogicalResult TMEMLoadOp::verify() { if (!isa( getSrc().getType().getEncoding())) return emitOpError("should use tensor memory encoding."); + if (failed(verifyTMEMOperand(*this, getType(), getSrc().getType(), "result"))) + return failure(); return triton::gpu::verifyMemoryOpTypes(*this, getSrc().getType(), getType()); } @@ -527,6 +564,9 @@ LogicalResult TMEMAllocOp::verify() { if (!isa( getType().getEncoding())) return emitOpError("should use tensor memory encoding"); + if (getSrc() && + failed(verifyTMEMOperand(*this, getSrc().getType(), getType(), "source"))) + return failure(); return triton::gpu::verifyAllocOp(*this, getSrc(), getType()); } diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 031d9fcb06..6879402fc7 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -269,7 +269,7 @@ def shared_memory_index_kernel(XBLOCK: ttgl.constexpr, layout: ttgl.constexpr, s @pytest.mark.parametrize("target", ALL_TARGETS) def test_shared_memory_index(target): layout = ttgl.BlockedLayout(size_per_thread=[1], threads_per_warp=[32], warps_per_cta=[4], order=[0]) - smem_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2) + smem_layout = ttgl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[0]) mod = run_parser( shared_memory_index_kernel, *make_args(256, layout, smem_layout, num_warps=4), @@ -278,7 +278,7 @@ def test_shared_memory_index(target): expecttest.assert_expected_inline( anonymize_ir(mod.str_nodebug()), """\ #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> -#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { tt.func public @shared_memory_index_kernel() attributes {noinline = false} { diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 790cb9f33a..617c1252ff 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -729,7 +729,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- -#layout = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#layout = #ttg.swizzled_shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0]}> #smem = #ttg.shared_memory // CHECK-LABEL: @warp_specialize_isolated_regions diff --git a/test/TritonGPU/canonicalize.mlir b/test/TritonGPU/canonicalize.mlir index ebcc412fda..10fba51712 100644 --- a/test/TritonGPU/canonicalize.mlir +++ b/test/TritonGPU/canonicalize.mlir @@ -108,9 +108,9 @@ tt.func @test_canonicalize_convert_histogram(%arg0: tensor<256xi32, #blocked1>, // CHECK-NEXT: gpu.barrier // CHECK-NEXT: tt.return %[[V]] -#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> -#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32, "ttg.compute-capability" = 80} { tt.func @test_canonicalize_convert_local_load(%arg0: !ttg.async.token) -> tensor<256xi32, #blocked1> { diff --git a/test/TritonGPU/inline.mlir b/test/TritonGPU/inline.mlir index 5ad5dfb2de..e623b13ee0 100644 --- a/test/TritonGPU/inline.mlir +++ b/test/TritonGPU/inline.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -inline | FileCheck %s #smem = #ttg.shared_memory -#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> // CHECK-LABEL: @inline_in_warp_specialize tt.func public @inline_in_warp_specialize(%arg0: !ttg.memdesc<1xi32, #shared, #smem, mutable>) { diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index 4f8ce0ec2f..a779f95145 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -431,3 +431,13 @@ tt.func @memdesc_reinterpret(%arg0: !ttg.memdesc<1xi64, #shared, #ttg.shared_mem %0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<1xi64, #shared, #ttg.shared_memory> -> !ttg.memdesc<1xi32, #shared, #ttng.tensor_memory> tt.return } + +// ----- + +// expected-error @below {{parent layout must have at least rank >= 2}} +#slice = #ttg.slice<{dim = 0, parent = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>}> + +// ----- + +// expected-error @below {{slice dim=2 must be less than the parent rank=2}} +#slice = #ttg.slice<{dim = 2, parent = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}>}> diff --git a/test/TritonGPU/load-mma-specialization.mlir b/test/TritonGPU/load-mma-specialization.mlir index bda1b3324c..16d97d08a8 100644 --- a/test/TritonGPU/load-mma-specialization.mlir +++ b/test/TritonGPU/load-mma-specialization.mlir @@ -10,6 +10,7 @@ #shared_trans = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}> #nvmma_smem = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}> #smem = #ttg.shared_memory +#scales = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> // CHECK-DAG: [[ACC_TMEM:#.*]] = #ttng.tensor_memory_encoding #acc_tmem = #ttng.tensor_memory_encoding @@ -775,8 +776,8 @@ tt.func @matmul_scaled_rhs_scales_tma( %BLOCK_K = arith.constant 64 : i32 %zero = arith.constant dense<0.0> : tensor<128x128xf32, #acc_layout> - %a_scales_const = arith.constant dense<127> : tensor<128x8xi8, #oper_layout> - %a_scales_tmem = ttng.tmem_alloc %a_scales_const : (tensor<128x8xi8, #oper_layout>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> + %a_scales_const = arith.constant dense<127> : tensor<128x8xi8, #scales> + %a_scales_tmem = ttng.tmem_alloc %a_scales_const : (tensor<128x8xi8, #scales>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> // CHECK-COUNT-2: ttg.local_alloc : () -> !ttg.memdesc<3x1xi64, // CHECK-NOT: ttg.local_alloc : () -> !ttg.memdesc<3x1xi64, @@ -790,7 +791,7 @@ tt.func @matmul_scaled_rhs_scales_tma( // CHECK-COUNT-3: async_tma_copy_global_to_local {{.*}} {ttg.partition = 2 : i32} %a_reg = tt.descriptor_load %a_desc[%off_m, %off_k] : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #oper_layout> %b_reg = tt.descriptor_load %b_desc[%off_n, %off_k] : !tt.tensordesc> -> tensor<128x64xf8E4M3FN, #oper_layout> - %b_scales_reg = tt.descriptor_load %b_scale_desc[%off_m, %c0_i32] : !tt.tensordesc>> -> tensor<128x8xi8, #oper_layout> + %b_scales_reg = tt.descriptor_load %b_scale_desc[%off_m, %c0_i32] : !tt.tensordesc>> -> tensor<128x8xi8, #scales> %a_sh = ttg.local_alloc %a_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem> %b_sh_raw = ttg.local_alloc %b_reg : (tensor<128x64xf8E4M3FN, #oper_layout>) -> !ttg.memdesc<128x64xf8E4M3FN, #nvmma_smem, #smem> @@ -799,7 +800,7 @@ tt.func @matmul_scaled_rhs_scales_tma( // CHECK-NEXT: wait_barrier {{.*}} {ttg.partition = 1 : i32} - %b_scales_tmem = ttng.tmem_alloc %b_scales_reg : (tensor<128x8xi8, #oper_layout>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> + %b_scales_tmem = ttng.tmem_alloc %b_scales_reg : (tensor<128x8xi8, #scales>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> %c_tmem, %c_tok = ttng.tmem_alloc %acc : (tensor<128x128xf32, #acc_layout>) -> (!ttg.memdesc<128x128xf32, #acc_tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) @@ -1128,9 +1129,10 @@ tt.func @load_scale_mma_user( %scales_reg = ttg.local_load %scales_shared : !ttg.memdesc<8x128xi8, #shared, #smem> -> tensor<8x128xi8, #oper_layout> // CHECK-NEXT: [[SCALES_TRANS:%.*]] = tt.trans [[SCALES_REG]] {{.*}}partition = 0 %scales_T = tt.trans %scales_reg {order = array} : tensor<8x128xi8, #oper_layout> -> tensor<128x8xi8, #oper_layout_trans> + %scales_cvt = ttg.convert_layout %scales_T : tensor<128x8xi8, #oper_layout_trans> -> tensor<128x8xi8, #scales> // CHECK-NEXT: wait_barrier [[SCALES_TMEM_BAR:%.*]], %arg{{[0-9]+}} {{.*}}partition = 0 // CHECK-NEXT: tmem_store [[SCALES_TRANS]], [[SCALES_TMEM:%.*]], %true {{.*}}partition = 0 - %scales_tmem = ttng.tmem_alloc %scales_T : (tensor<128x8xi8, #oper_layout_trans>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> + %scales_tmem = ttng.tmem_alloc %scales_cvt : (tensor<128x8xi8, #scales>) -> !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory> // CHECK-NEXT: arrive_barrier [[SCALES_READY_BAR:%.*]], 1 {{.*}}partition = 0 // CHECK: wait_barrier [[USER_DONE:%.*]], %arg{{[0-9]+}}, %true {{.*}}partition = 1 diff --git a/test/TritonGPU/optimize-partition-warps.mlir b/test/TritonGPU/optimize-partition-warps.mlir index e97e6cf071..2659f497d1 100644 --- a/test/TritonGPU/optimize-partition-warps.mlir +++ b/test/TritonGPU/optimize-partition-warps.mlir @@ -37,24 +37,24 @@ tt.func @no_tensor_computations(%arg0: i32) { // CHECK-LABEL: @small_tensor_computation tt.func @small_tensor_computation(%arg0: i32) { - %alloc = ttg.local_alloc : () -> !ttg.memdesc<128xi32, #shared, #smem, mutable> + %alloc = ttg.local_alloc : () -> !ttg.memdesc<128xi32, #shared_1d, #smem, mutable> ttg.warp_specialize(%arg0, %alloc) default { ttg.warp_yield } // CHECK: partition0({{.*}}) num_warps(1) - partition0(%arg1: i32, %arg2: !ttg.memdesc<128xi32, #shared, #smem, mutable>) num_warps(8) { + partition0(%arg1: i32, %arg2: !ttg.memdesc<128xi32, #shared_1d, #smem, mutable>) num_warps(8) { %0 = tt.splat %arg1 : i32 -> tensor<128xi32, #blocked8> - ttg.local_store %0, %arg2 : tensor<128xi32, #blocked8> -> !ttg.memdesc<128xi32, #shared, #smem, mutable> + ttg.local_store %0, %arg2 : tensor<128xi32, #blocked8> -> !ttg.memdesc<128xi32, #shared_1d, #smem, mutable> ttg.warp_return } // CHECK: partition1({{.*}}) num_warps(1) - partition1(%arg1: i32, %arg2: !ttg.memdesc<128xi32, #shared, #smem, mutable>) num_warps(4) { + partition1(%arg1: i32, %arg2: !ttg.memdesc<128xi32, #shared_1d, #smem, mutable>) num_warps(4) { %0 = tt.splat %arg1 : i32 -> tensor<128xi32, #blocked4> %1 = ttg.convert_layout %0 : tensor<128xi32, #blocked4> -> tensor<128xi32, #blocked4_broadcast> - ttg.local_store %1, %arg2 : tensor<128xi32, #blocked4_broadcast> -> !ttg.memdesc<128xi32, #shared, #smem, mutable> + ttg.local_store %1, %arg2 : tensor<128xi32, #blocked4_broadcast> -> !ttg.memdesc<128xi32, #shared_1d, #smem, mutable> ttg.warp_return - } : (i32, !ttg.memdesc<128xi32, #shared, #smem, mutable>) -> () + } : (i32, !ttg.memdesc<128xi32, #shared_1d, #smem, mutable>) -> () tt.return } diff --git a/test/TritonGPU/pipeline-assign-latencies.mlir b/test/TritonGPU/pipeline-assign-latencies.mlir index a792a63b40..b2a205f21c 100644 --- a/test/TritonGPU/pipeline-assign-latencies.mlir +++ b/test/TritonGPU/pipeline-assign-latencies.mlir @@ -841,9 +841,9 @@ tt.func @tc_gen5_mma_scaled(%lb : index, %ub : index, %step : index, #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked2 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}> +#scales = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> -#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #tmem = #ttng.tensor_memory_encoding #tmem_scales = #ttng.tensor_memory_scales_encoding<> #smem = #ttg.shared_memory @@ -853,8 +853,8 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { tt.func @tc_gen5_mma_scaled_tmem_scales(%lb : index, %ub : index, %step : index, %A_ptr: tensor<128x128x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %B_ptr: tensor<128x128x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, - %A_sc_ptr: tensor<1x2x32x4x4x!tt.ptr, #blocked2> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, - %B_sc_ptr: tensor<1x2x32x4x4x!tt.ptr, #blocked2> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, + %A_sc_ptr: tensor<128x8x!tt.ptr, #scales> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, + %B_sc_ptr: tensor<128x8x!tt.ptr, #scales> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %acc_init : tensor<128x128xf32, #blocked1>) -> () { %true = arith.constant true %acc_tm = ttng.tmem_alloc %acc_init : (tensor<128x128xf32, #blocked1>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> @@ -864,17 +864,17 @@ tt.func @tc_gen5_mma_scaled_tmem_scales(%lb : index, %ub : index, %step : index, %B = tt.load %B_ptr : tensor<128x128x!tt.ptr, #blocked1> %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> - %A_sc = tt.load %A_sc_ptr : tensor<1x2x32x4x4x!tt.ptr, #blocked2> - %A_sc_sh = ttg.local_alloc %A_sc : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem> + %A_sc = tt.load %A_sc_ptr : tensor<128x8x!tt.ptr, #scales> + %A_sc_sh = ttg.local_alloc %A_sc : (tensor<128x8xi8, #scales>) -> !ttg.memdesc<128x8xi8, #shared1, #smem> - %B_sc = tt.load %B_sc_ptr : tensor<1x2x32x4x4x!tt.ptr, #blocked2> - %B_sc_tm = ttng.tmem_alloc %B_sc : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #tmem_scales, #ttng.tensor_memory> + %B_sc = tt.load %B_sc_ptr : tensor<128x8x!tt.ptr, #scales> + %B_sc_tm = ttng.tmem_alloc %B_sc : (tensor<128x8xi8, #scales>) -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory> // CHECK: ttng.tc_gen5_mma_scaled {{.*}} // CHECK-NOT: tt.latency // CHECK-NOT: tt.self_latency ttng.tmem_store %acc_init, %acc_tm, %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> - ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %A_sc_sh, %B_sc_tm, %true, %true lhs = e5m2 rhs = e5m2 : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<1x2x32x4x4xi8, #tmem_scales, #ttng.tensor_memory> + ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %A_sc_sh, %B_sc_tm, %true, %true lhs = e5m2 rhs = e5m2 : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #shared1, #smem>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory> %acc_res = ttng.tmem_load %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> "use"(%acc_res) : (tensor<128x128xf32, #blocked1>) -> () } diff --git a/test/TritonGPU/pipeline-schedule-loop.mlir b/test/TritonGPU/pipeline-schedule-loop.mlir index d1dc51df30..33326dbc53 100644 --- a/test/TritonGPU/pipeline-schedule-loop.mlir +++ b/test/TritonGPU/pipeline-schedule-loop.mlir @@ -376,9 +376,9 @@ tt.func @tc_gen5_mma(%lb : index, %ub : index, %step : index, // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok = ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm[], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #tmem> + %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked1> // CHECK: "use"{{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - "use"(%c) : (tensor<128x128xf32, #tmem>) -> () + "use"(%c) : (tensor<128x128xf32, #blocked1>) -> () } tt.return } @@ -409,8 +409,8 @@ tt.func @tc_gen5_mma_if_user(%lb : index, %ub : index, %step : index, // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok = ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm[], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> scf.if %cnd { - %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #tmem> - "use"(%c) : (tensor<128x128xf32, #tmem>) -> () + %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked1> + "use"(%c) : (tensor<128x128xf32, #blocked1>) -> () } // CHECK: scf.if // CHECK: tmem_load @@ -449,9 +449,9 @@ tt.func @tc_gen5_mma_scaled(%lb : index, %ub : index, %step : index, // CHECK: ttng.tc_gen5_mma_scaled {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok = ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm[], %A_sc_sh, %B_sc_sh, %true, %true lhs = e5m2 rhs = e5m2 {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #ttg.shared_memory>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #ttg.shared_memory> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #tmem> + %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked1> // CHECK: "use"{{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - "use"(%c) : (tensor<128x128xf32, #tmem>) -> () + "use"(%c) : (tensor<128x128xf32, #blocked1>) -> () } tt.return } @@ -533,13 +533,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @two_dots - tt.func public @two_dots(%arg0: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: tensor<128x128x!tt.ptr, #blocked1> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<128x128x!tt.ptr, #blocked1> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg4: i32) { + tt.func public @two_dots(%arg0: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg4: i32) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -555,21 +554,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} %5 = ttg.local_alloc %4 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} - %6 = tt.load %arg2 : tensor<128x128x!tt.ptr, #blocked1> + %6 = tt.load %arg2 : tensor<128x128x!tt.ptr, #blocked> // CHECK: ttng.tmem_store {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} - %store_tok0 = ttng.tmem_store %6, %0[%tok0], %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %store_tok0 = ttng.tmem_store %6, %0[%tok0], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok0 = ttng.tc_gen5_mma %3, %5, %0[%store_tok0], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32} - %7, %load_tok0 = ttng.tmem_load %0[%mma_tok0] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %7, %load_tok0 = ttng.tmem_load %0[%mma_tok0] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> // CHECK: ttng.tmem_store {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32} - %store_tok1 = ttng.tmem_store %7, %1[%tok1], %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %store_tok1 = ttng.tmem_store %7, %1[%tok1], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32, tt.self_latency = 1 : i32} %mma_tok1 = ttng.tc_gen5_mma %3, %5, %1[%store_tok1], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 4 : i32} - %8, %load_tok1 = ttng.tmem_load %1[%mma_tok1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %8, %load_tok1 = ttng.tmem_load %1[%mma_tok1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> // CHECK: tt.store {{.*}} {loop.cluster = 0 : i32, loop.stage = 4 : i32} - tt.store %arg3, %8 : tensor<128x128x!tt.ptr, #blocked1> + tt.store %arg3, %8 : tensor<128x128x!tt.ptr, #blocked> scf.yield %load_tok0, %load_tok1 : !ttg.async.token, !ttg.async.token } tt.return @@ -578,31 +577,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: @tc_gen5_mma tt.func @tc_gen5_mma(%lb : index, %ub : index, %step : index, - %A_ptr: tensor<128x128x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, - %B: tensor<128x128xf16, #blocked1>, + %A_ptr: tensor<128x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, + %B: tensor<128x128xf16, #blocked>, %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>) -> () { %true = arith.constant true scf.for %iv = %lb to %ub step %step : index { // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} - %A = tt.load %A_ptr {tt.latency = 2 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A = tt.load %A_ptr {tt.latency = 2 : i32} : tensor<128x128x!tt.ptr, #blocked> // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} - %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} - %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok = ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm[], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #tmem> + %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked> // CHECK: "use"{{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - "use"(%c) : (tensor<128x128xf32, #tmem>) -> () + "use"(%c) : (tensor<128x128xf32, #blocked>) -> () } tt.return } @@ -610,31 +608,30 @@ tt.func @tc_gen5_mma(%lb : index, %ub : index, %step : index, // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: @tc_gen5_mma_if_user tt.func @tc_gen5_mma_if_user(%lb : index, %ub : index, %step : index, - %A_ptr: tensor<128x128x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, - %B: tensor<128x128xf16, #blocked1>, + %A_ptr: tensor<128x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, + %B: tensor<128x128xf16, #blocked>, %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, %cnd: i1) -> () { %true = arith.constant true scf.for %iv = %lb to %ub step %step : index { // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} - %A = tt.load %A_ptr {tt.latency = 2 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A = tt.load %A_ptr {tt.latency = 2 : i32} : tensor<128x128x!tt.ptr, #blocked> // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} - %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} - %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok = ttng.tc_gen5_mma %A_sh, %B_sh, %acc_tm[], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> scf.if %cnd { - %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #tmem> - "use"(%c) : (tensor<128x128xf32, #tmem>) -> () + %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked> + "use"(%c) : (tensor<128x128xf32, #blocked>) -> () } // CHECK: scf.if // CHECK: tmem_load @@ -648,8 +645,7 @@ tt.func @tc_gen5_mma_if_user(%lb : index, %ub : index, %step : index, // ----- -#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [4, 3, 2, 1, 0]}> #tmem = #ttng.tensor_memory_encoding @@ -657,25 +653,25 @@ tt.func @tc_gen5_mma_if_user(%lb : index, %ub : index, %step : index, module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} { // CHECK-LABEL: @tc_gen5_mma_scaled tt.func @tc_gen5_mma_scaled(%lb : index, %ub : index, %step : index, - %A_ptr: tensor<128x128x!tt.ptr, #blocked1> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, - %B: tensor<128x128xf16, #blocked1>, + %A_ptr: tensor<128x128x!tt.ptr, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, + %B: tensor<128x128xf16, #blocked>, %A_sc_sh: !ttg.memdesc<1x2x32x4x4xi8, #shared1, #ttg.shared_memory>, %B_sc_sh: !ttg.memdesc<1x2x32x4x4xi8, #shared1, #ttg.shared_memory>, %acc_tm : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>) -> () { %true = arith.constant true scf.for %iv = %lb to %ub step %step : index { // CHECK: tt.load {{.*}} {loop.cluster = 3 : i32, loop.stage = 0 : i32} - %A = tt.load %A_ptr {tt.latency = 2 : i32} : tensor<128x128x!tt.ptr, #blocked1> + %A = tt.load %A_ptr {tt.latency = 2 : i32} : tensor<128x128x!tt.ptr, #blocked> // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} - %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %A_sh = ttg.local_alloc %A : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32} - %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked1>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> + %B_sh = ttg.local_alloc %B : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory> // CHECK: ttng.tc_gen5_mma_scaled {{.*}} {loop.cluster = 1 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok = ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm[], %A_sc_sh, %B_sc_sh, %true, %true lhs = e5m2 rhs = e5m2 {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #ttg.shared_memory>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #ttg.shared_memory> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #tmem> + %c, %load_tok = ttng.tmem_load %acc_tm[%mma_tok] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory> -> tensor<128x128xf32, #blocked> // CHECK: "use"{{.*}} {loop.cluster = 0 : i32, loop.stage = 3 : i32} - "use"(%c) : (tensor<128x128xf32, #tmem>) -> () + "use"(%c) : (tensor<128x128xf32, #blocked>) -> () } tt.return } @@ -757,13 +753,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- #blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> -#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> #smem = #ttg.shared_memory #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: @two_dots - tt.func public @two_dots(%arg0: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: tensor<128x128x!tt.ptr, #blocked1> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<128x128x!tt.ptr, #blocked1> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg4: i32) { + tt.func public @two_dots(%arg0: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg1: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg2: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg3: tensor<128x128x!tt.ptr, #blocked> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}, %arg4: i32) { %true = arith.constant true %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 @@ -779,21 +774,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK: ttg.local_alloc {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} %5 = ttg.local_alloc %4 : (tensor<128x128xf16, #blocked>) -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> // CHECK: tt.load {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} - %6 = tt.load %arg2 : tensor<128x128x!tt.ptr, #blocked1> + %6 = tt.load %arg2 : tensor<128x128x!tt.ptr, #blocked> // CHECK: ttng.tmem_store {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32} - %store_tok0 = ttng.tmem_store %6, %0[%tok0], %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %store_tok0 = ttng.tmem_store %6, %0[%tok0], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 2 : i32, loop.stage = 2 : i32, tt.self_latency = 1 : i32} %mma_tok0 = ttng.tc_gen5_mma %3, %5, %0[%store_tok0], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32} - %7, %load_tok0 = ttng.tmem_load %0[%mma_tok0] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %7, %load_tok0 = ttng.tmem_load %0[%mma_tok0] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> // CHECK: ttng.tmem_store {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32} - %store_tok1 = ttng.tmem_store %7, %1[%tok1], %true : tensor<128x128xf32, #blocked1> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %store_tok1 = ttng.tmem_store %7, %1[%tok1], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tc_gen5_mma {{.*}} {loop.cluster = 1 : i32, loop.stage = 3 : i32, tt.self_latency = 1 : i32} %mma_tok1 = ttng.tc_gen5_mma %3, %5, %1[%store_tok1], %true, %true {tt.latency = 1 : i32, tt.self_latency = 1 : i32} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf16, #shared, #smem, mutable>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: ttng.tmem_load {{.*}} {loop.cluster = 0 : i32, loop.stage = 4 : i32} - %8, %load_tok1 = ttng.tmem_load %1[%mma_tok1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked1> + %8, %load_tok1 = ttng.tmem_load %1[%mma_tok1] : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> // CHECK: tt.store {{.*}} {loop.cluster = 0 : i32, loop.stage = 4 : i32} - tt.store %arg3, %8 : tensor<128x128x!tt.ptr, #blocked1> + tt.store %arg3, %8 : tensor<128x128x!tt.ptr, #blocked> scf.yield %load_tok0, %load_tok1 : !ttg.async.token, !ttg.async.token } tt.return diff --git a/test/TritonNvidiaGPU/canonicalize.mlir b/test/TritonNvidiaGPU/canonicalize.mlir index 3479c3851e..825858b3e4 100644 --- a/test/TritonNvidiaGPU/canonicalize.mlir +++ b/test/TritonNvidiaGPU/canonicalize.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -canonicalize | FileCheck %s -#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0], [0, 0]], block = []}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0], [64, 0]], block = []}> #tmem_scales = #ttng.tensor_memory_scales_encoding<> #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} { diff --git a/test/TritonNvidiaGPU/invalid.mlir b/test/TritonNvidiaGPU/invalid.mlir index 7185b67809..d6b2ac4672 100644 --- a/test/TritonNvidiaGPU/invalid.mlir +++ b/test/TritonNvidiaGPU/invalid.mlir @@ -24,7 +24,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- -#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> + #tmem = #ttng.tensor_memory_encoding module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { tt.func public @alloc_tensor_memory() { @@ -40,12 +41,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // ----- #shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> -#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#scales = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}> #tmem = #ttng.tensor_memory_scales_encoding<> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 65536 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} { tt.func public @alloc_tensor_memory(%arg: !ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>) { - %cst = arith.constant dense<0> : tensor<128x4xi8, #blocked> - %0 = ttng.tmem_alloc %cst : (tensor<128x4xi8, #blocked>) -> !ttg.memdesc<128x4xi8, #tmem, #ttng.tensor_memory> + %cst = arith.constant dense<0> : tensor<128x4xi8, #scales> + %0 = ttng.tmem_alloc %cst : (tensor<128x4xi8, #scales>) -> !ttg.memdesc<128x4xi8, #tmem, #ttng.tensor_memory> // expected-error @+1 {{Cannot copy into an immutable alloc}} ttng.tmem_copy %arg, %0, : (!ttg.memdesc<1x512xi8, #shared1, #ttg.shared_memory, mutable>, !ttg.memdesc<128x4xi8, #tmem, #ttng.tensor_memory>) -> () tt.return diff --git a/test/TritonNvidiaGPU/membar.mlir b/test/TritonNvidiaGPU/membar.mlir index 67125519a8..aed4fd2250 100644 --- a/test/TritonNvidiaGPU/membar.mlir +++ b/test/TritonNvidiaGPU/membar.mlir @@ -82,7 +82,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // ----- -#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { @@ -103,7 +103,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // ----- -#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#shared0 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #blocked0 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #smem = #ttg.shared_memory module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 254ea42b47..85ebbc81f8 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -10,7 +10,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td" //===----------------------------------------------------------------------===// def DpasEncodingAttr : DistributedEncoding<"DpasEncoding", "intel_dpas_encoding", - [MmaEncodingTrait], TritonIntelGPU_Dialect> { + [MmaEncodingTrait, DeclareLayoutEncodingMethods], TritonIntelGPU_Dialect> { let mnemonic = "dpas"; let description = [{ @@ -254,7 +254,7 @@ The semantic of this `tt.dot` includes GEMM tiling configuration as: //===----------------------------------------------------------------------===// def WarpEncodingAttr : DistributedEncoding<"WarpEncoding", "intel_warp_encoding", - [], TritonIntelGPU_Dialect> { + [DeclareLayoutEncodingMethods], TritonIntelGPU_Dialect> { let mnemonic = "warp"; let description = [{ diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index d4f245cd4f..c42461abff 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -635,18 +635,6 @@ SmallVector Subgroup2DBlockEncodingAttr::getRepOrder() const { return getMatrixOrder(getRank(), /*rowMajor*/ true); } -SmallVector Subgroup2DBlockEncodingAttr::getCTAsPerCGA() const { - return SmallVector(getCTALayout().getCTAsPerCGA()); -} - -SmallVector Subgroup2DBlockEncodingAttr::getCTAOrder() const { - return SmallVector(getCTALayout().getCTAOrder()); -} - -SmallVector Subgroup2DBlockEncodingAttr::getCTASplitNum() const { - return SmallVector(getCTALayout().getCTASplitNum()); -} - SmallVector Subgroup2DBlockEncodingAttr::getRepOrderForOperand(int opIdx) const { return getOrderForDotOperand(opIdx, getRank(), /*kContig*/ true); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp index d0805cfb78..2c44ed44ec 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp @@ -479,16 +479,15 @@ static int getContextualMaxNReg(Operation *op) { return maxnreg; } -static void lowerStoreToTensorMemory(Location loc, Operation *op, Value src, - Value dest, Value llSrc, Value pred, - Value tmemBase, +static void lowerStoreToTensorMemory(Location loc, Operation *op, + TypedValue src, + TypedValue dest, Value llSrc, + Value pred, Value tmemBase, ConversionPatternRewriter &rewriter) { auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector srcValues = unpackLLElements(loc, llSrc, rewriter); srcValues = packToI32(srcValues, loc, rewriter); - auto dstType = cast(dest.getType()); - auto info = getTMemRuntimeInfo(op, cast(src.getType()), - cast(dest.getType())); + auto info = getTMemRuntimeInfo(op, src.getType(), dest.getType()); const TMemMessageTraits message = selectTMemMessage(info, getContextualMaxNReg(op)); int regIdx = 0;