diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h index 896bf7316a..28b5ac824f 100644 --- a/include/triton/Dialect/Triton/IR/Utility.h +++ b/include/triton/Dialect/Triton/IR/Utility.h @@ -1,6 +1,8 @@ #ifndef TRITON_IR_UTILITY_H_ #define TRITON_IR_UTILITY_H_ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include #include @@ -10,6 +12,14 @@ namespace mlir { // Bitwidth of pointers constexpr int kPtrBitWidth = 64; +// Returns the bit width of a type, treating pointer-like types as 64-bit. +// This handles LLVM dialect pointer types. +inline int getIntOrFloatOrPtrBitWidth(Type type) { + if (isa(type)) + return kPtrBitWidth; + return type.getIntOrFloatBitWidth(); +} + template SmallVector convertType(ArrayRef in) { SmallVector out; for (const auto &i : in) diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index bedee8e604..fed4ded91f 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -137,10 +137,9 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, ArrayRef tilesPerWarp, ArrayRef warpsPerCTA); -LinearLayout chooseScaledWmmaScaleLayout( - MLIRContext *ctx, int dotOperandIdx, - const std::vector> &dotOperandWarpBasis, - ArrayRef dotOperandShape); +LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef warpsPerCTA, + ArrayRef dotOperandShape); LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, ArrayRef shape, int opIdx, diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index f3f56e1764..5a9ded6bf7 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -184,6 +184,13 @@ getLastUseOfPipelinedOp(ArrayRef ops, scf::ForOp forOp, // Clean up attributes passing over schedules across stages in pipelining void removePipeliningAttributes(ModuleOp moduleOp); + +// For LoadOp, DescriptorLoad, and DescriptorGather ops, determine if +// they should be pipelined. +bool isPipeliningBeneficial(Operation *op, + triton::ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool filterSmall = true); + } // namespace triton } // namespace mlir diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 313dbd41b6..983b6645b8 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -152,7 +152,8 @@ class AllocationAnalysis { auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType); numElems = product(shapePerCTA); } - int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8; + int64_t bytes = + numElems * getIntOrFloatOrPtrBitWidth(allocType.getElementType()) / 8; auto alignment = alloc.getAlignmentOrDefault(); allocation->addBuffer(alloc, bytes, diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 336667d129..89b2a35ff1 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { auto lhsInfo = operands[0]->getValue(); auto rhsInfo = operands[1]->getValue(); auto rank = lhsInfo.getRank(); + assert(isa(op.getType()) || + rank == 1 && "Expected ranked tensor or scalar"); assert(operands.size() == 2 && "Expected two operands"); + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + if (constantValue.has_value()) { + auto resTy = dyn_cast(op.getType()); + AxisInfo::DimVectorT constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + AxisInfo::DimVectorT contiguity(rank, 1); + AxisInfo::DimVectorT divisibility( + rank, highestPowOf2Divisor(constantValue.value())); + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; - auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); for (auto d = 0; d < rank; ++d) { - if (constantValue.has_value()) { - contiguity.push_back(1); - constancy.push_back( - std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); - divisibility.push_back( - highestPowOf2Divisor(constantValue.value())); - } else { - contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); - constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); - divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); - } + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); } return AxisInfo(contiguity, divisibility, constancy, constantValue); } @@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) { - return 1; + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } - virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) { return {}; @@ -192,6 +194,26 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { } }; +class UnrealizedConversionCastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl< + mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(mlir::UnrealizedConversionCastOp op, + ArrayRef *> operands) override { + auto tensorType = dyn_cast(op.getResultTypes()[0]); + if (tensorType && + tensorType.getRank() != operands[0]->getValue().getRank()) { + // Do not propagate AxisInfo with incorrect rank. This can cause a crash + // in future visitor applications. + return AxisInfo::getPessimisticValueState(op->getResult(0)); + } + return operands[0]->getValue(); + } +}; + class MakeRangeOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: @@ -254,7 +276,7 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl { getAxisInfo(ub::PoisonOp op, ArrayRef *> operands) override { unsigned rank = 1; - if (auto shape = dyn_cast(op.getType())) + if (auto shape = dyn_cast(op.getType())) rank = shape.getRank(); // Poison values are never accessed, thus assume optimistic values. @@ -308,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { return gcd(lhs.getDivisibility(dim), rhsDivisibility); } - int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, - int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -355,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { return std::max(lhsContiguity, rhsContiguity); } - int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, - const AxisInfo &rhs, int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto lhsDivisibility = lhs.getDivisibility(dim); @@ -379,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs) override { - if (lhs.getConstantValue().has_value() && - rhs.getConstantValue().has_value()) - return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + auto lhsConst = lhs.getConstantValue(); + auto rhsConst = rhs.getConstantValue(); + if (lhsConst.has_value() && rhsConst.has_value()) + return {lhsConst.value() * rhsConst.value()}; + if ((lhsConst.has_value() && lhsConst.value() == 0) || + (rhsConst.has_value() && rhsConst.value() == 0)) + return 0; return {}; } }; @@ -404,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto resTy = dyn_cast(op.getType()); + auto constancy = BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); if (!resTy) - return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + return constancy; auto shape = resTy.getShape(); - // Case 1: both lhs and rhs are constants. - auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - // Case 2: lhs contiguous, rhs constant. + // Case: lhs contiguous, rhs constant. // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), @@ -506,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { + auto constancy = BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); auto resTy = dyn_cast(op.getType()); if (!resTy) - return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); - auto shape = resTy.getShape(); - // lhs % 1 = 0 - return rhs.getConstantValue().has_value() && - rhs.getConstantValue().value() == 1 - ? shape[dim] - : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + return constancy; + // Case: lhs % 1 = 0 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return resTy.getDimSize(dim); + return constancy; } std::optional getConstantValue(OpTy op, const AxisInfo &lhs, @@ -669,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { int64_t constHint = 1; if (lhsInfo.getConstantValue().has_value() && rhsInfo.getConstantValue().has_value()) { - constHint = lhsInfo.getConstancy(d); + constHint = shape[d]; constantValue = compare(getPredicate(op), lhsInfo.getConstantValue().value(), rhsInfo.getConstantValue().value()) @@ -828,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { rhsInfo.getConstantValue().has_value() && lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) constantValue = lhsInfo.getConstantValue(); + + if (constantValue.has_value()) { + auto resTy = dyn_cast(op.getType()); + assert(resTy || rank == 1); + constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + } } return AxisInfo(contiguity, divisibility, constancy, constantValue); @@ -840,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: - int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, - int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -890,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { return multiplyDivisor(lhsDivisibility, 1ll << shift); } - int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, - const AxisInfo &rhs, int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -932,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { return std::max(1, lhsDivisibility / (int64_t(1) << shift)); } - int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, - int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -969,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { constantValue = {std::min(lhsInfo.getConstantValue().value(), rhsInfo.getConstantValue().value())}; } + auto resTy = dyn_cast(op.getType()); + assert(resTy || rank == 1); + AxisInfo::DimVectorT constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + AxisInfo::DimVectorT divisibility( + rank, highestPowOf2Divisor(constantValue.value())); return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), - /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), - /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/divisibility, + /*knownConstancy=*/constancy, /*constantValue=*/constantValue); } else { AxisInfo::DimVectorT contiguity, divisibility, constancy; @@ -1029,11 +1042,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver, // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is // in the process of a PartialConversion, where UnrealizedConversionCast // may exist + visitors.append(); visitors.append, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, - CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); visitors.append(); visitors.append(); @@ -1214,6 +1227,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) { return rhs; if (rhs.getRank() == 0) return lhs; + assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks"); DimVectorT contiguity; DimVectorT divisibility; DimVectorT constancy; @@ -1384,7 +1398,10 @@ void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, callee.setArgAttr(index, attrName, attr); }; auto axisInfo = axisInfoMap->lookup(value); - assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + // Only scalar arguments are supported. Do not forward multi-dimensional + // AxisInfo to the callee. + if (axisInfo.getRank() != 1) + continue; setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); setAttrFn("tt.constancy", axisInfo.getConstancy(0)); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 24f7a444a7..73c913b176 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -26,8 +26,6 @@ struct ConvertLayoutOpConversion : public ConvertOpToLLVMPattern { const TargetInfoBase &targetInfo; - // Set benefit to 2 so that this pattern applies before other convert-layout - // conversions. TODO(jlebar): Eventually we want this to be the only pattern. explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit = 1) @@ -277,8 +275,7 @@ struct ConvertLayoutOpConversion StringAttr kReg = str_attr("register"); StringAttr kLane = str_attr("lane"); auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); - int bitwidth = - elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : kPtrBitWidth; + int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy); auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, bitwidth); auto &[pReg, pLane, mixedTranspositions, nPack] = factors; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 37f20547c2..e1d03bc1ca 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -276,7 +276,7 @@ struct ElementwiseInlineAsmOpConversion auto ty = getTypeConverter()->convertType(getElementType(result)); // Pack return elements into 32-bits. - unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; + unsigned bitWidth = getIntOrFloatOrPtrBitWidth(ty); unsigned numElemsPerReg = std::min(std::max(32 / bitWidth, 1u), op.getPackedElement()); assert(op.getPackedElement() % numElemsPerReg == 0); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index b545cff52c..97bd6d4cb0 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -540,7 +540,7 @@ SmallVector lowerLdSt( auto kLane = str_attr("lane"); auto kWarp = str_attr("warp"); auto kOffset = str_attr("offset"); - auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); auto [elemsPerVec, permutation] = largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems); @@ -625,7 +625,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx, assert(*cvt.getOutDimNames().begin() == str_attr("offset")); auto calcPaddedOffset = [&](Value smemOffset) { TritonLLVMOpBuilder b(loc, rewriter); - auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); if (auto paddedEnc = dyn_cast( srcTy.getEncoding())) { // Apply the offset needed for padding. diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 16f3215b70..ab5d809bdf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -11,6 +11,16 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; namespace { + +Value bitOrPtrCast(Value val, Type type, TritonLLVMOpBuilder &b) { + if (isa(val.getType()) && + !isa(type)) { + return b.ptrtoint(type, val); + } else { + return b.bitcast(val, type); + } +} + struct SplatOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a @@ -39,13 +49,13 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern { unsigned ratio = srcBitWidth / cstBitWidth; Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); VectorType vecType = VectorType::get(ratio, intTy); - Value intCst = b.bitcast(constVal, intTy); + Value intCst = bitOrPtrCast(constVal, intTy, b); Value vec = b.undef(vecType); for (unsigned i = 0; i < ratio; ++i) vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i)); constVal = vec; } - auto llSrc = b.bitcast(constVal, srcType); + Value llSrc = bitOrPtrCast(constVal, srcType, b); size_t elemsPerThread = getTotalElemsPerThread(tensorTy); llvm::SmallVector elems(elemsPerThread, llSrc); return packLLElements(loc, typeConverter, elems, rewriter, resType); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index ac6866157a..c2e89dd8ea 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1438,64 +1438,54 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef shape, } } -LinearLayout chooseScaledWmmaScaleLayout( - MLIRContext *ctx, int dotOperandIdx, - const std::vector> &dotOperandWarpBasis, - ArrayRef dotOperandShape) { +LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef warpsPerCTA, + ArrayRef dotOperandShape) { using basisT = std::vector>; unsigned rank = dotOperandShape.size(); auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true); - auto standardOutDims = standardOutDimNames(ctx, rank); + auto outDimNames = standardOutDimNames(ctx, rank); + StringAttr kRegister = StringAttr::get(ctx, "register"); StringAttr kLane = StringAttr::get(ctx, "lane"); StringAttr kWarp = StringAttr::get(ctx, "warp"); StringAttr kBlock = StringAttr::get(ctx, "block"); - unsigned int scaleKWidth = dotOperandShape[1]; - // Init register layout. Will be adjusted later - auto regs = - mlir::triton::identityStandardND(kRegister, {1, scaleKWidth}, order); - LinearLayout lanes = LinearLayout::empty(); + // In scaled dot, the shapes of operands(without batch dimension) are, // respectively: // - A: [M, K] // - B: [K, N] // - aScale: [M, K / 32 or 16] // - bScale: [N, K / 32 or 16] - // - // To correctly feed A/B and its scale into instruction, we need to - // distribute aScale/bScale among warps in the same way as A/B. But bScale - // is not transposed like B. So we need to transpose the warp layout of - // bScale. - // - // The tricky part is, our desired outputs are [dim0, dim1], but - // at this position, the layouts are transposed to [dim1, dim0]. So - // instead of reverse bScale's layout, we need to reverse aScale's. There - // will be a transpose in the end to correct everything. - basisT warps = dotOperandWarpBasis; - if (dotOperandIdx == 0) { - for (auto &basis : warps) { - std::reverse(basis.begin(), basis.end()); - } - } + auto dimK = outDimNames[order[0]]; + auto dimNonK = outDimNames[order[1]]; - lanes = LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}}, - {kWarp, warps}, - {kBlock, {}}}, - {standardOutDims[order[0]], standardOutDims[order[1]]}); - LinearLayout newLL = regs * lanes; + // Each lane holds kWidth=4 consecutive values along the k dim. + // The first 16 lanes are distributed along the non-k dim. We are not using + // the remaining 16 lanes, so just let them duplicate values of the first 16 + // lanes. If the shape along the k dim is larger than kWidth, repeat this + // pattern to fill the k dim. + unsigned scaleKWidth = 4; + auto kSize = dotOperandShape[1]; + LinearLayout tileLayout = + LinearLayout::identity1D(scaleKWidth, kRegister, dimK) * + LinearLayout::identity1D(16, kLane, dimNonK) * + LinearLayout::zeros1D(2, kLane, dimK) * + LinearLayout::identity1D(kSize / scaleKWidth, kRegister, dimK); - // Adjust register-level layout to fill the shape, at this level, both - // aScale and bScale should align with A operand. - SmallVector repOrder = {1, 0}; - for (auto d : repOrder) { - auto outDim = standardOutDims[d]; - auto dimSize = newLL.getOutDimSize(outDim); - newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister, - outDim); - } - newLL = newLL.transposeOuts(standardOutDims); + auto warpsPerCTANew = (dotOperandIdx == 1) + ? SmallVector{warpsPerCTA[1], warpsPerCTA[0]} + : SmallVector{warpsPerCTA[0], warpsPerCTA[1]}; + + auto warpOrder = (dotOperandIdx == 1) ? SmallVector{0, 1} + : SmallVector{1, 0}; + LinearLayout warpLayout = + identityStandardND(kWarp, warpsPerCTANew, warpOrder); + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); - return newLL; + return combineCtaCgaWithShape( + ctaLayout, CTALayoutAttr::getDefault(ctx, /*rank=*/2), dotOperandShape); } // PTX ISA - Warp-level MMA Block Scaling diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index b3bfcd82bb..d0c4d6c810 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -809,6 +809,11 @@ LogicalResult MemDescIndexOp::verify() { return emitError("src and dst must have the same type of encoding"); } + if (dstTy.getAllocShape() != dstTy.getShape() || + srcTy.getAllocShape() != srcTy.getShape()) { + return emitError("alloc shape must match shape for both result and src"); + } + if (isa(srcEnc)) { // We support only 3D -> 2D subviews with only first offset being non-zero. if (srcTy.getRank() != 3 || dstTy.getRank() != 2) { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp index 2789a9352e..bbce977049 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp @@ -88,64 +88,6 @@ class AssignLoadLatencies { scf::ForOp forOp; int numStages; DenseMap &opLatency; - -public: - static bool canHaveSharedEncoding(tt::LoadOp op) { - // If used by an user with DotOp encoding, all the uses must be compatible. - bool incompatible = false; - getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible); - return !incompatible; - } - - static bool - isPipeliningBeneficial(Operation *op, Operation *finalUser, - tt::ModuleAxisInfoAnalysis &axisInfoAnalysis, - bool filterSmall) { - if (auto loadOp = dyn_cast(op)) { - if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) { - LDBG("Load " << *loadOp << " is too small for pipelining"); - return false; - } - } - if (isa(op)) - return true; - if (!canHaveSharedEncoding(cast(op))) { - LDBG("Load " << *op << " cannot have shared encoding"); - return false; - } - - ttg::SharedEncodingTrait localAllocEnc; - if (llvm::any_of(op->getUsers(), [&](Operation *user) { - return isa(user); - })) { - for (auto user : op->getUsers()) { - auto localAlloc = dyn_cast(user); - if (!localAlloc) - continue; - auto enc = mlir::cast( - localAlloc.getType().getEncoding()); - if (!localAllocEnc) { - localAllocEnc = enc; - } - if (enc != localAllocEnc) { - // If the load is used by a LocalAllocOp, all the users need to have - // the same encoding. - return false; - } - } - } - - if (localAllocEnc) { - auto registerTy = cast(op->getResultTypes()[0]); - auto vecBytes = getCopyVecBytes(registerTy, localAllocEnc); - if (filterSmall && vecBytes < 4) { - // At least 4 bytes need to be consecutive for cp.async - return false; - } - } - - return true; - } }; class AssignMMALatencies { @@ -280,8 +222,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, if (!seen.insert(op).second || excluded.count(op)) return; if (isa(op)) { - if (!AssignLoadLatencies::isPipeliningBeneficial( - op, finalUser, axisInfoAnalysis, filterSmall)) + if (!isPipeliningBeneficial(op, axisInfoAnalysis, filterSmall)) return; if (loadOpToIndLevel.count(op)) { int level = loadOpToIndLevel[op].first; diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp index ea6049ab9b..2f9c76336b 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -453,26 +453,17 @@ scf::ForOp lowerLoads(scf::ForOp forOp, CoarseSchedule &schedule, continue; } SharedEncodingTrait sharedEncoding; - bool canUseAsyncCp = false; - if (!isa(op.getResultTypes()[0])) { - canUseAsyncCp = op.getResultTypes()[0].getIntOrFloatBitWidth() >= 32; - sharedEncoding = ttg::SwizzledSharedEncodingAttr::get( - forOp.getContext(), 1, 1, 1, {0}, - ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0})); - if (canUseAsyncCp) { + bool canUseAsyncCp = + triton::isPipeliningBeneficial(&op, axisInfoAnalysis); + if (canUseAsyncCp) { + if (!isa(op.getResultTypes()[0])) { + sharedEncoding = ttg::SwizzledSharedEncodingAttr::get( + forOp.getContext(), 1, 1, 1, {0}, + ttg::CTALayoutAttr::get(forOp.getContext(), {1}, {1}, {0})); scalarLoads.push_back(&op); + } else { + sharedEncoding = getSharedEncoding(&op); } - } else { - sharedEncoding = getSharedEncoding(&op); - // Do not create async loads for small loads (cp.async requires at least - // 4 bytes) - canUseAsyncCp = - isa(op) && - canBeConvertedToAsyncLoad(cast(op), axisInfoAnalysis); - int copyVecBytes = getCopyVecBytes( - cast(op.getResultTypes()[0]), sharedEncoding); - - canUseAsyncCp &= copyVecBytes >= 4; } if (canUseAsyncCp || isTMALoad(&op)) { if (loadRequiresAdditionalBuffer(&op)) { diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index fd4f26120f..70291738ba 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -603,6 +603,10 @@ ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(RankedTensorType ty) { } ttg::SharedEncodingTrait mlir::triton::getSharedEncoding(Operation *op) { + if (!isa(op->getResultTypes()[0])) { + return nullptr; + } + // Try to use local alloc encoding if possible. ttg::SharedEncodingTrait localAllocEnc; if (llvm::any_of(op->getUsers(), [&](Operation *user) { @@ -683,8 +687,7 @@ triton::createSingleBufferView(OpBuilder &builder, Value alloc, Value idx) { allocDescType.getShape().end()); auto viewDescType = ttg::MemDescType::get( shape, allocDescType.getElementType(), allocDescType.getEncoding(), - allocDescType.getMemorySpace(), allocDescType.getMutableMemory(), - /*allocShape=*/allocDescType.getAllocShape()); + allocDescType.getMemorySpace(), allocDescType.getMutableMemory()); return builder.create(alloc.getLoc(), viewDescType, alloc, idx); } @@ -933,3 +936,38 @@ void triton::removePipeliningAttributes(ModuleOp moduleOp) { op->removeAttr(mlir::triton::kScheduledMaxStageAttrName); }); } + +static bool canHaveSharedEncoding(tt::LoadOp op) { + // If used by an user with DotOp encoding, all the uses must be compatible. + bool incompatible = false; + getSharedEncIfAllUsersAreDotEnc(op.getResult(), incompatible); + return !incompatible; +} + +bool triton::isPipeliningBeneficial( + Operation *op, tt::ModuleAxisInfoAnalysis &axisInfoAnalysis, + bool filterSmall) { + if (auto loadOp = dyn_cast(op)) { + if (filterSmall && !canBeConvertedToAsyncLoad(loadOp, axisInfoAnalysis)) { + LDBG("Load " << *loadOp << " is too small for pipelining"); + return false; + } + } + if (isa(op)) + return true; + if (!canHaveSharedEncoding(cast(op))) { + LDBG("Load " << *op << " cannot have shared encoding"); + return false; + } + + if (auto localAllocEnc = getSharedEncoding(op)) { + auto registerTy = cast(op->getResultTypes()[0]); + auto vecBytes = mlir::triton::getCopyVecBytes(registerTy, localAllocEnc); + if (filterSmall && vecBytes < 4) { + // At least 4 bytes need to be consecutive for cp.async + return false; + } + } + + return true; +} diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 2279741404..a46717238d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1566,7 +1566,7 @@ void replaceUsesAndPropagateType( bool isMutable = cast(val.getType()).getMutableMemory(); Type newDstType = ttg::MemDescType::get( oldType.getShape(), oldType.getElementType(), oldType.getEncoding(), - oldType.getMemorySpace(), isMutable, oldType.getAllocShape()); + oldType.getMemorySpace(), isMutable); newVal = builder.create(subview.getLoc(), newDstType, val, subview.getIndex()); } else if (auto subslice = dyn_cast(user)) { diff --git a/python/examples/gluon/01-attention-forward.py b/python/examples/gluon/01-attention-forward.py index 3770c8e009..15c2f3f9c4 100644 --- a/python/examples/gluon/01-attention-forward.py +++ b/python/examples/gluon/01-attention-forward.py @@ -52,10 +52,11 @@ class BarrierCounter: phase: gl.tensor num_barriers: gl.constexpr + @gluon.constexpr_function def __init__(self, index, phase, num_barriers): self.index = index self.phase = phase - self.num_barriers = num_barriers + self.num_barriers = gl.constexpr(num_barriers) @gluon.must_use_result @gluon.jit @@ -79,6 +80,7 @@ class ChannelType: num_buffers: gl.constexpr num_consumers: gl.constexpr + @gluon.constexpr_function def __init__(self, mem, ready_bars, empty_bars, num_buffers, num_consumers): self.mem = mem self.ready_bars = ready_bars @@ -143,6 +145,7 @@ class Producer: channel: ChannelType counter: BarrierCounter + @gluon.constexpr_function def __init__(self, channel, counter): self.channel = channel self.counter = counter @@ -158,6 +161,7 @@ class Consumer: channel: ChannelType counter: BarrierCounter + @gluon.constexpr_function def __init__(self, channel, counter): self.channel = channel self.counter = counter @@ -234,6 +238,7 @@ class AttentionConfig: num_kv_buffers: gl.constexpr use_exp2_turnstile: gl.constexpr + @gluon.constexpr_function def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE, dtype, num_warps): self.qk_scale = qk_scale @@ -250,7 +255,7 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE self.num_warps = gl.constexpr(num_warps) self.SPLIT_D_FACTOR = gl.constexpr(2) - self.SPLIT_EXP_FACTOR = 256 // HEAD_DIM + self.SPLIT_EXP_FACTOR = gl.constexpr(256 // HEAD_DIM) self.SPLIT_QK_LOAD_FACTOR = gl.constexpr(2 if STAGE == 1 else 1) self.SPLIT_M = gl.constexpr(self.BLOCK_M // 2) self.SPLIT_D = gl.constexpr(self.HEAD_DIM // self.SPLIT_D_FACTOR) @@ -305,6 +310,7 @@ class ProgramScheduler: num_pid_in_group: gl.tensor num_tiles: gl.tensor + @gluon.constexpr_function def __init__(self, config, start_pid, num_pid_n, num_pid_in_group, num_tiles): self.config = config self.start_pid = start_pid @@ -339,6 +345,7 @@ class AttentionProgram: offset_y: gl.tensor qo_offset_y: gl.tensor + @gluon.constexpr_function def __init__(self, config, start_m, off_hz, offset_y, qo_offset_y): self.config = config self.start_m = start_m @@ -840,12 +847,13 @@ def attention_kernel( # chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile) descs = (desc_q, desc_k, desc_v, desc_o) - gl.warp_specialize((config, chnls, descs, M, STAGE), _attn_fwd_correction, (config, chnls, descs, M, STAGE), [ - _attn_fwd_softmax0, - _attn_fwd_softmax1, - _attn_fwd_mma, - _attn_fwd_load, - _attn_fwd_epilogue, + gl.warp_specialize([ + (_attn_fwd_correction, (config, chnls, descs, M, STAGE)), + (_attn_fwd_softmax0, (config, chnls, descs, M, STAGE)), + (_attn_fwd_softmax1, (config, chnls, descs, M, STAGE)), + (_attn_fwd_mma, (config, chnls, descs, M, STAGE)), + (_attn_fwd_load, (config, chnls, descs, M, STAGE)), + (_attn_fwd_epilogue, (config, chnls, descs, M, STAGE)), ], [4, 4, 1, 1, 1], [192, 192, 24, 24, 24]) q_chnl.release() diff --git a/python/src/gluon_ir.cc b/python/src/gluon_ir.cc index a64e5b2fe0..46652a41bb 100644 --- a/python/src/gluon_ir.cc +++ b/python/src/gluon_ir.cc @@ -384,6 +384,16 @@ void init_gluon_ir(py::module &&m) { ctx, version, warpsPerCta, instrShape, transposed, ctaLayout, tilesPerWarp, elementBitWidth); }) + .def("get_amd_mfma_scale_layout", + [](GluonOpBuilder &self, unsigned opIdx, std::vector &shape, + unsigned mfmaMDim, std::vector &tilesPerWarp, + std::vector &warpsPerCTA) -> py::object { + auto ctx = self.getContext(); + auto ll = ttg::chooseScaledMfmaScaleLayout( + ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA); + auto attr = ttg::LinearEncodingAttr::get(ctx, ll); + return layoutToGluon(attr); + }) .def("get_amd_wmma_layout", [](GluonOpBuilder &self, unsigned version, bool transposed, std::vector &warpsPerCta, @@ -397,6 +407,15 @@ void init_gluon_ir(py::module &&m) { return ttg::AMDWmmaEncodingAttr::get( ctx, version, transposed, warpsPerCta, ctaLayout, instrShape); }) + .def("get_amd_wmma_scale_layout", + [](GluonOpBuilder &self, unsigned opIdx, std::vector &shape, + std::vector &warpsPerCTA) -> py::object { + auto ctx = self.getContext(); + auto ll = ttg::chooseScaledWmmaScaleLayout(ctx, opIdx, warpsPerCTA, + shape); + auto attr = ttg::LinearEncodingAttr::get(ctx, ll); + return layoutToGluon(attr); + }) .def("get_intel_dpas_layout", [](GluonOpBuilder &self, unsigned repeatCount, unsigned systolicDepth, unsigned executionSize, diff --git a/python/src/ir.h b/python/src/ir.h index b27c41a739..5cac33bab6 100644 --- a/python/src/ir.h +++ b/python/src/ir.h @@ -35,7 +35,7 @@ class TritonOpBuilder { if (!block.empty()) setLastLoc(block.begin()->getLoc()); else - setLastLoc(builder->getUnknownLoc()); + setLastLoc(getLocForBlock(&block)); builder->setInsertionPointToStart(&block); } @@ -43,7 +43,7 @@ class TritonOpBuilder { if (!block.empty()) setLastLoc(block.back().getLoc()); else - setLastLoc(builder->getUnknownLoc()); + setLastLoc(getLocForBlock(&block)); builder->setInsertionPointToEnd(&block); } @@ -53,10 +53,14 @@ class TritonOpBuilder { } void restoreInsertionPoint(mlir::OpBuilder::InsertPoint pt) { - if (pt.isSet() && pt.getPoint() != pt.getBlock()->end()) - setLastLoc(pt.getPoint()->getLoc()); - else - setLastLoc(builder->getUnknownLoc()); + setLastLoc(builder->getUnknownLoc()); + if (pt.isSet()) { + if (pt.getPoint() != pt.getBlock()->end()) + setLastLoc(pt.getPoint()->getLoc()); + else + setLastLoc(getLocForBlock(pt.getBlock())); + } + builder->restoreInsertionPoint(pt); } @@ -87,4 +91,10 @@ class TritonOpBuilder { std::unique_ptr lastLoc; bool lineInfoEnabled = !mlir::triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); + + mlir::Location getLocForBlock(mlir::Block *block) { + if (auto parentOp = block->getParentOp()) + return parentOp->getLoc(); + return builder->getUnknownLoc(); + } }; diff --git a/python/test/gluon/test_consan.py b/python/test/gluon/test_consan.py index d0ea6c3a23..ee9b10cc4c 100644 --- a/python/test/gluon/test_consan.py +++ b/python/test/gluon/test_consan.py @@ -748,8 +748,10 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout()) for i in range(2): mbarrier.init(bar.index(i), count=1) - ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout), - [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, FAILURE, blocked_layout)), + (ws_1, (smem, bar, FAILURE, blocked_layout)), + ], [4], [32]) mbarrier.wait(bar.index(1), phase=0) val = smem.index(0).load(blocked_layout) output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout) @@ -802,8 +804,10 @@ def ws_kernel(output, FAILURE: ttgl.constexpr): bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout()) for i in range(2): mbarrier.init(bar.index(i), count=1) - ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout), - [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, FAILURE, blocked_layout)), + (ws_1, (smem, bar, FAILURE, blocked_layout)), + ], [4], [32]) mbarrier.wait(bar.index(1), phase=0) val = smem.index(0).load(blocked_layout) output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout) @@ -865,8 +869,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): bar = ttgl.allocate_shared_memory(ttgl.int64, [3, 1], mbarrier.MBarrierLayout()) for i in range(3): mbarrier.init(bar.index(i), count=1) - ttgl.warp_specialize((smem, bar, MISSING_BAR, blocked_layout), ws_default, - (smem, bar, MISSING_BAR, blocked_layout), [ws_1, ws_2], [4, 4], [32, 32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, MISSING_BAR, blocked_layout)), + (ws_1, (smem, bar, MISSING_BAR, blocked_layout)), + (ws_2, (smem, bar, MISSING_BAR, blocked_layout)), + ], [4, 4], [32, 32]) mbarrier.wait(bar.index(2), phase=0) val = smem.index(0).load(blocked_layout) output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout) @@ -925,8 +932,11 @@ def kernel(output, FAILURE: ttgl.constexpr): bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=2) mbarrier.init(bar.index(1), count=1) - ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout), - [ws_1, ws_2], [4, 4], [32, 32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, FAILURE, blocked_layout)), + (ws_1, (smem, bar, FAILURE, blocked_layout)), + (ws_2, (smem, bar, FAILURE, blocked_layout)), + ], [4, 4], [32, 32]) mbarrier.wait(bar.index(1), phase=0) val = smem.index(0).load(blocked_layout) output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout) @@ -1013,8 +1023,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): mbarrier.arrive(bar.index(2), count=1) mbarrier.arrive(bar.index(3), count=1) - ttgl.warp_specialize((smem, bar, MISSING_BAR, blocked_layout), ws_default, - (smem, bar, MISSING_BAR, blocked_layout), [ws_1, ws_2], [4, 4], [32, 32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, MISSING_BAR, blocked_layout)), + (ws_1, (smem, bar, MISSING_BAR, blocked_layout)), + (ws_2, (smem, bar, MISSING_BAR, blocked_layout)), + ], [4, 4], [32, 32]) output = torch.empty((XBLOCK, ), device=device, dtype=torch.float16) kernel[(1, )](output, MISSING_BAR=MISSING_BAR, num_warps=4) @@ -1078,8 +1091,10 @@ def kernel(output, FAILURE: ttgl.constexpr): mbarrier.arrive(bar.index(2), count=1) - ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout), - [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, FAILURE, blocked_layout)), + (ws_1, (smem, bar, FAILURE, blocked_layout)), + ], [4], [32]) output = torch.empty((XBLOCK, ), device=device, dtype=torch.float16) kernel[(1, )](output, FAILURE=FAILURE, num_warps=4) @@ -1166,8 +1181,12 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): mbarrier.arrive(bar.index(2), count=2) mbarrier.arrive(bar.index(3), count=2) - ttgl.warp_specialize((smem, bar, MISSING_BAR, blocked_layout), ws_default, - (smem, bar, MISSING_BAR, blocked_layout), [ws_1, ws_2, ws_3], [4, 4, 4], [32, 32, 32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, MISSING_BAR, blocked_layout)), + (ws_1, (smem, bar, MISSING_BAR, blocked_layout)), + (ws_2, (smem, bar, MISSING_BAR, blocked_layout)), + (ws_3, (smem, bar, MISSING_BAR, blocked_layout)), + ], [4, 4, 4], [32, 32, 32]) output = torch.empty((XBLOCK, ), device=device, dtype=torch.float16) kernel[(1, )](output, MISSING_BAR=MISSING_BAR, num_warps=4) @@ -1231,8 +1250,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr): bar = ttgl.allocate_shared_memory(ttgl.int64, [3, 1], mbarrier.MBarrierLayout()) for i in range(3): mbarrier.init(bar.index(i), count=1) - ttgl.warp_specialize((smem, bar, MISSING_BAR), ws_default, (smem, bar, MISSING_BAR), [ws_1, ws_2], [2, 8], - [32, 32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, MISSING_BAR)), + (ws_1, (smem, bar, MISSING_BAR)), + (ws_2, (smem, bar, MISSING_BAR)), + ], [2, 8], [32, 32]) mbarrier.wait(bar.index(2), phase=0) val = smem.index(0).load(blocked_layout) output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout) @@ -1298,8 +1320,10 @@ def kernel(input, FAILURE: ttgl.constexpr): smem = ttgl.allocate_shared_memory(ttgl.float16, [4, XBLOCK], smem_layout) blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[XBLOCK], threads_per_warp=[32], warps_per_cta=[4], order=[0]) - ttgl.warp_specialize((input, smem, FAILURE, blocked_layout, 0), ws_prog, - (input, smem, FAILURE, blocked_layout, 2), [ws_prog], [4], [32]) + ttgl.warp_specialize([ + (ws_prog, (input, smem, FAILURE, blocked_layout, 0)), + (ws_prog, (input, smem, FAILURE, blocked_layout, 2)), + ], [4], [32]) input = torch.randn((XBLOCK, ), device=device, dtype=torch.float16) kernel[(1, )](input, FAILURE=FAILURE, num_warps=4) @@ -1354,8 +1378,10 @@ def kernel(input, FAILURE: ttgl.constexpr): smem = ttgl.allocate_shared_memory(ttgl.float16, [2, XBLOCK], smem_layout) bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=1) - ttgl.warp_specialize((input, smem, bar, FAILURE, blocked_layout), ws_default, - (input, smem, bar, FAILURE, blocked_layout), [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (input, smem, bar, FAILURE, blocked_layout)), + (ws_1, (input, smem, bar, FAILURE, blocked_layout)), + ], [4], [32]) input = torch.randn((XBLOCK, ), device=device, dtype=torch.float16) kernel[(1, )](input, FAILURE=FAILURE, num_warps=4) @@ -1410,8 +1436,10 @@ def kernel(FAILURE: ttgl.constexpr): smem = ttgl.allocate_shared_memory(ttgl.float16, [2, XBLOCK, XBLOCK], smem_layout) bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=1) - ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout, mma_layout), ws_default, - (smem, bar, FAILURE, blocked_layout), [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (smem, bar, FAILURE, blocked_layout, mma_layout)), + (ws_1, (smem, bar, FAILURE, blocked_layout)), + ], [4], [32]) kernel[(1, )](FAILURE=FAILURE, num_warps=4) @@ -1446,7 +1474,10 @@ def kernel(): bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=1) mbarrier.init(bar.index(1), count=1) - ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (bar, )), + (ws_1, (bar, )), + ], [4], [32]) kernel[(1, )](num_warps=4) @@ -1513,7 +1544,10 @@ def kernel(): bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=2) mbarrier.init(bar.index(1), count=2) - ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (bar, )), + (ws_1, (bar, )), + ], [4], [32]) kernel[(1, )](num_warps=4) @@ -1549,7 +1583,10 @@ def kernel(): bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=1) mbarrier.arrive(bar.index(0), count=1) - ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (bar, )), + (ws_1, (bar, )), + ], [4], [32]) kernel[(1, )](num_warps=4) @@ -1590,7 +1627,10 @@ def kernel(input_desc): bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=1) mbarrier.init(bar.index(1), count=1) - ttgl.warp_specialize((input_desc, smem, bar), ws_default, (input_desc, smem, bar), [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (input_desc, smem, bar)), + (ws_1, (input_desc, smem, bar)), + ], [4], [32]) input = torch.randn((XBLOCK, XBLOCK), device=device, dtype=torch.float16) shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2) @@ -1629,6 +1669,9 @@ def kernel(): bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout()) mbarrier.init(bar.index(0), count=1) mbarrier.init(bar.index(1), count=1) - ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32]) + ttgl.warp_specialize([ + (ws_default, (bar, )), + (ws_1, (bar, )), + ], [4], [32]) kernel[(1, )](num_warps=4) diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 361940e02c..bd39de4b78 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -18,6 +18,7 @@ is_hopper, is_xpu, ) +from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor from triton.experimental import gluon from triton.experimental.gluon import language as ttgl from triton.experimental.gluon.language.nvidia.ampere import async_copy, mma_v2 @@ -140,6 +141,92 @@ def test_async_copy_mbarrier(device): torch.testing.assert_close(out[20:], torch.zeros((12, 32), **tensor_opts)) +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_device_tma_load(): + + @gluon.jit + def tma_device_load_kernel(input_ptr, output_ptr, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): + input_desc = tma.make_tensor_descriptor( + input_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + + smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + + mbarrier.expect(bar, input_desc.block_type.nbytes) + tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + mbarrier.wait(bar, 0) + mbarrier.invalidate(bar) + + block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) + xindex = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(1, block_layout))[:, None] + yindex = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, block_layout))[None, :] + val = smem.load(block_layout) + ttgl.store(output_ptr + yindex + xindex * XBLOCK, val) + + XBLOCK = 16 + input = torch.zeros((XBLOCK, XBLOCK), device="cuda", dtype=torch.float16) + output = torch.ones_like(input) + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + tma_device_load_kernel[(1, )](input, output, XBLOCK, smem_layout) + torch.testing.assert_close(input, output) + + +@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper") +def test_device_tma_store(): + + @gluon.jit + def tma_device_store_kernel(out_ptr, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): + layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) + value = ttgl.full([XBLOCK, XBLOCK], 0, ttgl.float16, layout) + alloc = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout, value) + out_desc = tma.make_tensor_descriptor( + out_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + tma.async_copy_shared_to_global(out_desc, [0, 0], alloc) + tma.store_wait(0) + alloc._keep_alive() + + XBLOCK = 16 + out = torch.ones((XBLOCK, XBLOCK), dtype=torch.float16, device="cuda") + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + tma_device_store_kernel[(1, )](out, XBLOCK, smem_layout) + torch.testing.assert_close(out, torch.zeros_like(out)) + + @gluon.jit def mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, block_layout: ttgl.constexpr, mma_layout: ttgl.constexpr, shared_layout_a: ttgl.constexpr, shared_layout_b: ttgl.constexpr, @@ -547,145 +634,108 @@ def kernel(a_ptr, b_ptr, c_ptr, # @pytest.mark.xfail(not is_hip_cdna4(), reason="Requires CDNA4", run=False) -@pytest.mark.parametrize("M, N, K, rhs_scale, mxfp_type, normal_type", [(32, 32, 128, rhs_scale, mxfp_type, normal_type) - for rhs_scale in [True, False] - for mxfp_type in ["e2m1"] - for normal_type in ["e4m3", "e5m2"]]) -def test_amd_mfma_scaled(M, N, K, rhs_scale, mxfp_type, normal_type): - device = 'cuda' - - @triton.jit - def triton_kernel(a_base, stride_am, stride_ak, a_scale, # - b_base, stride_bk, stride_bn, b_scale, # - out, # - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - type_a: tl.constexpr, type_b: tl.constexpr): - DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 - DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B - a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_am + \ - tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_ak - b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_bk + \ - tl.arange(0, BLOCK_N)[None, :] * stride_bn - - a = tl.load(a_ptr) - b = tl.load(b_ptr) - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 - if a_scale is not None: - scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, - SCALE_BLOCK_K)[None, :] - a_scale = tl.load(scale_a_ptr) - if b_scale is not None: - scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, - SCALE_BLOCK_K)[None, :] - b_scale = tl.load(scale_b_ptr) - c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) - out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] - tl.store(out_ptr, c.to(tl.bfloat16)) +@pytest.mark.parametrize("M, N, K", [(32, 32, 128)]) +@pytest.mark.parametrize("a_type, b_type", [(a_type, b_type) + for a_type in ["e2m1", "e4m3", "e5m2"] + for b_type in ["e2m1", "e4m3", "e5m2"]]) +@pytest.mark.parametrize("has_scale", [True, False]) +def test_amd_mfma_scaled(M, N, K, a_type, b_type, has_scale, device='cuda'): @gluon.jit - def gluon_kernel(a_base, stride_am, stride_ak, a_scale, # - b_base, stride_bk, stride_bn, b_scale, # - out, # - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # - type_a: tl.constexpr, type_b: tl.constexpr): - DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 - DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 + def kernel(out_ptr, a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, # + M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, # + a_type: tl.constexpr, b_type: tl.constexpr): + DIV_FACTOR_A: tl.constexpr = 2 if a_type == "e2m1" else 1 + DIV_FACTOR_B: tl.constexpr = 2 if b_type == "e2m1" else 1 + K_A: tl.constexpr = K // DIV_FACTOR_A + K_B: tl.constexpr = K // DIV_FACTOR_B + + mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=4, instr_shape=[16, 16, 128], transposed=True, + warps_per_cta=[2, 2]) a_unpacked_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [8, 8], [4, 1], [1, 0]) a_packed_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [8, 8], [4, 1], [1, 0]) - a_layout: ttgl.constexpr = a_packed_layout if type_a == "e2m1" else a_unpacked_layout - - a_scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout( - reg_bases=[], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp_bases=[[0, 0], [16, 0]], - block_bases=[], shape=[32, 4]) + a_load_layout: ttgl.constexpr = a_packed_layout if a_type == "e2m1" else a_unpacked_layout + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=16) + a_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(a_layout, [M, K // 32]) b_unpacked_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [32, 2], [4, 1], [1, 0]) b_packed_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0]) - b_layout: ttgl.constexpr = b_packed_layout if type_b == "e2m1" else b_unpacked_layout - - b_scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout( - reg_bases=[], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp_bases=[[16, 0], [0, 0]], - block_bases=[], shape=[32, 4]) - - mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=4, instr_shape=[16, 16, 128], transposed=True, - warps_per_cta=[2, 2]) - - zero = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=mfma_layout) - - a_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, a_layout))[:, None] * stride_am + \ - ttgl.arange(0, PACKED_BLOCK_K_A, layout=ttgl.SliceLayout(0, a_layout))[None, :] * stride_ak - a = ttgl.amd.cdna4.buffer_load(a_base, a_offsets) - a = ttgl.convert_layout(a, ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=16)) - - b_offsets = ttgl.arange(0, PACKED_BLOCK_K_B, layout=ttgl.SliceLayout(1, b_layout))[:, None] * stride_bk + \ - ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, b_layout))[None, :] * stride_bn - b = ttgl.amd.cdna4.buffer_load(b_base, b_offsets) - b = ttgl.convert_layout(b, ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=16)) - - if a_scale is not None: - a_scale_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, a_scale_layout))[:, None] * SCALE_BLOCK_K + \ - ttgl.arange(0, SCALE_BLOCK_K, layout=ttgl.SliceLayout(0, a_scale_layout))[None, :] - a_scale = ttgl.amd.cdna4.buffer_load(a_scale, a_scale_offsets) + b_load_layout: ttgl.constexpr = b_packed_layout if b_type == "e2m1" else b_unpacked_layout + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=16) + b_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(b_layout, [N, K // 32]) + + a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, a_load_layout))[:, None] + a_offs_k = ttgl.arange(0, K_A, layout=ttgl.SliceLayout(0, a_load_layout))[None, :] + a = ttgl.amd.cdna4.buffer_load(a_ptr, a_offs_m * K_A + a_offs_k) + a = ttgl.convert_layout(a, a_layout) + + b_offs_k = ttgl.arange(0, K_B, layout=ttgl.SliceLayout(1, b_load_layout))[:, None] + b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, b_load_layout))[None, :] + b = ttgl.amd.cdna4.buffer_load(b_ptr, b_offs_k * N + b_offs_n) + b = ttgl.convert_layout(b, b_layout) + + a_scale = None + if a_scale_ptr is not None: + a_scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, a_scale_layout))[:, None] + a_scale_offs_k = ttgl.arange(0, K // 32, layout=ttgl.SliceLayout(0, a_scale_layout))[None, :] + a_scale = ttgl.amd.cdna4.buffer_load(a_scale_ptr, a_scale_offs_m * (K // 32) + a_scale_offs_k) + + b_scale = None + if b_scale_ptr is not None: + b_scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, b_scale_layout))[:, None] + b_scale_offs_k = ttgl.arange(0, K // 32, layout=ttgl.SliceLayout(0, b_scale_layout))[None, :] + b_scale = ttgl.amd.cdna4.buffer_load(b_scale_ptr, b_scale_offs_n * (K // 32) + b_scale_offs_k) + + zero = ttgl.zeros([M, N], dtype=ttgl.float32, layout=mfma_layout) + c = ttgl.amd.cdna4.mfma_scaled(a, a_scale, a_type, b, b_scale, b_type, zero) + c = c.to(out_ptr.dtype.element_ty) + + out_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, mfma_layout))[:, None] + out_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, mfma_layout))[None, :] + ttgl.amd.cdna4.buffer_store(c, out_ptr, out_offs_m * N + out_offs_n) + + def _create_mxfp_operand(operand: int, m: int, n: int, dtype: str): + size = (m, n) + if dtype == 'e4m3': + v = torch.randint(20, 40, size, dtype=torch.uint8) + v_ref = v.view(torch.float8_e4m3fn).to(torch.float32) + elif dtype == 'e5m2': + v = torch.randint(20, 40, size, dtype=torch.uint8) + v_ref = v.view(torch.float8_e5m2).to(torch.float32) else: - a_scale = ttgl.full([BLOCK_M, SCALE_BLOCK_K], 127, dtype=ttgl.int8, layout=a_scale_layout) - - if b_scale is not None: - b_scale_offsets = ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(1, b_scale_layout))[:, None] * SCALE_BLOCK_K + \ - ttgl.arange(0, SCALE_BLOCK_K, layout=ttgl.SliceLayout(0, b_scale_layout))[None, :] - b_scale = ttgl.amd.cdna4.buffer_load(b_scale, b_scale_offsets) - else: - b_scale = ttgl.full([BLOCK_M, SCALE_BLOCK_K], 127, dtype=ttgl.int8, layout=b_scale_layout) - - c = ttgl.amd.cdna4.mfma_scaled(a, a_scale, type_a, b, b_scale, type_b, zero) - c = c.to(out.dtype.element_ty) - - out_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, mfma_layout))[:, None] * BLOCK_N + \ - ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, mfma_layout))[None, :] - ttgl.amd.cdna4.buffer_store(c, out, out_offsets) + assert dtype == 'e2m1' + pack_dim = 1 if operand == 0 else 0 + v_mxfp4 = MXFP4Tensor(size=size).random() + v = v_mxfp4.to_packed_tensor(pack_dim) + v_ref = v_mxfp4.to(torch.float32) + return v.to(device), v_ref.to(device) + + def _create_mxfp_scale(operand: int, m: int, n: int): + size = (m, n // 32) + scale = MXScaleTensor(size=tuple(size)).random(1 / 32, 32) + scale_ref = scale.to(torch.float32).repeat_interleave(32, dim=1) + scale_ref = scale_ref.T.contiguous() if operand == 1 else scale_ref + return scale.data.to(device), scale_ref.to(device) torch.manual_seed(0) - - type_a = normal_type if rhs_scale else mxfp_type - type_b = mxfp_type if rhs_scale else normal_type - - DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 - DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 - x = torch.randint(20, 40, (M, K // DIV_FACTOR_A), dtype=torch.uint8, device=device) - y = torch.randint(20, 40, (K // DIV_FACTOR_B, N), dtype=torch.uint8, device=device) - - min_scale, max_scale = (0, 142) - scale_x = torch.randint(min_scale, max_scale + 1, (M, K // 32), dtype=torch.uint8, device=device) - scale_y = torch.randint(min_scale, max_scale + 1, (N, K // 32), dtype=torch.uint8, device=device) - if rhs_scale: - scale_x = None + a, a_ref = _create_mxfp_operand(0, M, K, a_type) + b, b_ref = _create_mxfp_operand(1, K, N, b_type) + + if has_scale: + a_scale, a_scale_ref = _create_mxfp_scale(0, M, K) + b_scale, b_scale_ref = _create_mxfp_scale(1, N, K) + out = torch.empty((M, N), dtype=torch.float32, device=device) + compiled = kernel[(1, )](out, a, b, a_scale, b_scale, M, N, K, a_type, b_type, num_warps=4) + out_ref = torch.matmul(a_ref * a_scale_ref, b_ref * b_scale_ref) + torch.testing.assert_close(out, out_ref) else: - scale_y = None - - def make_finite(x, dtype): - if dtype not in ("e5m2", "e4m3"): - return x - mask = 0x7C if dtype == "e5m2" else 0x7F - finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask - x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) - x.copy_(x_finite) - return x - - x = make_finite(x, type_a) - y = make_finite(y, type_b) - - z = torch.zeros((M, N), dtype=torch.bfloat16, device=device) - pgm = gluon_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b) - assert "v_mfma_scale_f32_16x16x128_f8f6f4" in pgm.asm["amdgcn"] - - z_ref = torch.zeros((M, N), dtype=torch.bfloat16, device=device) - triton_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z_ref, M, N, K, type_a, type_b) + out = torch.empty((M, N), dtype=torch.float32, device=device) + compiled = kernel[(1, )](out, a, b, None, None, M, N, K, a_type, b_type, num_warps=4) + out_ref = torch.matmul(a_ref, b_ref) + torch.testing.assert_close(out, out_ref) - torch.testing.assert_close(z, z_ref, rtol=1e-5, atol=1e-5) + assert 'v_mfma_scale_f32_16x16x128_f8f6f4' in compiled.asm['amdgcn'] def test_math_fast_expf(device): diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 77cb8707d8..904d418048 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -258,8 +258,8 @@ def test_tensor_memory(): %4 = arith.bitcast %c1_i32 : i32 to i32 %5 = ub.poison : i32 scf.for %arg0 = %2 to %3 step %4 : i32 { - %6 = ttg.memdesc_index %result_2[%arg0] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> - %result_4 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 2x128x128> -> tensor<128x128xf32, #blocked> + %6 = ttg.memdesc_index %result_2[%arg0] : !ttg.memdesc<2x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %result_4 = ttng.tmem_load %6 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked> } tt.return } @@ -338,8 +338,8 @@ def test_shared_memory_index(target): %3 = arith.bitcast %c1_i32 : i32 to i32 %4 = ub.poison : i32 scf.for %arg0 = %1 to %2 step %3 : i32 { - %5 = ttg.memdesc_index %0[%arg0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> - %6 = ttg.local_load %5 : !ttg.memdesc<256xi32, #shared, #smem, mutable, 4x256> -> tensor<256xi32, #blocked> + %5 = ttg.memdesc_index %0[%arg0] : !ttg.memdesc<4x256xi32, #shared, #smem, mutable> -> !ttg.memdesc<256xi32, #shared, #smem, mutable> + %6 = ttg.local_load %5 : !ttg.memdesc<256xi32, #shared, #smem, mutable> -> tensor<256xi32, #blocked> } tt.return } @@ -408,15 +408,15 @@ def test_shared_memory_cast(target): tt.func public @shared_memory_cast_kernel() attributes {noinline = false} { %0 = ttg.local_alloc : () -> !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> %c0_i32 = arith.constant 0 : i32 - %1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> - %2 = ttg.memdesc_trans %1 {order = array} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable, 2x256x128> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256> - tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) -> () + %1 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<256x128xi8, #shared, #smem, mutable> + %2 = ttg.memdesc_trans %1 {order = array} : !ttg.memdesc<256x128xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x256xi8, #shared1, #smem, mutable> + tt.call @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[128, 256]ASMD__"(%2) : (!ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) -> () %3 = ttg.local_alloc : () -> !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> %4 = ttg.memdesc_reshape %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<128x64xf16, #shared3, #smem, mutable> %5 = ttg.memdesc_reinterpret %3 : !ttg.memdesc<32x1x4x64xf16, #shared2, #smem, mutable> -> !ttg.memdesc<1024xi8, #shared4, #smem, mutable> tt.return } - tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[2, 128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable, 2x128x256>) attributes {noinline = true} { + tt.func private @"test_frontend.anchor_noinline__MDi8S128_256SLNVMMA_64_8_True_False_NVMMALAS[128, 256]ASMD__"(%arg0: !ttg.memdesc<128x256xi8, #shared1, #smem, mutable>) attributes {noinline = true} { tt.return } } @@ -466,17 +466,17 @@ def test_warp_specialize(): # CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32} # CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32} # CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} - # CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array + # CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]], [[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array # CHECK-NEXT: default { # CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}cconstexpr_42{{.*}}([[A]], [[B]], [[C]]) # CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2 # CHECK-NEXT: } - # CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) { + # CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>, %arg3: tensor<1xi32, [[BLOCKED]]>, %arg4: tensor<2xi32, [[BLOCKED]]>, %arg5: tensor<4xi32, [[BLOCKED]]>) num_warps(4) { # CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}cconstexpr_42{{.*}}(%arg0, %arg1, %arg2) # CHECK-NEXT: warp_return # CHECK-NEXT: } - # CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) { - # CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}cconstexpr_42{{.*}}(%arg0, %arg1, %arg2) + # CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>, %arg3: tensor<1xi32, [[BLOCKED]]>, %arg4: tensor<2xi32, [[BLOCKED]]>, %arg5: tensor<4xi32, [[BLOCKED]]>) num_warps(4) { + # CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}cconstexpr_42{{.*}}(%arg3, %arg4, %arg5) # CHECK-NEXT: warp_return # CHECK-NEXT: } # CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0) @@ -487,14 +487,20 @@ def test_warp_specialize(): c = ttgl.arange(0, 4, layout=layout) pair = Pair(a, b) e: ttgl.constexpr = 42 - a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default, (pair, c, e), - [warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48]) + a, b = ttgl.warp_specialize([ + (warp_specialize_default, (pair, c, e)), + (warp_specialize_worker0, (pair, c, e)), + (warp_specialize_worker1, (pair, c, e)), + ], [4, 4], [24, 48]) anchor(a) anchor(b) # CHECK: ttg.warp_specialize([[A]], [[B]], [[C]]) # CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> () - ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48]) + ttgl.warp_specialize([ + (warp_specialize_worker0, (pair, c, e)), + (warp_specialize_worker1, (pair, c, e)), + ], [4], [48]) @gluon.jit @@ -535,7 +541,11 @@ def test_num_warps_caller_context(): # CHECK: func private @{{.*}}ws_test_worker1{{.*}}_NW1() attributes {noinline = false, "ttg.num-warps" = 1 : i32} # CHECK: func private @{{.*}}ws_body{{.*}}_NW1"() attributes {noinline = false, "ttg.num-warps" = 1 : i32} # CHECK: func private @{{.*}}anchor{{.*}}_NW1(%arg0: tensor<128xi32, [[BLOCKED_NW1]]>) attributes {noinline = false, "ttg.num-warps" = 1 : i32} - ttgl.warp_specialize((), ws_test_default, (), [ws_test_worker0, ws_test_worker1], [2, 1], [80, 80]) + ttgl.warp_specialize([ + (ws_test_default, ()), + (ws_test_worker0, ()), + (ws_test_worker1, ()), + ], [2, 1], [80, 80]) @gluon.jit @@ -920,7 +930,7 @@ def test_tmem_index_constexpr(): tt.func public @tmem_index_kernel() attributes {noinline = false} { %result = ttng.tmem_alloc : () -> !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> %c0_i32 = arith.constant 0 : i32 - %0 = ttg.memdesc_index %result[%c0_i32] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable, 2x256x256> + %0 = ttg.memdesc_index %result[%c0_i32] : !ttg.memdesc<2x256x256xi32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<256x256xi32, #tmem, #ttng.tensor_memory, mutable> tt.return } } @@ -2198,11 +2208,10 @@ def buffer_load_store_kernel(x, y): ttgl.amd.cdna4.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs') -@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) -def test_buffer_load_store(target): +def test_buffer_load_store(): x = MockTensor(ttgl.float32) y = MockTensor(ttgl.float32) - module = run_parser(buffer_load_store_kernel, *make_args(x, y), target=target) + module = run_parser(buffer_load_store_kernel, *make_args(x, y), target=HIP_TARGET_CDNA3) expecttest.assert_expected_inline( anonymize_ir(module.str_nodebug()), """\ @@ -2247,11 +2256,10 @@ def buffer_load_store_with_broadcast_kernel(x, y): ttgl.amd.cdna3.buffer_store(stored_value=a, ptr=y, offsets=offsets, mask=mask, cache='.cs') -@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) -def test_buffer_load_store_with_broadcast(target): +def test_buffer_load_store_with_broadcast(): x = MockTensor(ttgl.float16) y = MockTensor(ttgl.float16) - module = run_parser(buffer_load_store_with_broadcast_kernel, *make_args(x, y), target=target) + module = run_parser(buffer_load_store_with_broadcast_kernel, *make_args(x, y), target=HIP_TARGET_CDNA3) expecttest.assert_expected_inline( anonymize_ir(module.str_nodebug()), """\ @@ -2407,16 +2415,15 @@ def test_amd_mfma_scaled(target): def kernel(): mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=4, instr_shape=[16, 16, 128], transposed=True, warps_per_cta=[1, 1]) - scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout([], - [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], - [], [], [16, 4]) - - a = ttgl.full([16, 64], 0x11, ttgl.uint8, ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, - k_width=16)) - b = ttgl.full([64, 16], 0x22, ttgl.uint8, ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, - k_width=16)) - a_scale = ttgl.full([16, 4], 0x02, ttgl.uint8, scale_layout) - b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, scale_layout) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=16) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=16) + a_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(a_layout, [16, 4]) + b_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(b_layout, [16, 4]) + + a = ttgl.full([16, 64], 0x11, ttgl.uint8, a_layout) + b = ttgl.full([64, 16], 0x22, ttgl.uint8, b_layout) + a_scale = ttgl.full([16, 4], 0x02, ttgl.uint8, a_scale_layout) + b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, b_scale_layout) acc = ttgl.full([16, 16], 0, ttgl.float32, mfma_layout) ttgl.amd.cdna4.mfma_scaled(a, a_scale, 'e2m1', b, b_scale, 'e2m1', acc) @@ -2451,21 +2458,70 @@ def test_amd_mfma_scaled_none(target): @gluon.jit def kernel(): mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(4, [16, 16, 128], True, [1, 1]) - scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout([], - [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], - [], [], [16, 4]) - a = ttgl.full([16, 64], 0x11, ttgl.uint8, ttgl.DotOperandLayout(0, mfma_layout, 16)) b = ttgl.full([64, 16], 0x22, ttgl.uint8, ttgl.DotOperandLayout(1, mfma_layout, 16)) - - b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, scale_layout) acc = ttgl.full([16, 16], 0, ttgl.float32, mfma_layout) - ttgl.amd.cdna4.mfma_scaled(a, None, 'e2m1', b, b_scale, 'e2m1', acc) + ttgl.amd.cdna4.mfma_scaled(a, None, 'e2m1', b, None, 'e2m1', acc) - with pytest.raises(CompilationError) as e: - run_parser(kernel, target=target) + module = run_parser(kernel, *make_args(num_warps=1), target=target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#linear = #ttg.linear<{register = [], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [], block = []}> +#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [16, 16, 128], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @kernel() attributes {noinline = false} { + %c17_i8 = arith.constant 17 : i8 + %cst = arith.constant dense<17> : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %c34_i8 = arith.constant 34 : i8 + %cst_0 = arith.constant dense<34> : tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %cst_1 = arith.constant 0.000000e+00 : f32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c127_i8 = arith.constant 127 : i8 + %cst_3 = arith.constant dense<127> : tensor<16x4xi8, #linear> + %c127_i8_4 = arith.constant 127 : i8 + %cst_5 = arith.constant dense<127> : tensor<16x4xi8, #linear> + %cst_6 = arith.constant 0.000000e+00 : f32 + %0 = tt.dot_scaled %cst scale %cst_3, %cst_0 scale %cst_5, %cst_2 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> * tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> -> tensor<16x16xf32, #mma> + tt.return + } +} +""") - assert "Scales must not be None" in str(e.value) + +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA4]) +def test_amd_mfma_scaled_scalar(target): + + @gluon.jit + def kernel(): + mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(4, [16, 16, 128], True, [1, 1]) + a = ttgl.full([16, 64], 0x11, ttgl.uint8, ttgl.DotOperandLayout(0, mfma_layout, 16)) + b = ttgl.full([64, 16], 0x22, ttgl.uint8, ttgl.DotOperandLayout(1, mfma_layout, 16)) + acc = ttgl.full([16, 16], 0, ttgl.float32, mfma_layout) + ttgl.amd.cdna4.mfma_scaled(a, 0x02, 'e2m1', b, 0x01, 'e2m1', acc) + + module = run_parser(kernel, *make_args(num_warps=1), target=target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#linear = #ttg.linear<{register = [], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [], block = []}> +#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 1], instrShape = [16, 16, 128], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @kernel() attributes {noinline = false} { + %c17_i8 = arith.constant 17 : i8 + %cst = arith.constant dense<17> : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %c34_i8 = arith.constant 34 : i8 + %cst_0 = arith.constant dense<34> : tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %cst_1 = arith.constant 0.000000e+00 : f32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + %c2_i8 = arith.constant 2 : i8 + %cst_3 = arith.constant dense<2> : tensor<16x4xi8, #linear> + %c1_i8 = arith.constant 1 : i8 + %cst_4 = arith.constant dense<1> : tensor<16x4xi8, #linear> + %cst_5 = arith.constant 0.000000e+00 : f32 + %0 = tt.dot_scaled %cst scale %cst_3, %cst_0 scale %cst_4, %cst_2 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> * tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> -> tensor<16x16xf32, #mma> + tt.return + } +} +""") @pytest.mark.parametrize("target", [HIP_TARGET_GFX1250]) @@ -2477,19 +2533,15 @@ def kernel(): instr_shape=[16, 16, 128]) wmma_layout_packed: ttgl.constexpr = ttgl.amd.AMDWMMALayout(version=3, transposed=True, warps_per_cta=[2, 2], instr_shape=[16, 16, 64]) - a_scale_linear_layout: ttgl.constexpr = ttgl.DistributedLinearLayout( - reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], - warp_bases=[[0, 0], [16, 0]], block_bases=[], shape=[32, 4]) - b_scale_linear_layout: ttgl.constexpr = ttgl.DistributedLinearLayout( - reg_bases=[[0, 1], [0, 2]], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], - warp_bases=[[16, 0], [0, 0]], block_bases=[], shape=[32, 4]) - - a = ttgl.full([32, 64], 0x11, ttgl.uint8, - ttgl.DotOperandLayout(operand_index=0, parent=wmma_layout_packed, k_width=16)) - b = ttgl.full([64, 32], 0x22, ttgl.uint8, - ttgl.DotOperandLayout(operand_index=1, parent=wmma_layout_packed, k_width=16)) - a_scale = ttgl.full([32, 4], 0x02, ttgl.uint8, a_scale_linear_layout) - b_scale = ttgl.full([32, 4], 0x01, ttgl.uint8, b_scale_linear_layout) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=wmma_layout_packed, k_width=16) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=wmma_layout_packed, k_width=16) + a_scale_layout: ttgl.constexpr = ttgl.amd.gfx1250.get_wmma_scale_layout(a_layout, [32, 4]) + b_scale_layout: ttgl.constexpr = ttgl.amd.gfx1250.get_wmma_scale_layout(b_layout, [32, 4]) + + a = ttgl.full([32, 64], 0x11, ttgl.uint8, a_layout) + b = ttgl.full([64, 32], 0x22, ttgl.uint8, b_layout) + a_scale = ttgl.full([32, 4], 0x02, ttgl.uint8, a_scale_layout) + b_scale = ttgl.full([32, 4], 0x01, ttgl.uint8, b_scale_layout) acc = ttgl.full([32, 32], 0, ttgl.float32, wmma_layout) ttgl.amd.gfx1250.wmma_scaled(a, a_scale, 'e2m1', b, b_scale, 'e2m1', acc) @@ -2527,23 +2579,81 @@ def test_amd_wmma_scaled_none(target): def kernel(): wmma_layout: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 128]) wmma_layout_packed: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 64]) - scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout([[0, 1], [0, 2]], - [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], [], [], - [16, 4]) a_layout: ttgl.constexpr = ttgl.DotOperandLayout(0, wmma_layout_packed, 16) b_layout: ttgl.constexpr = ttgl.DotOperandLayout(1, wmma_layout_packed, 16) a = ttgl.full([16, 64], 0x11, ttgl.uint8, a_layout) b = ttgl.full([64, 16], 0x22, ttgl.uint8, b_layout) - b_scale = ttgl.full([16, 4], 0x01, ttgl.uint8, scale_layout) acc = ttgl.full([16, 16], 0, ttgl.float32, wmma_layout) - ttgl.amd.gfx1250.wmma_scaled(a, None, 'e2m1', b, b_scale, 'e2m1', acc) + ttgl.amd.gfx1250.wmma_scaled(a, None, 'e2m1', b, None, 'e2m1', acc) - with pytest.raises(CompilationError) as e: - run_parser(kernel, target=target) + module = run_parser(kernel, *make_args(num_warps=1), target=target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [1, 1], instrShape = [16, 16, 64]}> +#mma1 = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [1, 1], instrShape = [16, 16, 128]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + %c17_i8 = arith.constant 17 : i8 + %cst = arith.constant dense<17> : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %c34_i8 = arith.constant 34 : i8 + %cst_0 = arith.constant dense<34> : tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %cst_1 = arith.constant 0.000000e+00 : f32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1> + %c127_i8 = arith.constant 127 : i8 + %cst_3 = arith.constant dense<127> : tensor<16x4xi8, #linear> + %c127_i8_4 = arith.constant 127 : i8 + %cst_5 = arith.constant dense<127> : tensor<16x4xi8, #linear> + %cst_6 = arith.constant 0.000000e+00 : f32 + %0 = tt.dot_scaled %cst scale %cst_3, %cst_0 scale %cst_5, %cst_2 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> * tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> -> tensor<16x16xf32, #mma1> + tt.return + } +} +""") - assert "Scales must not be None" in str(e.value) + +@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250]) +def test_amd_wmma_scaled_scalar(target): + + @gluon.jit + def kernel(): + wmma_layout: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 128]) + wmma_layout_packed: ttgl.constexpr = ttgl.amd.AMDWMMALayout(3, True, [1, 1], [16, 16, 64]) + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(0, wmma_layout_packed, 16) + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(1, wmma_layout_packed, 16) + + a = ttgl.full([16, 64], 0x11, ttgl.uint8, a_layout) + b = ttgl.full([64, 16], 0x22, ttgl.uint8, b_layout) + acc = ttgl.full([16, 16], 0, ttgl.float32, wmma_layout) + + ttgl.amd.gfx1250.wmma_scaled(a, 0x02, 'e2m1', b, 0x01, 'e2m1', acc) + + module = run_parser(kernel, *make_args(num_warps=1), target=target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [1, 1], instrShape = [16, 16, 64]}> +#mma1 = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [1, 1], instrShape = [16, 16, 128]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @kernel() attributes {noinline = false} { + %c17_i8 = arith.constant 17 : i8 + %cst = arith.constant dense<17> : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + %c34_i8 = arith.constant 34 : i8 + %cst_0 = arith.constant dense<34> : tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + %cst_1 = arith.constant 0.000000e+00 : f32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma1> + %c2_i8 = arith.constant 2 : i8 + %cst_3 = arith.constant dense<2> : tensor<16x4xi8, #linear> + %c1_i8 = arith.constant 1 : i8 + %cst_4 = arith.constant dense<1> : tensor<16x4xi8, #linear> + %cst_5 = arith.constant 0.000000e+00 : f32 + %0 = tt.dot_scaled %cst scale %cst_3, %cst_0 scale %cst_4, %cst_2 lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> * tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<16x4xi8, #linear> -> tensor<16x16xf32, #mma1> + tt.return + } +} +""") @gluon.jit @@ -2811,8 +2921,12 @@ def test_get_num_warps(): # CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW8 # CHECK-NEXT arith.constant 8 : i32 print_num_warps() - ttgl.warp_specialize((), print_num_warps, (), [print_num_warps, print_num_warps, print_num_warps], [1, 2, 8], - [24, 24, 24]) + ttgl.warp_specialize([ + (print_num_warps, ()), + (print_num_warps, ()), + (print_num_warps, ()), + (print_num_warps, ()), + ], [1, 2, 8], [24, 24, 24]) def test_mismatch_shape_and_layout_rank(): @@ -2952,3 +3066,94 @@ def test_amd_tdm_store(target): } } """) + + +@pytest.mark.parametrize("target", [BLACKWELL_TARGET, HOPPER_TARGET]) +def test_nv_tma_descriptor_load_kernel(target): + + @gluon.jit + def nv_tma_descriptor_load_kernel(input_ptr): + XBLOCK: ttgl.constexpr = 128 + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2) + input_desc = tma.make_tensor_descriptor( + input_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float32.primitive_bitwidth // 8) + tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + + ptr = MockTensor(ttgl.float32) + module = run_parser(nv_tma_descriptor_load_kernel, *make_args(ptr), target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @nv_tma_descriptor_load_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 + %c128_i32_0 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , > + %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + %2 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> + ttng.init_barrier %2, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> + %true = arith.constant true + ttng.barrier_expect %2, 65536, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %c0_i32_1 = arith.constant 0 : i32 + %true_2 = arith.constant true + ttng.async_tma_copy_global_to_local %0[%c0_i32, %c0_i32_1] %1, %2, %true_2 : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + tt.return + } +} +""") + + +@pytest.mark.parametrize("target", [BLACKWELL_TARGET, HOPPER_TARGET]) +def test_nv_tma_descriptor_store_kernel(target): + + @gluon.jit + def nv_tma_descriptor_store_kernel(input_ptr): + XBLOCK: ttgl.constexpr = 128 + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2) + input_desc = tma.make_tensor_descriptor( + input_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout) + tma.async_copy_shared_to_global(input_desc, [0, 0], smem) + tma.store_wait(0) + + ptr = MockTensor(ttgl.float32) + module = run_parser(nv_tma_descriptor_store_kernel, *make_args(ptr), target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @nv_tma_descriptor_store_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 + %c128_i32_0 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , > + %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %c0_i32_1 = arith.constant 0 : i32 + ttng.async_tma_copy_local_to_global %0[%c0_i32, %c0_i32_1] %1 : !tt.tensordesc>, !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + ttng.async_tma_store_wait {pendings = 0 : i32} + tt.return + } +} +""") diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a18c1116e1..32aea3b632 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6699,3 +6699,40 @@ def kernel(): tl.device_assert(tl.sum(x) == x.sum()) kernel[(1, )]() + + +@pytest.mark.interpreter +@pytest.mark.parametrize("rank", [2, 3, 4, 5, 6]) +@pytest.mark.parametrize("trans_a", [False, True]) +@pytest.mark.parametrize("trans_b", [False, True]) +def test_dot_multidim(rank, trans_a, trans_b, device): + + if is_interpreter(): + pytest.skip("bfloat16 is not supported in the interpreter") + + @triton.jit + def kernel(X, Y, Z, RANK: tl.constexpr, TRANS_A: tl.constexpr, TRANS_B: tl.constexpr): + x = tl.load(X + tl.arange(0, 256 << RANK)).reshape([2] * (RANK - 2) + [32, 32]) + y = tl.load(Y + tl.arange(0, 256 << RANK)).reshape([2] * (RANK - 2) + [32, 32]) + if TRANS_A: + x = tl.trans(x) + if TRANS_B: + y = tl.trans(y) + z = tl.dot(x, y) + tl.store(Z + tl.arange(0, 256 << RANK), z.reshape([256 << RANK])) + + shape = (2, ) * (rank - 2) + (32, 32) + + a = torch.randint(-4, 5, shape, dtype=torch.bfloat16, device=device) + b = torch.randint(-4, 5, shape, dtype=torch.bfloat16, device=device) + c = torch.empty(shape, dtype=torch.float32, device=device) + kernel[(1, )](a, b, c, rank, trans_a, trans_b) + + if trans_a: + a = torch.transpose(a, -1, -2) + if trans_b: + b = torch.transpose(b, -1, -2) + + d = a.to(torch.float32) @ b.to(torch.float32) + + assert torch.equal(c, d) diff --git a/python/test/unit/language/test_matmul.py b/python/test/unit/language/test_matmul.py index 02f7690f4e..05bec959e1 100644 --- a/python/test/unit/language/test_matmul.py +++ b/python/test/unit/language/test_matmul.py @@ -779,8 +779,11 @@ def generate_gemm_input(dim0, dim1, dtype): triton_out = triton_out.to(torch.float32) torch.testing.assert_close(torch_out, triton_out, atol=2e-5, rtol=1e-4) if is_hip() and preshuffle: - assert "tilesPerWarp = [2, 2]" in k.asm["ttgir"] assert "ds_read_u8" not in k.asm["amdgcn"] + if mfma_nonkdim == 16: + assert "tilesPerWarp = [2, 2]" in k.asm["ttgir"] + elif mfma_nonkdim == 32: # default tilesPerWarp = [1, 1] + assert "tilesPerWarp" not in k.asm["ttgir"] @pytest.mark.parametrize("M, N, K", [(1024, 512, 512), (998, 111, 512), (63, 128, 512)]) diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 70448f05ca..6a58ede4f3 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -4,7 +4,7 @@ import triton import triton.language as tl -from triton._internal_testing import is_hopper, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy +from triton._internal_testing import is_hopper, is_sm12x, is_interpreter, numpy_random, to_triton, unwrap_tensor, tma_dtypes, to_numpy from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor from typing import Optional from triton._internal_testing import is_cuda, is_hip, is_hip_cdna3, is_xpu @@ -384,8 +384,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): @pytest.mark.interpreter def test_tensor_descriptor_padding(device): - if not is_cuda(): - pytest.xfail("padding is unsupported") + if is_xpu(): + pytest.skip("FIXME: issue #5400") @triton.jit def device_tma_load(in_ptr, out_ptr, IM, IN, YM, YN, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, @@ -1487,6 +1487,7 @@ def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl. @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8]) @pytest.mark.parametrize("y", [0, 32, 48]) @pytest.mark.skipif(is_hopper(), reason="TMA Scatter is not supported on hopper") +@pytest.mark.skipif(is_sm12x(), reason="TMA Scatter is not supported on sm120") def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device): if BLOCK_X > X or y + BLOCK_Y > Y: pytest.xfail() diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index ba86218120..562e8df721 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -54,6 +54,10 @@ def is_hopper(): return is_cuda() and torch.cuda.get_device_capability()[0] == 9 +def is_sm12x(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 12 + + def is_hip(): target = get_current_target() return False if target is None else target.backend == "hip" diff --git a/python/triton/experimental/gluon/language/_core.py b/python/triton/experimental/gluon/language/_core.py index 35d87d8c13..89c0051a43 100644 --- a/python/triton/experimental/gluon/language/_core.py +++ b/python/triton/experimental/gluon/language/_core.py @@ -493,26 +493,28 @@ def set_auto_layout(value, layout, _semantic=None): @builtin -def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs, - _semantic=None, _generator=None): +def warp_specialize(functions_and_args, worker_num_warps, worker_num_regs, _semantic=None, _generator=None): """ Create a warp-specialized execution region, partitioning work across warps. + This forks the current execution into a "default partition" and an arbitrary number of + "worker partitons". The default partition is executed in the same :code:`num_warps` warps as + the parent region, and may accept tensor arguments and return tensors. Worker partitions are + executed in additional warps, which sit idle while executing the parent region. + + Note that calling warp_specialize recursively is not supported. + Args: - default_args (List[Any]): Arguments for the default region. - default_partition (callable): Function to build the default execution region. - worker_args (List[Any]): Arguments for each warp partition. - worker_partitions (List[callable]): Functions for each warp partition. - worker_num_warps (List[int]): Number of warps per partition. - worker_num_regs (List[int]): Number of registers per partition. + functions_and_args (List[Tuple[Callable, Any]]): List of functions and arguments for each partition. The first of which is the default partition. + worker_num_warps (List[int]): Number of warps used for each worker partition. + worker_num_regs (List[int]): Number of registers for each worker partition. Returns: - Tuple[Any, ...]: Results from the default region. + Tuple[Any, ...]: Results from the default partition. """ worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps] worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs] - return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, - worker_num_regs, _generator) + return _semantic.warp_specialize(functions_and_args, worker_num_warps, worker_num_regs, _generator) @builtin diff --git a/python/triton/experimental/gluon/language/_semantic.py b/python/triton/experimental/gluon/language/_semantic.py index 4ae2739b72..7e3579dc28 100644 --- a/python/triton/experimental/gluon/language/_semantic.py +++ b/python/triton/experimental/gluon/language/_semantic.py @@ -276,7 +276,7 @@ def memdesc_index(self, mem_desc, index): shape = mem_desc.shape[1:] index = self.to_tensor(index).handle layout = mem_desc.layout - ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, mem_desc.type.alloc_shape) + ty = ttgl.shared_memory_descriptor_type(mem_desc.dtype, shape, layout, shape) builder = self.builder handle = builder.create_memdesc_index(ty.to_ir(builder), mem_desc.handle, index) return ttgl.shared_memory_descriptor(handle, **ty.__dict__) @@ -419,13 +419,17 @@ def gather(self, src: TensorTy, index: TensorTy, axis: int) -> TensorTy: gather = self.builder.create_gather(src.handle, index.handle, axis) return self.wrap_tensor(gather, src.type.scalar, index.type.shape, index.type.layout) - def warp_specialize(self, default_args, default_partition, worker_args, worker_partitions, - worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], generator): - num_partitions = len(worker_partitions) - _check(isinstance(default_args, (tuple, ttgl.tuple)), - lambda: f"default_args must be a tuple of arguments, but got {type(default_args)}") - _check(isinstance(worker_args, (tuple, ttgl.tuple)), - lambda: f"worker_args must be a tuple of arguments, but got {type(worker_args)}") + def warp_specialize(self, functions_and_args, worker_num_warps: Sequence[int], worker_num_regs: Sequence[int], + generator): + for _, args in functions_and_args: + _check(isinstance(args, (tuple, ttgl.tuple)), + lambda: f"function arguments must be a tuple of arguments, but got {type(args)}") + + assert len(functions_and_args) >= 1, "expected at least one function for the default partition" + default_partition, default_args = functions_and_args[0] + num_partitions = len(functions_and_args) - 1 + workers = functions_and_args[1:] + assert num_partitions == len( worker_num_warps ), f"warp specialize got {num_partitions} partitions but {len(worker_num_warps)} warp counts" @@ -447,8 +451,9 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p result_types = [r.get_type() for r in mlir_results] # Create the warp specialize op. + worker_args = [flatten_values_to_ir(args) for _, args in workers] + mlir_args = sum(worker_args, []) builder.restore_insertion_point(insert_pt) - mlir_args = flatten_values_to_ir(worker_args) ws_op = builder.create_warp_specialize(result_types, mlir_args, worker_num_warps) ws_op.get_default_region().push_back(default_block) ws_op.set_requested_registers(worker_num_regs) @@ -457,13 +462,16 @@ def warp_specialize(self, default_args, default_partition, worker_args, worker_p builder.create_block_with_parent(ws_op.get_partition_op_holder(), []) partitions_op = builder.create_warp_specialize_partitions(num_partitions) arg_types = [arg.get_type() for arg in mlir_args] - for i in range(num_partitions): + arg_it = 0 + for i, (func, args) in enumerate(workers): caller_context = GluonCallerContext(num_warps=worker_num_warps[i]) block = builder.create_block_with_parent(partitions_op.get_region(i), arg_types) - block_args = [block.get_argument(j) for j in range(len(mlir_args))] - block_args = unflatten_ir_values(block_args, [arg.type for arg in worker_args]) - generator.call_JitFunction(worker_partitions[i], block_args, kwargs={}, caller_context=caller_context) + mlir_args = worker_args[i] + block_args = [block.get_argument(arg_it + j) for j in range(len(mlir_args))] + block_args = unflatten_ir_values(block_args, [arg.type for arg in args]) + generator.call_JitFunction(func, block_args, kwargs={}, caller_context=caller_context) builder.create_warp_return() + arg_it += len(mlir_args) builder.set_insertion_point_after(ws_op.get_operation()) mlir_results = [ws_op.get_result(i) for i in range(len(result_types))] diff --git a/python/triton/experimental/gluon/language/amd/_ops.py b/python/triton/experimental/gluon/language/amd/_ops.py index ab23105772..56367f6095 100644 --- a/python/triton/experimental/gluon/language/amd/_ops.py +++ b/python/triton/experimental/gluon/language/amd/_ops.py @@ -2,6 +2,7 @@ from triton.experimental.gluon.language import _core as ttgl from triton.experimental.gluon.language._semantic import _check +from .._core import _unwrap_if_constexpr from .._layouts import DotOperandLayout from ._layouts import AMDWMMALayout @@ -34,3 +35,41 @@ def _wmma(version, a, b, acc, semantic): handle = semantic.dot(a, b, acc, input_precision=knobs.language.fp32_default, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle return ttgl.tensor(handle, acc.type) + + +def _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, scale_fn, semantic): + """ Shared implementation for AMD WMMA scaled and MFMA scaled operation. """ + + def _get_scale_shape(op_idx, operand, format): + operand_shape = [s for s in operand.type.shape] + scale_shape = operand_shape + unpack_factor = 2 if format.value == "e2m1" else 1 + if op_idx == 0: + k = scale_shape[-1] * unpack_factor + scale_shape[-1] = k // 32 + else: + k = scale_shape[-2] * unpack_factor + scale_shape[-2] = k // 32 + scale_shape[-2], scale_shape[-1] = scale_shape[-1], scale_shape[-2] + return scale_shape + + def _create_and_broadcast_default_scale(op_idx, scale, format): + operand = a if op_idx == 0 else b + + scale_shape = _get_scale_shape(op_idx, operand, format) + scale_layout = scale_fn(operand.type.layout, scale_shape, semantic) + + if isinstance(scale, ttgl.tensor) and scale.numel.value != 1: + assert scale.type.shape == scale_shape, \ + f"Expect scale tensor to have shape {scale_shape}, but got {scale.type.shape}" + return scale + + scale_value = _unwrap_if_constexpr(scale) + scale_value = 0x7F if scale_value is None else scale_value + return semantic.full(scale_shape, scale_value, ttgl.uint8, scale_layout) + + a_scale = _create_and_broadcast_default_scale(0, a_scale, a_format) + b_scale = _create_and_broadcast_default_scale(1, b_scale, b_format) + output = semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, fast_math=False, lhs_k_pack=True, + rhs_k_pack=True, out_dtype=ttgl.float32) + return ttgl.tensor(output.handle, acc.type) diff --git a/python/triton/experimental/gluon/language/amd/cdna4/__init__.py b/python/triton/experimental/gluon/language/amd/cdna4/__init__.py index 021b770a53..78c0b88802 100644 --- a/python/triton/experimental/gluon/language/amd/cdna4/__init__.py +++ b/python/triton/experimental/gluon/language/amd/cdna4/__init__.py @@ -1,13 +1,26 @@ -from triton.experimental.gluon.language import _core as ttgl -from ..._core import builtin, float32, _unwrap_if_constexpr +from ..._core import builtin, _unwrap_if_constexpr from ..._layouts import DotOperandLayout from .._layouts import AMDMFMALayout +from .._ops import _mma_scaled from ..cdna3 import _buffer_atomic_rmw_impl from ..cdna3 import * # NOQA: F403 from ..cdna3 import __all__ as __cdna3_all from . import async_copy -__all__ = [*__cdna3_all, "async_copy", "mfma_scaled"] +__all__ = [*__cdna3_all, "async_copy", "mfma_scaled", "get_mfma_scale_layout"] + + +def _get_mfma_scale_layout(dot_operand_layout, shape, semantic): + dot_operand_layout = _unwrap_if_constexpr(dot_operand_layout) + shape = _unwrap_if_constexpr(shape) + + op_idx = dot_operand_layout.operand_index + parent = dot_operand_layout.parent + assert isinstance(parent, AMDMFMALayout), "Expected parent to be an instance of AMDMFMALayout" + mdim = parent.instr_shape[0] + tiles_per_warp = parent.tiles_per_warp + warps_per_cta = parent.warps_per_cta + return semantic.builder.get_amd_mfma_scale_layout(op_idx, shape, mdim, tiles_per_warp, warps_per_cta) @builtin @@ -26,11 +39,11 @@ def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None) Args: a (tensor): The operand A to be multiplied. - a_scale (tensor): Scale factor for operand A. + a_scale (Optional[tensor]): Scale factor for operand A. a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`. b (tensor): The operand B to be multiplied. - b_scale (tensor): Scale factor for operand B. Available formats: `e2m1`, `e4m3`, `e5m2`. - b_format (str): Format of the operand B. + b_scale (Optional[tensor]): Scale factor for operand B. + b_format (str): Format of the operand B. Available formats: `e2m1`, `e4m3`, `e5m2`. acc (tensor): Accumulator tensor. """ layout = acc.type.layout @@ -43,14 +56,21 @@ def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None) assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}" assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}" - a_scale = _unwrap_if_constexpr(a_scale) - b_scale = _unwrap_if_constexpr(b_scale) - assert a_scale is not None and b_scale is not None, "Scales must not be None" + return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _get_mfma_scale_layout, _semantic) + - tensor = _semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, False, True, True, float32) +@builtin +def get_mfma_scale_layout(dot_operand_layout, shape, _semantic=None): + """ Get the scale layout for MFMA scaled operands. + + Args: + dot_operand_layout (DotOperandLayout): The dot operand layout. + shape (List[int]): The shape of the scale tensor. - ret_ty = ttgl.distributed_type(tensor.dtype, tensor.shape, layout) - return ttgl.tensor(tensor.handle, ret_ty) + Return: + layout (DistributedLinearLayout): The scale layout. + """ + return _get_mfma_scale_layout(dot_operand_layout, shape, _semantic) """ diff --git a/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py b/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py index f247a76729..2492eefb78 100644 --- a/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py +++ b/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py @@ -1,10 +1,21 @@ from ..._core import builtin, _unwrap_if_constexpr -from .._ops import _wmma, _verify_wmma -from triton.experimental.gluon.language import _core as ttgl +from .._ops import _wmma, _verify_wmma, _mma_scaled from .._layouts import AMDWMMALayout +from ..cdna3 import buffer_load, buffer_store from . import tdm -__all__ = ["tdm", "wmma", "wmma_scaled"] +__all__ = ["tdm", "wmma", "wmma_scaled", "buffer_load", "buffer_store", "get_wmma_scale_layout"] + + +def _get_wmma_scale_layout(dot_operand_layout, shape, semantic): + dot_operand_layout = _unwrap_if_constexpr(dot_operand_layout) + shape = _unwrap_if_constexpr(shape) + + op_idx = dot_operand_layout.operand_index + parent = dot_operand_layout.parent + assert isinstance(parent, AMDWMMALayout), "Expected parent to be an instance of AMDMFMALayout" + warps_per_cta = parent.warps_per_cta + return semantic.builder.get_amd_wmma_scale_layout(op_idx, shape, warps_per_cta) @builtin @@ -35,10 +46,10 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None) Args: a (tensor): The operand A to be multiplied. - a_scale (tensor): Scale factor for operand A. + a_scale (Optional[tensor]): Scale factor for operand A. a_format (str): Format of the operand A. Available formats: `e2m1`, `e4m3`, `e5m2`. b (tensor): The operand B to be multiplied. - b_scale (tensor): Scale factor for operand B. + b_scale (Optional[tensor]): Scale factor for operand B. b_format (str): Format of the operand B. Available formats: `e2m1`, `e4m3`, `e5m2`. acc (tensor): Accumulator tensor. """ @@ -59,10 +70,18 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None) assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}" assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}" - a_scale = _unwrap_if_constexpr(a_scale) - b_scale = _unwrap_if_constexpr(b_scale) - assert a_scale is not None and b_scale is not None, "Scales must not be None" + return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _get_wmma_scale_layout, _semantic) - handle = _semantic.dot_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, fast_math=False, lhs_k_pack=True, - rhs_k_pack=True, out_dtype=acc.dtype).handle - return ttgl.tensor(handle, acc.type) + +@builtin +def get_wmma_scale_layout(dot_operand_layout, shape, _semantic=None): + """ Get the scale layout for WMMA scaled operands. + + Args: + dot_operand_layout (DotOperandLayout): The dot operand layout. + shape (List[int]): The shape of the scale tensor. + + Return: + layout (DistributedLinearLayout): The scale layout. + """ + return _get_wmma_scale_layout(dot_operand_layout, shape, _semantic) diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py index b3244d5a6b..2636d1d72e 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py @@ -334,7 +334,7 @@ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descrip builder = _semantic.builder shape = self.shape[1:] layout = self.layout - ret = tensor_memory_descriptor(None, self.dtype, shape, layout, self.type.alloc_shape) + ret = tensor_memory_descriptor(None, self.dtype, shape, layout, shape) ret.handle = builder.create_memdesc_index(ret.type.to_ir(builder), self.handle, index.handle) return ret diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/float2.py b/python/triton/experimental/gluon/language/nvidia/blackwell/float2.py index 87d4bf8197..c06b103f36 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/float2.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/float2.py @@ -72,6 +72,7 @@ def _fma_f32x2(a, b, c): class Float2Tensor: value: ttgl.tensor + @constexpr_function def __init__(self, value: ttgl.tensor): self.value = value diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py index 60a0724c55..717331e53c 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -5,6 +5,7 @@ store_wait, tensor_descriptor, tensor_descriptor_type, + make_tensor_descriptor, ) __all__ = [ @@ -15,6 +16,7 @@ "store_wait", "tensor_descriptor", "tensor_descriptor_type", + "make_tensor_descriptor", ] diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py index ea26e20bd2..a7d3cd4e50 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -22,6 +22,14 @@ class tensor_descriptor_type(base_type): def __str__(self) -> str: return f"tensor_descriptor<{self.block_type}, {self.layout}>" + def _to_ir(self, builder: ir.builder) -> ir.type: + is_signed = self.block_type.element_ty.is_int_signed() + return builder.get_tensor_descriptor_layout_type( + self.block_type.to_ir(builder), + is_signed, + self.layout._to_ir(builder), + ) + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: handle = handles[cursor] cursor += 1 @@ -95,3 +103,66 @@ def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None): def store_wait(pendings, _semantic=None): pendings = _unwrap_if_constexpr(pendings) _semantic.builder.create_async_tma_store_wait(pendings) + + +@builtin +def make_tensor_descriptor( + base: ttgl.tensor, + shape: List[ttgl.tensor], + strides: List[ttgl.tensor], + block_shape: List[ttgl.constexpr], + layout: NVMMASharedLayout, + padding_option="zero", + _semantic=None, +) -> tensor_descriptor: + padding_option = _unwrap_if_constexpr(padding_option) + + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + assert isinstance(base.dtype, ttgl.pointer_type) + elem_size = base.dtype.element_ty.primitive_bitwidth // 8 + contig_dim_size = ttgl._unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + last_stride = ttgl._unwrap_if_constexpr(strides[-1]) + if last_stride != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") + + shape = [_semantic.make_scalar(x, ttgl.int32) for x in shape] + strides = [_semantic.make_scalar(ttgl._unwrap_if_constexpr(x), ttgl.int64) for x in strides] + + # Check whether `block_shape` is static + block_shape = ttgl._unwrap_shape(block_shape) + + assert isinstance(base.type, ttgl.pointer_type) + block_type = ttgl.block_type(base.type.element_ty, block_shape) + base_handle = base.handle + + padding = _semantic._str_to_padding_option(padding_option) + + layout = _unwrap_if_constexpr(layout) + assert isinstance(layout, NVMMASharedLayout), \ + "Expected layout to be a NVMMASharedLayout" + + shape_type = ttgl.tuple(shape).type + strides_type = ttgl.tuple(strides).type + ty = tensor_descriptor_type(block_type, shape_type, strides_type, layout) + + if base.type.element_ty.is_int() and padding == ttgl.ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer blocks") + handle = _semantic.builder.create_make_tensor_descriptor( + ty._to_ir(_semantic.builder), + base_handle, + [s.handle for s in shape], + [s.handle for s in strides], + padding, + ) + return tensor_descriptor(handle, shape, strides, block_type, layout) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 8768ca028b..2dfa88b2ba 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1544,13 +1544,15 @@ def _get_instance(this_cls): def __new__(this_cls, *args, _semantic=None, _generator=None, **kwargs): # Call into the user-defined constructor. instance = this_cls._get_instance() - if isinstance(cls.__init__, JITCallable): - raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function") extra_kwargs = {} - if "_semantic" in inspect.signature(cls.__init__).parameters: - extra_kwargs["_semantic"] = _semantic - if "_generator" in inspect.signature(cls.__init__).parameters: - extra_kwargs["_generator"] = _generator + if isinstance(cls.__init__, JITCallable): + # raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function") + pass + else: + if "_semantic" in inspect.signature(cls.__init__).parameters: + extra_kwargs["_semantic"] = _semantic + if "_generator" in inspect.signature(cls.__init__).parameters: + extra_kwargs["_generator"] = _generator cls.__init__(instance, *args, **extra_kwargs, **kwargs) # Require that the user-defined constructor initialized all fields. @@ -1577,11 +1579,15 @@ def type(self): return _aggregate_type(aggregate_value, [(name, getattr(self, name).type) for name in cls.__annotations__.keys()]) + hash_attrs = [cls.__init__] + for (name, member) in inspect.getmembers(cls): if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITCallable): if name != "__init__": setattr(aggregate_value, name, member) + hash_attrs.append(member) + aggregate_value.hash_attrs = hash_attrs aggregate_value.__name__ = cls.__name__ aggregate_value.__module__ = cls.__module__ aggregate_value.__qualname__ = cls.__qualname__ @@ -1725,8 +1731,9 @@ def trans(input: tensor, *dims, _semantic=None): """ Permutes the dimensions of a tensor. - If the parameter :code:`dims` is not specified, the function defaults to a (1,0) permutation, - effectively transposing a 2D tensor. + If the parameter :code:`dims` is not specified, the function defaults to + swapping the last two axes, thereby performing an (optionally batched) + 2D transpose. :param input: The input tensor. :param dims: The desired ordering of dimensions. For example, @@ -1743,7 +1750,10 @@ def trans(input: tensor, *dims, _semantic=None): """ dims = _unwrap_iterable(dims) if not dims: - dims = (1, 0) + n = len(input.shape) + if n < 2: + raise ValueError("tl.trans invoked with a 0- or 1-dimensional tensor") + dims = list(builtins.range(n - 2)) + [n - 1, n - 2] return _semantic.permute(input, dims) @@ -1765,7 +1775,7 @@ def permute(input, *dims, _semantic=None): permute(x, 2, 1, 0) :py:func:`trans` is equivalent to this function, except when - :code:`dims` is empty, it tries to do a (1,0) permutation. + :code:`dims` is empty, it tries to swap the last two axes. """ dims = _unwrap_iterable(dims) return _semantic.permute(input, dims) @@ -2018,7 +2028,36 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i out_dtype = _unwrap_if_constexpr(out_dtype) max_num_imprecise_acc = _unwrap_if_constexpr(max_num_imprecise_acc) acc = _unwrap_if_constexpr(acc) - return _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype) + + # check shapes make sense: + a_shape = list(input.shape) + b_shape = list(other.shape) + assert len(a_shape) == len(b_shape) >= 2, "input and other must have equal ranks >= 2" + assert a_shape[:-2] == b_shape[:-2], "input and other must have equal batch shapes" + assert a_shape[-1] == b_shape[-2], "input and other must have equal reduction dimensions" + + # compute shape of accumulator: + c_shape = a_shape[:-1] + [b_shape[-1]] + if acc is not None: + assert list(acc.shape) == c_shape, "accumulator shape is incompatible" + rank = len(c_shape) + + if rank >= 4: + batch_size = 1 + for i in builtins.range(rank - 2): + batch_size *= c_shape[i] + input = _semantic.reshape(input, [batch_size] + a_shape[-2:], can_reorder=False) + other = _semantic.reshape(other, [batch_size] + b_shape[-2:], can_reorder=False) + if acc is not None: + acc = _semantic.reshape(acc, [batch_size] + c_shape[-2:], can_reorder=False) + + res = _semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype) + + if rank >= 4: + res = _semantic.reshape(res, c_shape, can_reorder=False) + + assert list(res.shape) == c_shape, "output shape is unexpected" + return res @builtin diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index e12509f4f2..0c4d710496 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -36,7 +36,7 @@ def __init__(self, fn, arg_names, configs, key, reset_to_zero, restore_value, pr self.keys = key self.cache: Dict[Tuple, Config] = {} self.arg_names = arg_names - self.cache_results = cache_results or (knobs.autotuning.cache and not knobs.runtime.interpret) + self.cache_results = (cache_results or knobs.autotuning.cache) and not knobs.runtime.interpret # Reset to zero or restore values self.reset_to_zero = [] diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index a0b5d43c69..d5be6dc864 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -122,6 +122,11 @@ def record_reference(self, val, var_dict=None, name=None): if val is None or type(val) is ModuleType: return + if getattr(val, "__triton_aggregate__", False): + for attr in val.hash_attrs: + self.record_reference(attr) + return + if getattr(val, "__triton_builtin__", False): return @@ -735,7 +740,6 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o # TODO(jlebar): Remove uses of these fields outside this file, then # remove the fields here. self.arg_names = [p.name for p in self.params] - self.constexprs = [p.num for p in self.params if p.is_constexpr] # Hooks that will be called prior to executing "run" self.pre_run_hooks = [] diff --git a/python/triton_kernels/bench/bench_mlp.py b/python/triton_kernels/bench/bench_mlp.py index a640e07d71..11c50716ad 100644 --- a/python/triton_kernels/bench/bench_mlp.py +++ b/python/triton_kernels/bench/bench_mlp.py @@ -69,6 +69,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d x_dtype = torch.float8_e4m3fnuz input_x = torch.randn((batch // DP, dim1), device=dev) + expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev)) # run layer fpath = Path(tempfile.mktemp()) proton.start(str(fpath), hook="triton") @@ -78,7 +79,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d if n_expts_tot > 1: # sparse logits = matmul_ogs(xg, wg, bg, precision_config=pcg) x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP, - TP=TP) + TP=TP, expt_assignment=expt_assignment) else: # dense x = triton_dist.all_gather(input_x, dim=0) rdata, gather_indx, scatter_indx, metadata = None, None, None, None @@ -86,7 +87,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, precision_config=pc2) - x = triton_dist.reduce_scatter(x, metadata=metadata, dim=0) + x = triton_dist.reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) proton.finalize() return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*") @@ -136,6 +137,8 @@ def roofline_mlp(batch_sizes, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d parser.add_argument("--name", type=str, choices=["dense", "gpt-oss-x2"]) parser.add_argument("--quantized", action="store_true", default=False) args = parser.parse_args() + if args.tp > 1: + raise NotImplementedError("TP>1 is not supported yet in distributed mode.") dtypes = quantized_dtypes if args.quantized else dense_dtypes if args.name == "dense": assert args.ep == 1, "EP must be 1 for dense" diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 75099b2949..924952a83b 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -5,56 +5,41 @@ import torch.multiprocessing as mp from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Optional -import triton -import triton.language as tl import triton_kernels import triton_kernels.swiglu +from triton_kernels.reduce import reduce from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx -from triton_kernels.topk import topk_torch from triton_kernels.topk import topk from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq from triton_kernels.tensor_details import layout -from triton_kernels.tensor import BIT, SparseMatrix, Bitmatrix, make_ragged_tensor_metadata +from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata +from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment from bench_utils import quantize_weight -def legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act): - sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix) - dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx - combine_indx = sparse_logits.mask_metadata.row_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - gate_scal = sparse_logits.vals.flatten()[combine_indx] - routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, - ragged_batch_metadata) - gather_idx = GatherIndx(combine_indx, dispatch_indx) - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) - return routing_data, gather_idx, scatter_idx - - -def legacy_routing(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None): - if sm_first: - logits = torch.softmax(logits, dim=-1) - sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows) - return legacy_routing_from_bitmatrix(sparse_logits.mask, sparse_logits.vals, sparse_logits.indx, logits.shape[-1], - n_expts_act) - - @dataclass class ReduceScatterMetadata: - input_split_sizes: list[int] - ep_indx: torch.Tensor - EP: int = 1 - TP: int = 1 + mode: str + active_indx: Optional[torch.Tensor] = None + dispatch_indx: Optional[torch.Tensor] = None + combine_indx: Optional[torch.Tensor] = None def _is_distributed_launch() -> bool: return int(os.environ.get("WORLD_SIZE", "1")) > 1 +def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> Optional[ExptAssignment]: + if not _is_distributed_launch(): + return None + expt_dict = make_expt_dict_uniform(EP, n_expts_tot) + return make_expt_assignment(EP, n_expts_tot, expt_dict, device) + + def setup() -> Tuple[int, int]: if _is_distributed_launch(): world_size = int(os.environ["WORLD_SIZE"]) @@ -102,482 +87,79 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor: def reduce_scatter( input_tensor: torch.Tensor, - metadata: ReduceScatterMetadata = None, + n_expts_act: int, + metadata: ReduceScatterMetadata, + expt_assignment: Optional[ExptAssignment] = None, dim: int = 0, op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, ) -> torch.Tensor: if _is_distributed_launch(): - - def dtype_cast(dtype: torch.dtype) -> torch.dtype: - # check if dtype is fp8, then convert it to float16 before reducing - if dtype in [torch.float16, torch.bfloat16, torch.float32]: - return dtype - else: - return torch.float16 - - world_size = dist.get_world_size() - original_dtype = input_tensor.dtype - intermediate_dtype = dtype_cast(original_dtype) - if metadata and metadata.input_split_sizes: - assert dim == 0, "metadata only works with dim=0" - input_list = list(input_tensor.split(metadata.input_split_sizes, dim=0)) - output_list = all_to_all(input_list, dim=0) - n_tokens = metadata.ep_indx.size(dim) - other_dims = input_tensor.shape[1:] - output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype) - for i in range(world_size): - ep_rank = i // metadata.TP - mask = torch.any(metadata.ep_indx == ep_rank, dim=1) - if op == dist.ReduceOp.SUM: - output_tensor[mask] += output_list[i].to(intermediate_dtype) - else: - raise NotImplementedError(f"Reduce operation {op} is not implemented.") - return output_tensor.to(original_dtype) + if metadata.mode and metadata.mode == "ep_sharding": + if dim != 0 or op != dist.ReduceOp.SUM: + raise NotImplementedError("Only dim=0 and op=SUM are supported for MoE reduce_scatter.") + output = convert_ep_to_dp(input_tensor, expt_assignment, metadata.active_indx, metadata.combine_indx) + # weighted average of the output token from experts + output = output.view(-1, n_expts_act, output.shape[-1]) + output, _ = reduce(output, dim=1) + return output else: - input_list = list(input_tensor.chunk(world_size, dim=dim)) - shape = input_list[0].shape - input_list = [x.to(intermediate_dtype) for x in input_list] - output_tensor = input_tensor.new_empty(shape, dtype=intermediate_dtype) - dist.reduce_scatter(output_tensor, input_list, op=op) - return output_tensor.to(original_dtype) + raise NotImplementedError(f"Distributed reduce_scatter mode {metadata.mode} is not implemented yet.") else: return input_tensor -def all_to_all(input_list: list[torch.Tensor], dim: int = 0) -> list[torch.Tensor]: - if _is_distributed_launch(): - # Check if all tensors have only one dimension with different sizes - for t in input_list: - for d in range(t.dim()): - if d != dim and t.size(d) != input_list[0].size(d): - raise ValueError("All tensors must have the same size in all dimensions except the specified one.") - input_sizes = [t.size(dim) for t in input_list] - input_sizes = torch.tensor(input_sizes, device=input_list[0].device).unsqueeze(0) - input_sizes = all_gather(input_sizes, dim=0) - output_split_sizes = input_sizes[:, dist.get_rank()].tolist() - other_dims = list(input_list[0].shape[:dim] + input_list[0].shape[dim + 1:]) - output_list = [ - torch.empty([size] + other_dims, dtype=input_list[0].dtype, device=input_list[0].device) - for size in output_split_sizes - ] - dist.all_to_all(output_list, input_list) - return output_list - else: - return input_list - - -def _apply_parallelism( - expt_scal: torch.Tensor, - expt_indx: torch.Tensor, - x: torch.Tensor, - chunk_size: int, - EP: int = 1, - TP: int = 1, -): - if EP > 1: - # Distributed Expert Parallelism - ep_indx = expt_indx // chunk_size - - # Partition tokens by expert parallelism rank - expt_scal_list = [] - expt_indx_list = [] - x_list = [] - - for i in range(EP): - mask = torch.any(ep_indx == i, dim=1) - expt_scal_masked = expt_scal[mask] - expt_indx_masked = expt_indx[mask] - x_masked = x[mask] - - for _ in range(TP): - expt_scal_list.append(expt_scal_masked) - expt_indx_list.append(expt_indx_masked) - x_list.append(x_masked) - - # Exchange data across processes - expt_scal_list = all_to_all(expt_scal_list, dim=0) - expt_indx_list = all_to_all(expt_indx_list, dim=0) - x_list = all_to_all(x_list, dim=0) - - output_split_sizes = [x.size(0) for x in expt_scal_list] - expt_scal = torch.cat(expt_scal_list, dim=0) - expt_indx = torch.cat(expt_indx_list, dim=0) - x = torch.cat(x_list, dim=0) - - # Filter for local experts only - ep_rank = dist.get_rank() // TP - mask = (expt_indx // chunk_size) == ep_rank - expt_indx -= ep_rank * chunk_size - expt_scal = expt_scal.masked_fill(~mask, 0) - expt_indx = expt_indx.masked_fill(~mask, chunk_size) - else: - # Distributed Data Parallelism - ep_indx = None - output_split_sizes = None - x = all_gather(x, dim=0) - expt_scal = all_gather(expt_scal, dim=0) - expt_indx = all_gather(expt_indx, dim=0) - - return expt_scal, expt_indx, ep_indx, x, output_split_sizes - - -def routing_torch(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None, EP=1, TP=1): - _, n_expts_tot = logits.shape - - if n_rows: - logits = logits[:n_rows] - if sm_first: - logits = torch.softmax(logits, dim=-1) - - expt_scal, expt_indx = topk_torch(logits, n_expts_act, expt_indx, has_user_provided_indx=expt_indx is not None) - expt_indx = expt_indx.int() - if not sm_first: - expt_scal = torch.softmax(expt_scal, dim=-1) - - # Sort each token's selections by expert - expt_indx, sort_indices = torch.sort(expt_indx, dim=1, stable=True) - expt_scal = torch.gather(expt_scal, 1, sort_indices) - - chunk_size = n_expts_tot // EP - - expt_scal, expt_indx, ep_indx, x, output_split_sizes = _apply_parallelism(expt_scal, expt_indx, x, chunk_size, - EP=EP, TP=TP) - - # Flatten topk data - expt_scal = expt_scal.reshape(-1) - expt_indx = expt_indx.reshape(-1).to(torch.int32) - - # Sort by expert_id for contiguous experts in matmul - expt_indx, topk_indx = torch.sort(expt_indx, stable=True) - gate_indx = torch.argsort(topk_indx, stable=True) - - mask = expt_indx != chunk_size - topk_indx[~mask] = -1 - gate_indx[gate_indx >= mask.sum()] = -1 - gate_scal = expt_scal[topk_indx] - hist = torch.histc(expt_indx[mask], bins=chunk_size, min=0, max=chunk_size - 1) - - # Pack the matmul data structures - gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) - scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) - n_gates = mask.sum().item() - expt_data = make_ragged_tensor_metadata(hist, n_gates) - - return ( - x, - RoutingData(gate_scal, hist, chunk_size, n_expts_act, expt_data=expt_data), - gather_indx, - scatter_indx, - ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP), - ) - - -@triton.jit -def pack_bitmatrix( - bitmatrix, - expt_indx, - n_rows, - n_cols, - n_expts_act, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - sentinel: tl.constexpr, -): - """ - Packs expt_indx into a bitmatrix. - """ - pid_m = tl.program_id(0) - offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offsets_k = tl.arange(0, BLOCK_SIZE_K) - offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :] - mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :] - indices = tl.load(expt_indx + offsets, mask=mask, other=sentinel) - div = indices // 32 - rem = indices % 32 - iters = tl.cdiv(sentinel, BLOCK_SIZE_K) - for i in range(iters): - offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) - x = tl.where(div[:, :, None] == offs[None, None, :], (1 << rem)[:, :, None], 0) - y = tl.reduce_or(x, axis=1) - bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * n_cols + offs[None, :] - tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) - - -@triton.jit -def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr): - pid_m = tl.program_id(0) - cutoff_word = cutoff // 32 - cutoff_bit = cutoff % 32 - cutoff_mask = (1 << (cutoff_bit)) - 1 - for start_n in range(0, shape_bn, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn) - values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values) - values = tl.where(offs_n > cutoff_word, 0, values) - tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn) - - -class PruneRouting(torch.autograd.Function): - - @staticmethod - def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): - from triton_kernels.compaction import compaction - n_tokens_pad = expt_scal.shape[0] - assert n_expts_tot % simulated_ep == 0 - _routing_clear_bitmatrix[(n_tokens_pad, )]( - bitmatrix.storage.data, - bitmatrix.storage.data.stride(0), - bitmatrix.storage.data.stride(1), - bitmatrix.storage.data.shape[1], - n_expts_tot // simulated_ep, - BLOCK_N=512, - ) - # perform compaction to update expt_scal / expt_indx - expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix) - n_expts_tot = n_expts_tot // simulated_ep - bitmatrix.shape[-1] = n_expts_tot - return expt_scal, expt_indx, bitmatrix - - -def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): - return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep) - - -def routing_triton(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None, EP=1, TP=1): - _, n_expts_tot = logits.shape - - if sm_first: - logits = torch.softmax(logits, dim=-1) - - sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows) - expt_scal = sparse_logits.vals - expt_indx = sparse_logits.indx - expt_indx = expt_indx.int() - - chunk_size = n_expts_tot // EP - - expt_scal, expt_indx, ep_indx, x, output_split_sizes = _apply_parallelism(expt_scal, expt_indx, x, chunk_size, - EP=EP, TP=TP) - - # TODO: Skip all the following if `EP == 1` - # Recover bitmatrix for local experts - BLOCK_SIZE_M = 512 - BLOCK_SIZE_K = 32 - # The sentinel value is chunk_size + 1 instead of chunk_size to ensure the bitmatrix buffer - # doesn't overflow. For example, cdiv(32, 32) is 1, while the 32th bit is on the second column. - sentinel = chunk_size + 1 - n_cols = triton.cdiv(sentinel, BLOCK_SIZE_K) - n_rows = expt_indx.size(0) - bitmatrix = torch.zeros((n_rows, n_cols), dtype=torch.uint32, device=expt_indx.device) - grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), ) - - pack_bitmatrix[grid]( - bitmatrix, - expt_indx, - n_rows, - n_cols, - n_expts_act, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_K=BLOCK_SIZE_K, - sentinel=sentinel, - ) - bitmatrix_shape = [n_rows, triton.cdiv(chunk_size, BLOCK_SIZE_K) * 32] - bitmatrix_shape_max = [n_rows, None] - bitmatrix = Bitmatrix(bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max) - expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, EP) - routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, - n_expts_tot // EP, n_expts_act) - - return ( - x, - routing_data, - gather_indx, - scatter_indx, - ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP), - ) - - -def routing(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None, EP=1, TP=1, - backend="triton") -> Tuple[RoutingData, GatherIndx, ScatterIndx, ReduceScatterMetadata]: +# TODO: support TP > 1 +# TODO: clean up duplicate code with triton_kernels.test_distributed.py +# TODO: Support nonuniform expert assignment +def routing( + x, logits, n_expts_act, sm_first: bool = False, y_indx: Optional[torch.Tensor] = None, EP: int = 1, TP: int = 1, + expt_assignment: Optional[ExptAssignment] = None, mode: str = "ep_sharding" +) -> Tuple[torch.Tensor, RoutingData, GatherIndx, ScatterIndx, Optional[ReduceScatterMetadata]]: + n_expts_tot = logits.shape[-1] if _is_distributed_launch(): - assert backend in ["torch", "triton"], "backend must be either 'torch' or 'triton'" - if backend == "torch": - return routing_torch(x, logits, n_expts_act, sm_first, expt_indx, n_rows, EP, TP) - elif backend == "triton": - return routing_triton(x, logits, n_expts_act, sm_first, expt_indx, n_rows, EP, TP) + if mode == "ep_sharding": + if not expt_assignment: + raise ValueError("expt_assignment must be provided for distributed routing.") + if TP > 1: + raise NotImplementedError("TP > 1 is not supported in distributed MoE benchmark yet.") + rank = dist.get_rank() + expt_map = expt_assignment.expt_map[rank, :] + logits_global = topk( + logits, + n_expts_act, + apply_softmax=not sm_first, + y_indx=y_indx, + all_gather=True, + ) + active_indx = logits_global.indx + expt_sizes = logits_global.mask_metadata.col_sum + dispatch_indx = logits_global.mask_metadata.col_sorted_indx + combine_indx = logits_global.mask_metadata.row_sorted_indx + logits_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0]) + x = convert_dp_to_ep(x, expt_assignment, active_indx, dispatch_indx) + logits_local_metadata = remap_ragged_tensor_metadata(logits_global_metadata, expt_map) + gate_scal = logits_global.vals.flatten()[combine_indx] + rdata = RoutingData(gate_scal, expt_sizes, n_expts_tot // EP, n_expts_act, logits_local_metadata) + reduce_scatter_metadata = ReduceScatterMetadata( + mode=mode, + active_indx=active_indx, + dispatch_indx=dispatch_indx, + combine_indx=combine_indx, + ) + return x, rdata, None, None, reduce_scatter_metadata else: - raise ValueError(f"Unknown backend: {backend}") + raise NotImplementedError(f"Distributed routing mode {mode} is not implemented yet.") else: - return x, *legacy_routing(logits, n_expts_act, sm_first, expt_indx, n_rows), None - - -# The following dummy methods simulate the behavior of distributed operations -# in a non-distributed environment for testing purposes. -# Assuming each rank has the same data for simplicity. - - -def dummy_all_gather(out, x): - out[0].copy_(x) - out[1].copy_(x) - - -def dummy_all_to_all(output_list, input_list): - output_list[0].copy_(input_list[0]) - output_list[1].copy_(input_list[0]) - - -def dummy_reduce_scatter(out, x_list, op): - out.copy_(x_list[0] * 2) - - -def test_all_gather_non_distributed(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "1") - x = torch.randn(4, 5) - result = all_gather(x, dim=0) - torch.testing.assert_close(result, x) - - -@pytest.mark.parametrize("dim", [0, 1]) -def test_all_gather_distributed(monkeypatch, dim): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - - x = torch.randn(4, 4) - result = all_gather(x, dim=dim) - expected = torch.cat([x, x], dim=dim) - torch.testing.assert_close(result, expected) - - -def test_reduce_scatter_non_distributed(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "1") - x = torch.randn(4, 6) - result = reduce_scatter(x, dim=0) - torch.testing.assert_close(result, x) - - -def test_reduce_scatter_distributed(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "reduce_scatter", dummy_reduce_scatter) - - x = torch.randn(4, 6) - expected = x.chunk(2, dim=0)[0] * 2 - - result = reduce_scatter(x, dim=0) - torch.testing.assert_close(result, expected) - - -def test_reduce_scatter_distributed_with_metadata(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "get_rank", lambda: 0) - monkeypatch.setattr(dist, "all_to_all", dummy_all_to_all) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - - input_split_sizes = [1, 1] - ep_indx = torch.tensor([[0], [1]]) - metadata = ReduceScatterMetadata(input_split_sizes=input_split_sizes, ep_indx=ep_indx, EP=2) - # Assume the current ep rank is 0. - # [1, 2] comes from rank 0 - # [3, 4] comes from rank 1. - x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) - - result = reduce_scatter(x, metadata=metadata, dim=0) - torch.testing.assert_close(result, torch.tensor([[1, 2], [1, 2]], dtype=torch.float32)) - - -def test_routing_distributed_EP(monkeypatch): - # Test distributed routing with EP=1 (token_mask should be None) - monkeypatch.setenv("WORLD_SIZE", "2") - # Set environment for local rank and distributed group - monkeypatch.setenv("LOCAL_RANK", "0") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "get_rank", lambda: 0) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - monkeypatch.setattr(dist, "all_to_all", dummy_all_to_all) - - # NOTE: must set `device="cuda"` since `routing` expects CUDA tensors. - logits = torch.tensor([[0.1, 0.2, 0.4, 0.3], [0.5, 0.4, 0.3, 0.1]], device="cuda", dtype=torch.float16) - x = torch.randn_like(logits, device="cuda", dtype=torch.float16) - n_expts_act = 2 - EP = 2 - expt_indx = torch.tensor([[0, 1], [0, 1]], device="cuda").reshape(-1) - topk_indx = torch.argsort(expt_indx, stable=True) - gate_indx = torch.argsort(topk_indx, stable=True) - _, rdata, gather_indx, scatter_indx, metadata = routing(x, logits, n_expts_act, EP=EP) - assert torch.equal(gather_indx.src_indx, topk_indx.int()) - assert torch.equal(gather_indx.dst_indx, gate_indx.int()) - assert torch.equal(scatter_indx.src_indx, gate_indx.int()) - assert torch.equal(scatter_indx.dst_indx, topk_indx.int()) - - -def test_all_to_all(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setenv("LOCAL_RANK", "0") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "get_rank", lambda: 0) - monkeypatch.setattr(dist, "all_to_all", dummy_all_to_all) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - - input_list = [torch.tensor([1, 2], dtype=torch.float32), torch.tensor([3, 4], dtype=torch.float32)] - output_list = all_to_all(input_list) - assert torch.equal(output_list[0], torch.tensor([1, 2], dtype=torch.float32)) - assert torch.equal(output_list[1], torch.tensor([1, 2], dtype=torch.float32)) - assert len(output_list) == 2 - - -def test_pack_bitmatrix(): - # Test parameters - n_rows, n_expts_act = 4, 3 - sentinel = 63 # We have experts 0-62, and 63 is a dummy value - - expt_indx = torch.tensor([[0, 33, 63], [31, 32, 33], [5, 10, 15], [0, 62, 63]], dtype=torch.int32, device="cuda") - n_cols = triton.cdiv(sentinel, 32) - bitmatrix = torch.zeros((n_rows, n_cols), dtype=torch.uint32, device="cuda") - - BLOCK_SIZE_M = 128 - BLOCK_SIZE_K = 32 - grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), ) - - pack_bitmatrix[grid]( - bitmatrix, - expt_indx, - n_rows, - n_cols, - n_expts_act, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_K=BLOCK_SIZE_K, - sentinel=sentinel, - ) - # Prune the bitmatrix to remove dummy values - _routing_clear_bitmatrix[(n_rows, )]( - bitmatrix, - bitmatrix.stride(0), - bitmatrix.stride(1), - bitmatrix.shape[1], - sentinel, - BLOCK_N=128, - ) - - # Old pytorch version do not have "bitwise_and_cpu" not implemented for 'UInt32' - bitmatrix = bitmatrix.cpu().numpy() - - # Verify bit packing - def is_bit_set(row, expert_id): - word_idx, bit_idx = expert_id // 32, expert_id % 32 - return (bitmatrix[row, word_idx] & (1 << bit_idx)) != 0 - - # Check specific cases - assert is_bit_set(0, 0) and is_bit_set(0, 33) and not is_bit_set(0, 63) # Token 0 - assert is_bit_set(1, 31) and is_bit_set(1, 32) and is_bit_set(1, 33) # Token 1 - assert is_bit_set(2, 5) and is_bit_set(2, 10) and is_bit_set(2, 15) # Token 2 - assert is_bit_set(3, 0) and not is_bit_set(3, 63) and is_bit_set(3, 62) # Token 3 + logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first) + dispatch_indx = logits.mask_metadata.col_sorted_indx + combine_indx = logits.mask_metadata.row_sorted_indx + ragged_batch_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0]) + gate_scal = logits.vals.flatten()[combine_indx] + routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, + ragged_batch_metadata) + gather_indx = GatherIndx(combine_indx, dispatch_indx) + scatter_indx = ScatterIndx(dispatch_indx, combine_indx) + return x, routing_data, gather_indx, scatter_indx, None def gather_ep(rank, world_size, param, TP, EP): @@ -679,13 +261,14 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac } xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype]) x0 = all_gather(xd, dim=0) + expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev)) # single-GPU pass def single(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: logits = matmul_ogs(xg, wg, bg, precision_config=pcg) - rdata, gi, si = legacy_routing(logits, n_expts_act) + x, rdata, gi, si, _ = routing(x, logits, n_expts_act) else: rdata = gi = si = None x = matmul_ogs(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act) @@ -696,13 +279,13 @@ def distributed(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: # sparse logits = matmul_ogs(xg, wg, bg, precision_config=pcg) - x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP) + x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment) else: # dense x = all_gather(x, dim=0) rdata = gi = si = metadata = None x = matmul_ogs(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act) x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2) - x = reduce_scatter(x, metadata=metadata, dim=0) + x = reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) # gather the result from all GPUs, just for verification return all_gather(x, dim=0) @@ -721,31 +304,20 @@ def distributed(x): @pytest.mark.parametrize( "batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP", - [ - # dense cases - test parallelism - (1024, 1024, 1024, 1, 1, "bf16", "bf16", 1, 1), - (1024, 1024, 1024, 1, 1, "bf16", "bf16", 4, 1), - ] + - # dense cases - test precision - [(1024, 1024, 1024, 1, 1, "fp8", "fp8", 1, 1), (1024, 1024, 1024, 1, 1, "fp8", "fp8", 4, 1)] + # dense cases + [(1024, 1024, 1024, 1, 1, "bf16", "bf16", 1, 1), (1024, 1024, 1024, 1, 1, "fp8", "fp8", 1, 1)] # moe cases - test parallelism + [ (1024, 1024, 1024, 128, 2, "bf16", "bf16", 1, 1), - (1024, 1024, 1024, 128, 2, "bf16", "bf16", 4, 1), (1024, 1024, 1024, 128, 2, "bf16", "bf16", 1, 4), - (1024, 1024, 1024, 128, 2, "bf16", "bf16", 2, 2), ] + # moe cases - test precision ([ (1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 1), - (1024, 1024, 1024, 128, 2, "fp8", "mx4", 4, 1), (1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 4), - (1024, 1024, 1024, 128, 2, "fp8", "mx4", 2, 2), ] if has_native_mx4 else [ (1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 1), - (1024, 1024, 1024, 128, 2, "bf16", "mx4", 4, 1), (1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 4), - (1024, 1024, 1024, 128, 2, "bf16", "mx4", 2, 2), ]), ) def test_mlp_mp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, monkeypatch): @@ -756,8 +328,8 @@ def test_mlp_mp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, T pytest.skip("Test requires CUDA compute capability >= 9.0.") if is_hip() and get_cdna_version() == 4 and EP > 1: pytest.skip("[TODO] Unknown issue with CDNA 4 and EP > 1") - if TP > 1 and x_dtype == "fp8": - pytest.skip("[TODO] Testing FP8 is not supported for TP > 1.") + if TP > 1: + pytest.skip("[TODO] TP > 1 is not supported yet in distributed mode.") monkeypatch.setenv("WORLD_SIZE", f"{parallelism}") monkeypatch.setenv("MASTER_ADDR", "127.0.0.1") diff --git a/python/triton_kernels/reduce.py b/python/triton_kernels/reduce.py deleted file mode 100644 index e408ff5d76..0000000000 --- a/python/triton_kernels/reduce.py +++ /dev/null @@ -1,280 +0,0 @@ -from dataclasses import dataclass -import torch -import triton -import triton.language as tl -from triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn -from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale -from triton_kernels.numerics import InFlexData, OutFlexData, MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 -from typing import Optional -import types -import sys -from .specialize import specialize - -_kernels = dict() - - -@dataclass(frozen=True) -class FnSpecs: - name: str - fn: "triton.runtime.jit.JITFunction" - fn_arg_names: tuple[str] - fn_arg_do_not_specialize: tuple[str] = tuple() - - @staticmethod - def default(): - return FnSpecs("dflt", None, tuple()) - - -@dataclass(frozen=True) -class PostprocessFn: - specs: FnSpecs = FnSpecs.default() - fn_args: tuple[object] = tuple() - - -def get_kernels(fn_specs: FnSpecs = FnSpecs.default()): - global _kernels - key = (fn_specs.name, ) - if key in _kernels: - return _kernels[key] - spec_constants = {"POSTPROCESS_FN": fn_specs.fn} - spec_tuples = {"postprocess_fn_args": fn_specs.fn_arg_names} - do_not_specialize = fn_specs.fn_arg_do_not_specialize - module = types.ModuleType(f"reduce{'_'.join(key)}") - sys.modules[module.__name__] = module - module._reduce = specialize(_reduce, module, spec_constants, spec_tuples, do_not_specialize=do_not_specialize) - _kernels[key] = module - return module - - -@triton.jit -def _reduce(X, stride_xr, stride_x0, stride_x1, # x tensor (input) - XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale - Y, stride_y0, stride_y1, # y tensor (output) - YMx, stride_ymx0, stride_ymx1, # y mx scale - Mask, stride_mr, stride_m0, stride_m1, # mask tensor - Scale, stride_sr, stride_s0, stride_s1, # scale tensor - K, S0, S1, # shape (K = reduction dim; S0, S1 = output dims) - POSTPROCESS_FN: tl.constexpr, postprocess_fn_args, XFlex, # x flex (global) scale - YFlexExpected, YFlexActual, YFlexChecksum, Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale - IS_MASK_NONE: tl.constexpr, # - BROADCAST_R: tl.constexpr, # - BROADCAST_S0: tl.constexpr, # - BROADCAST_S1: tl.constexpr, # - IS_SCALE_NONE: tl.constexpr, # - SCALE_BROADCAST_R: tl.constexpr, # - SCALE_BROADCAST_S0: tl.constexpr, # - SCALE_BROADCAST_S1: tl.constexpr, # - BLOCK_S0: tl.constexpr, # - BLOCK_S1: tl.constexpr, # - ): - pid_s0 = tl.program_id(0) - pid_s1 = tl.program_id(1) - tl.static_assert(BLOCK_S1 % 32 == 0) - BLOCK_SMX1: tl.constexpr = BLOCK_S1 // 32 - offs_s0 = pid_s0 * BLOCK_S0 + tl.arange(0, BLOCK_S0) - offs_s1 = pid_s1 * BLOCK_S1 + tl.arange(0, BLOCK_S1) - offs_smx1 = pid_s1 * BLOCK_SMX1 + tl.arange(0, BLOCK_SMX1) - valid_s0 = offs_s0 < S0 - valid_s1 = offs_s1 < S1 - valid_smx1 = offs_smx1 < tl.cdiv(S1, 32) - y = tl.zeros((BLOCK_S0, BLOCK_S1), dtype=tl.float32) - x_flex_scale = load_scale(XFlex) - for k in tl.range(0, K, num_stages=2): - x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_s1[None, :] * stride_x1 - x = tl.load(x_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=0.0) - x = x.to(tl.float32) - if XMx is not None: - xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_smx1[None, :] * stride_xmx1 - xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_smx1[None, :], other=0.0) - xmx = (xmx.to(tl.uint32) << 23).to(tl.float32, bitcast=True) - x = (xmx[:, :, None] * x.reshape([BLOCK_S0, BLOCK_S1 // 32, 32])).reshape([BLOCK_S0, BLOCK_S1]) - x = x * x_flex_scale - if not IS_SCALE_NONE: - k_term_s = 0 if SCALE_BROADCAST_R else (k * stride_sr) - s0_term_s = 0 if SCALE_BROADCAST_S0 else (offs_s0[:, None] * stride_s0) - s1_term_s = 0 if SCALE_BROADCAST_S1 else (offs_s1[None, :] * stride_s1) - s_ptrs = Scale + k_term_s + s0_term_s + s1_term_s - s = tl.load(s_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1) - x = x * s - if not IS_MASK_NONE: - k_term = 0 if BROADCAST_R else (k * stride_mr) - s0_term = 0 if BROADCAST_S0 else (offs_s0[:, None] * stride_m0) - s1_term = 0 if BROADCAST_S1 else (offs_s1[None, :] * stride_m1) - m_ptrs = Mask + k_term + s0_term + s1_term - m = tl.load(m_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1) - x = tl.where(m != 0, x, 0.0) - y += x - if POSTPROCESS_FN is not None: - y = POSTPROCESS_FN(y, *postprocess_fn_args) - y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF) - y_ptrs = Y + offs_s0[:, None] * stride_y0 + offs_s1[None, :] * stride_y1 - if YMx is not None: - y, y_scale = quantize_mxfp8_fn(y, valid_s1[None, :]) - y_mx_ptrs = YMx + offs_s0[:, None] * stride_ymx0 + offs_smx1[None, :] * stride_ymx1 - tl.store(y_mx_ptrs, y_scale, mask=valid_s0[:, None] & valid_smx1[None, :]) - tl.store(y_ptrs, y, mask=valid_s0[:, None] & valid_s1[None, :]) - - -def reduce( - x: torch.Tensor, - dim: int, - mask: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - x_mxscale: Optional[torch.Tensor] = None, - x_flex: Optional[InFlexData] = InFlexData(), - y_flex: Optional[OutFlexData] = OutFlexData(), - y_flex_saturate_inf: bool = False, - postprocess_fn: Optional[PostprocessFn] = None, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Performs a reduction over the specified dimension of the input tensor, - optionally multiplied by `scale` and ignoring masked elements. - - Arguments: - - x: Tensor - input tensor to reduce. - - dim: int - dimension along which `x` should be reduce. - - mask: Optional[torch.Tensor] - integer mask of the same shape as `x` (or broadcastable to it). - entries that are `0` are ignored in the reduction. - if `mask is None`, all elements are included. - - scale: Optional[torch.Tensor] - scale factors of the same shape as `x` (or broadcastable to it). - the reduction is performed over `x * scale`. If `scale is None`, - a value of 1 is used everywhere. - - Returns: - - output: torch.Tensor - The reduced tensor with `dim` removed. - """ - if x.ndim != 3: - raise NotImplementedError("reduce only supports 3D inputs in this implementation") - if dim < 0: - dim += x.ndim - if dim not in (0, 1, 2): - raise ValueError("dim must be in {0,1,2}") - if x_mxscale is not None: - if dim == 2: - raise ValueError("reduction over the micro-scaled dimension not supported") - assert x.shape[:-2] == x_mxscale.shape[:-2] - assert triton.cdiv(x.shape[-1], 32) * 32 == x_mxscale.shape[-1] * 32 - assert dim != -1 - # assert not y_flex.is_per_batch - if postprocess_fn is None: - postprocess_fn = PostprocessFn() - if y_flex is None: - y_flex = OutFlexData() - if x_flex is None: - x_flex = InFlexData() - # input shapes - dims = (0, 1, 2) - nonred = tuple(d for d in dims if d != dim) - S0, S1 = x.shape[nonred[0]], x.shape[nonred[1]] - y = torch.empty((S0, S1), device=x.device, dtype=x.dtype) - y_mxscale = None - if x_mxscale is not None: - y_mxscale = torch.empty((S0, triton.cdiv(S1, 32)), device=x.device, dtype=x_mxscale.dtype) - # Strides for X along reduced and non-reduced dims - stride_xr = x.stride(dim) - stride_x0 = x.stride(nonred[0]) - stride_x1 = x.stride(nonred[1]) - # Strides for X mx scales - stride_xmxr = None if x_mxscale is None else x_mxscale.stride(dim) - stride_xmx0 = None if x_mxscale is None else x_mxscale.stride(nonred[0]) - stride_xmx1 = None if x_mxscale is None else x_mxscale.stride(nonred[1]) - # Strides for Y mx scales - stride_ymx0 = None if y_mxscale is None else y_mxscale.stride(0) - stride_ymx1 = None if y_mxscale is None else y_mxscale.stride(1) - # Mask strides (broadcast allowed via stride 0) - if mask is not None: - mstr0, mstr1, mstr2 = mask.stride() - stride_mr = (mstr0 if dim == 0 else (mstr1 if dim == 1 else mstr2)) - stride_m0 = (mstr0 if nonred[0] == 0 else (mstr1 if nonred[0] == 1 else mstr2)) - stride_m1 = (mstr0 if nonred[1] == 0 else (mstr1 if nonred[1] == 1 else mstr2)) - else: - stride_mr = stride_m0 = stride_m1 = 0 - # Scale strides (broadcast allowed via stride 0) - if scale is not None: - sstr0, sstr1, sstr2 = scale.stride() - stride_sr = (sstr0 if dim == 0 else (sstr1 if dim == 1 else sstr2)) - stride_s0 = (sstr0 if nonred[0] == 0 else (sstr1 if nonred[0] == 1 else sstr2)) - stride_s1 = (sstr0 if nonred[1] == 0 else (sstr1 if nonred[1] == 1 else sstr2)) - else: - stride_sr = stride_s0 = stride_s1 = 0 - K = x.shape[dim] - # Always use the 2D tiled kernel with constexpr metaprogramming for mask broadcasting - BLOCK_S0 = 64 - BLOCK_S1 = 128 - grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(S1, BLOCK_S1)) - mask_arg = mask if mask is not None else x - scale_arg = scale if scale is not None else x - reduce_kernel = get_kernels(postprocess_fn.specs)._reduce - reduce_kernel[grid]( - x, stride_xr, stride_x0, stride_x1, # - x_mxscale, stride_xmxr, stride_xmx0, stride_xmx1, # - y, y.stride(0), y.stride(1), # - y_mxscale, stride_ymx0, stride_ymx1, # - mask_arg, stride_mr, stride_m0, stride_m1, # - scale_arg, stride_sr, stride_s0, stride_s1, # - K, S0, S1, # - *postprocess_fn.fn_args, x_flex.scale, y_flex.expected_scale, y_flex.actual_scale, y_flex.checksum_scale, - y_flex_saturate_inf, # - IS_MASK_NONE=(mask is None), # - BROADCAST_R=(stride_mr == 0), # - BROADCAST_S0=(stride_m0 == 0), # - BROADCAST_S1=(stride_m1 == 0), # - IS_SCALE_NONE=(scale is None), # - SCALE_BROADCAST_R=(stride_sr == 0), # - SCALE_BROADCAST_S0=(stride_s0 == 0), # - SCALE_BROADCAST_S1=(stride_s1 == 0), # - BLOCK_S0=BLOCK_S0, # - BLOCK_S1=BLOCK_S1, # - num_warps=4 # - ) - return y, y_mxscale - - -def compute_actual_scale(x, dtype, per_batch_scale=False): - max_finite = { - torch.float8_e5m2: MAX_FINITE_FLOAT8E5, - torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV, - torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8, - }[dtype] - maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max() - return maxvals / max_finite - - -def reduce_torch(x: torch.Tensor, dim: int, mask: Optional[torch.Tensor] = None, # - scale: Optional[torch.Tensor] = None, # - x_mxscale: Optional[torch.Tensor] = None, # - x_flex: Optional[InFlexData] = InFlexData(), y_flex: Optional[OutFlexData] = OutFlexData(), - y_flex_saturate_inf: bool = False, postprocess_fn: Optional[callable] = None): - from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch, upcast_from_mxfp_torch - x_dtype = x.dtype - # upcast input - if x_mxscale is not None: - x = upcast_from_mxfp_torch(x, x_mxscale, torch.float32, axis=-1) - x = x.to(torch.float32) - if x_flex is not None: - x *= x_flex.scale - # upcast scale - if scale is None: - scale = torch.ones(1, dtype=torch.float32, device=x.device) - scale = scale.to(torch.float32) - # initialize mask - if mask is None: - mask = torch.ones(1, dtype=torch.bool, device=x.device) - mask = mask.to(torch.bool) - ret = torch.where(mask, x * scale, 0).sum(dim=dim) - if postprocess_fn is not None: - ret = postprocess_fn(ret) - if y_flex is not None: - y_flex.actual_scale.copy_(compute_actual_scale(ret, x_dtype, y_flex.is_per_batch)) - ret = (ret / y_flex.expected_scale).to(x_dtype) - # downcast output - ret_mxscale = None - if x_mxscale is not None: - assert y_flex is None - ret, ret_mxscale = downcast_to_mxfp_torch(ret, torch.float8_e4m3fn, axis=-1) - return ret.to(x_dtype), ret_mxscale diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 06fec8fbfc..f9eb3be995 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -245,7 +245,8 @@ class Case: Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), - Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), + # Cover (N or K) % 128 == 64 (https://github.com/triton-lang/triton/pull/7203) + Case(1, 1472, 1472, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), @@ -316,6 +317,24 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile, x_transpose, w_transpose, y_transpose, device, opt_flags_scope): + # We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to + # the frame that called pytest.skip, including all the tensors, leading to OOM. + skip_message = None + try: + _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, + n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + x_transpose, w_transpose, y_transpose, + device, opt_flags_scope) + except pytest.skip.Exception as e: + skip_message = str(e) + + if skip_message is not None: + pytest.skip(skip_message) + +def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot, + n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile, + x_transpose, w_transpose, y_transpose, + device, opt_flags_scope): # TODO: remove when Triton FP8 supports proper RTNE if is_cuda(): if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: @@ -325,8 +344,6 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o if weight_dtype_str.startswith("mx"): if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10: pytest.skip("float8 x mx not supported with cuda capability < 10") - if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9: - pytest.skip("Not enough memory on A100") elif is_hip(): if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4(): @@ -365,8 +382,21 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).") expt_is_inner = (inner_expt_opt is not None) - if expt_is_inner and (mode != "ragged" or "mx" in act_dtype_str or "mx" in weight_dtype_str): - pytest.skip("Not supported yet") + if expt_is_inner: + if mode != "ragged": + pytest.skip("inner_expt_opt only meaningful with ragged") + if "mx" in act_dtype_str and inner_expt_opt != "pad_x": + pytest.skip("inner_expt_opt and act mx only supported with pad_x") + if "mx" in weight_dtype_str: + if inner_expt_opt != "pad_w": + pytest.skip("inner_expt_opt and weight mx only supported with pad_w") + if is_persistent and not hbm_swizzling: + pytest.skip("FIXME: Fatal Python error: Aborted") + if is_hip(): + if act_dtype_str == "bfloat16": + pytest.skip("FIXME: failed to translate module to LLVM IR") + if hbm_swizzling: + pytest.skip("NYI: nner_expt_opt and HBM swizzling") # launch metadata for batched / mx types may not work yet. torch.manual_seed(0) @@ -398,6 +428,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o opt_flags.update_opt_flags_constraints(constraints) weight_mxfp = weight_dtype_str.startswith("mx") + weight_mxfp4 = weight_mxfp and "float4" in weight_dtype_str if weight_mxfp: weight_dtype_str = weight_dtype_str[2:] act_mxfp8 = act_dtype_str.startswith("mx") @@ -421,6 +452,13 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o rdata = gindx = sindx = None padding_block_k = 32 + if hbm_swizzling: + if torch.cuda.get_device_capability()[0] >= 10: + # Blackwell scale swizzling constraint + # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py#L45 + padding_block_k = 128 + elif not is_persistent: + padding_block_k = 64 x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode, torch.bfloat16 if act_mxfp8 else act_dtype, # torch.bfloat16 if weight_mxfp else weight_dtype, @@ -456,11 +494,12 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o # compute layouts w_layout, w_layout_opts = layout.StridedLayout, dict() w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict() - if hbm_swizzling and "float4" in weight_dtype_str: + if hbm_swizzling and weight_mxfp4: w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis) w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( mx_axis=mx_axis, num_warps=8) # downcast to mxfp +<<<<<<< HEAD w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype @@ -469,6 +508,74 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o # convert layouts w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) +======= + w_tri_orig = w_tri + if colmajor_mxfp_weight: + w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) + w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) + w_tri_dtype = FP4 if weight_mxfp4 else weight_dtype + w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) + w_scale_tri = wrap_torch_tensor(w_scale_tri) + # convert layouts + w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) + w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) + else: + if torch.cuda.get_device_capability()[0] < 10: + pytest.skip("transposed mxfp weight not supported with cuda capability < 10") + if block_m == 16: + pytest.skip("PassManager::run failed from Triton compiler") + # TODO: swizzling for rowmajor + + # A typical use case is we already quantized col-major weight, + # and we want matmul with its transposed row-major weight w/o + # requantization. + + # put abs_max of each 32x32 block to diagonal so scales of transposed agree + w_ndim = w_tri.ndim + if w_ndim == 2: + w_tri = w_tri.unsqueeze(0) + BLOCK_SIZE = int(MXFP_BLOCK_SIZE) + for e, i, j in itertools.product(range(w_tri.shape[0]), range(0, w_tri.shape[1], BLOCK_SIZE), range(0, w_tri.shape[2], BLOCK_SIZE)): + i_end = min(i+BLOCK_SIZE, w_tri.shape[1]) + j_end = min(j+BLOCK_SIZE, w_tri.shape[2]) + block = w_tri[e, i:i_end, j:j_end] + m_abs = block.abs().max() + i_len = i_end - i + j_len = j_end - j + min_len = min(i_len, j_len) + signs = torch.randint(0, 2, (max(i_len, j_len),), device=w_tri.device) * 2 - 1 + block.diagonal(dim1=-2, dim2=-1)[:] = signs[:min_len] * m_abs + if j_len > i_len: + block[i_len - 1, i_len:] = signs[min_len:] * m_abs + elif i_len > j_len: + block[j_len:, j_len - 1] = signs[min_len:] * m_abs + if w_ndim == 2: + w_tri = w_tri.squeeze(0) + + # matmul with rowmajor weight expects scale is separately + # constructed (not much additional memory needed). + _, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) + # reuse quantized value from colmajor + w_tri_rowmajor, w_scale_tri_rowmajor = downcast_to_mxfp(w_tri.mT.contiguous(), weight_dtype, axis=mx_axis) + w_ref = upcast_from_mxfp(w_tri_rowmajor, w_scale_tri_rowmajor, torch.bfloat16, axis=mx_axis).mT.contiguous() + w_tri = w_tri_rowmajor.data.mT + + def _pad_and_block(x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.pad(x, (0, x.shape[-1] % BLOCK_SIZE), mode="replicate") + return x.view(*x.shape[:-1], x.shape[-1] // BLOCK_SIZE, BLOCK_SIZE) + + # check if generated scale is transpose-invariant as intended construction + # [cdiv(K, 32), N] -> dedup to [cdiv(K, 32), cdiv(N, 32)] + w_scale_tri_blocked = _pad_and_block(w_scale_tri) + w_scale_tri_sampled = w_scale_tri_blocked[..., 0:1] + # [cdiv(N, 32), K] -> dedup to [cdiv(N, 32), cdiv(K, 32)] + w_scale_tri_rowmajor_blocked = _pad_and_block(w_scale_tri_rowmajor) + w_scale_tri_rowmajor_sampled = w_scale_tri_rowmajor_blocked[..., 0:1] + assert torch.equal(w_scale_tri_sampled.expand_as(w_scale_tri_blocked), w_scale_tri_blocked) + assert torch.equal(w_scale_tri_rowmajor_sampled.expand_as(w_scale_tri_rowmajor_blocked), w_scale_tri_rowmajor_blocked) + assert torch.equal(w_scale_tri_sampled.squeeze(-1), w_scale_tri_rowmajor_sampled.squeeze(-1).mT) + +>>>>>>> 9f21c06d55b5c2eccd872d92e9335c4eb13969c5 precision_opt.weight_scale = w_scale_tri epilogue = None if act_mxfp8: @@ -509,8 +616,13 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue, y=y_tri_in, inner_routing_data=inner_routing_data) +<<<<<<< HEAD except (opt_flags.InapplicableConstraint, NotImplementedError): pytest.xfail("inapplicable opt_flags constraint") +======= + except (opt_flags.InapplicableConstraint, NotImplementedError) as e: + pytest.skip(f"inapplicable opt_flags constraint {e}") +>>>>>>> 9f21c06d55b5c2eccd872d92e9335c4eb13969c5 if y_tri_in is not None: assert tri_y.data_ptr() == y_tri_in.data_ptr() assert tri_y.shape == y_tri_in.shape @@ -543,7 +655,7 @@ def scale(val, scal): ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1) maxtol = 4e-1 rmstol = 4e-2 - elif weight_mxfp and "float4_e2m1" in weight_dtype_str: + elif weight_mxfp4: if act_is_float8: maxtol = 8e-2 else: diff --git a/python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py b/python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py index dd02c71d2e..36f4d86a26 100644 --- a/python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py +++ b/python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py @@ -1,7 +1,7 @@ import pytest from triton._internal_testing import is_cuda, is_xpu from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4 -from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout, HopperAmpereMXValueLayout +from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper @@ -26,7 +26,7 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version): x = x.mT if x.shape[1 - mx_axis] < 32: pytest.skip("Not enough elements along non-mx axis") - layout = HopperAmpereMXValueLayout(x.shape, mx_axis, mma_version) + layout = HopperMXValueLayout(x.shape, mx_axis, mma_version) res = layout.unswizzle_data(layout.swizzle_data(x)) assert (res == x).all() @@ -37,7 +37,7 @@ def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version): @pytest.mark.xfail(condition=not is_cuda(), reason="Only supported on CUDA", run=False) def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps): x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") - layout = HopperAmpereMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps) + layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps) res = layout.unswizzle_data(layout.swizzle_data(x)) assert (res[:shape[0], :shape[1]] == x).all() @@ -87,8 +87,8 @@ def test_upcast_mxfp4_to_bf16(): x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis) x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4) x_fp4_scale = wrap_torch_tensor(x_fp4_scale) - x_fp4_val = convert_layout(x_fp4_val, HopperAmpereMXValueLayout, mx_axis=mx_axis) - x_fp4_scale = convert_layout(x_fp4_scale, HopperAmpereMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps) + x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis) + x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps) y = torch.empty_like(x_bf16) _upcast_mxfp4_to_bf16[(1, )]( y, x_fp4_val.storage.data, x_fp4_scale.storage.data, # diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index 12acf6bfb1..9b26b20205 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -644,9 +644,17 @@ def matmul_ogs(x, w, bias, w_has_tma = opt_flags.is_persistent w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n], "dense") if w_has_tma else w_storage.data # create tma descriptor for w_scale - w_scale_tensor_or_tma = w_scale w_scale_has_tma = opt_flags.is_persistent and w_scale is not None - w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k], "dense") if w_scale_has_tma else w_scale + w_transpose = w_storage.data.stride()[-2] == 1 + if w_scale_has_tma: + w_scale_storage = w_scale.storage + w_scale_tma_block_size = [opt_flags.block_n, opt_flags.block_k] if w_transpose else [opt_flags.block_k, opt_flags.block_n] + if isinstance(w_scale.storage.layout, StridedLayout): + w_scale_storage = _canonicalize_storage(w_scale.storage, 3, None) + w_scale_tma_block_size = [1] + w_scale_tma_block_size + w_scale_tensor_or_tma = w_scale_storage.make_tma(w_scale_tma_block_size, "dense") + else: + w_scale_tensor_or_tma = w_scale # canonicalize strides x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride()) x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None) @@ -661,7 +669,6 @@ def matmul_ogs(x, w, bias, # (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose # is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs. # w_transpose = w_storage.data.stride()[-1] != 1 - w_transpose = w_storage.data.stride()[-2] == 1 fused_comm_kwargs = { "pYPtrs": fused_comm.out_handles, "ScatterShardIndx": fused_comm.scatter_shard_indx, diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py index d411d3255c..1d71b9bc27 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py @@ -98,13 +98,14 @@ def _load_tile_attrs( tl.static_assert(M is not None) expt_id, pid_z, pid_z_out, start_m, block_id, eM = 0, 0, pid_e, 0, pid_m, M k_tiles = tl.cdiv(tl.load(ExptHist + pid_e), BLOCK_K) - padded_start_off = tl.load(ExptTileOffs + pid_e) * BLOCK_K + padded_start_off_raw = tl.load(ExptTileOffs + pid_e) + padded_start_off = padded_start_off_raw * BLOCK_K unpadded_start_off = tl.load(ExptOffs + pid_e) off_k_x = padded_start_off if X_IS_PADDED else unpadded_start_off # K_W is only used for non-TMA kernel (W bound is handled by TMA on TMA kernel). if W_IS_PADDED: - off_k_w = padded_start_off - K_W = tl.load(ExptTileOffs + pid_e + 1) * BLOCK_K + off_k_w = padded_start_off_raw * PACKED_BLOCK_K_W + K_W = tl.load(ExptTileOffs + pid_e + 1) * PACKED_BLOCK_K_W else: off_k_w = unpadded_start_off K_W = tl.load(ExptOffs + pid_e + 1) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py index 21884a1dfb..5b671cb11b 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -131,7 +131,7 @@ def _matmul_ogs( tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), "mx_weight_ptr must be uint8 or fp8") tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") - tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") + tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, f"{BLOCK_K=} must be a multiple of {MX_PACK_DIVISOR=}") tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values") # TODO: refactor if/else when triton front end improves @@ -247,7 +247,6 @@ def _matmul_ogs( # TODO: refactor if/else when triton front end improves if is_w_microscaled: - tl.static_assert(not EXPT_IS_INNER, "Not supported yet") WMxScale += expt_id * stride_w_mx_e if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": @@ -281,7 +280,8 @@ def _matmul_ogs( offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N) # K dimension must be the last dimension for the scales - offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK) + tl.static_assert(not EXPT_IS_INNER or W_IS_PADDED) + offs_k_scale = off_k_w // PACKED_BLOCK_K_W * PACKED_MX_BLOCK + tl.arange(0, PACKED_MX_BLOCK) WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n else: WMxScalePtrs = None @@ -295,7 +295,7 @@ def _matmul_ogs( XMxScale += start_z.to(index_type) * stride_x_mx_z if GatherIndx is None: XMxScale += start_m * stride_x_mx_m - offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + offs_x_k_scale = off_k_x // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k else: XMxScalePtrs = None diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py index 079c298631..3c9f05f499 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -125,7 +125,6 @@ def _p_matmul_ogs( tl.static_assert(get_dtype(WMxScale) == tl.uint8, "mx_scale_ptr must be uint8") tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales") - tl.static_assert(not EXPT_IS_INNER, "Not supported yet") # We have pack 2 fp4 values in a byte W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1 @@ -249,7 +248,7 @@ def _p_matmul_ogs( XMxScalePtrs = XMxScale + start_z.to(index_type) * stride_x_mx_z if GatherIndx is None: XMxScalePtrs += start_m * stride_x_mx_m - offs_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + offs_k_scale = off_k_x0 // MXFP_BLOCK_SIZE + tl.arange(0, MX_SCALE_BLOCK_K) XMxScalePtrs += (offs_x_m if USE_GATHER_TMA else offs_m).to(index_type)[:, None] * stride_x_mx_m XMxScalePtrs += offs_k_scale.to(index_type)[None, :] * stride_x_mx_k else: diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py index 10186e3408..13d0975333 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py @@ -291,7 +291,10 @@ def make_default_opt_flags_nvidia( else: if tokens_per_expt <= 64 and routing_data is not None and routing_data.expt_hist is not None: # Ragged and likely memory bound; set the block size higher to minimize loading weights more than once. - block_m = max(16, min(triton.next_power_of_2(2 * tokens_per_expt), 64)) + if lhs_dtype == torch.bfloat16 and rhs_dtype == FP4 and tokens_per_expt >= 16 and torch.cuda.get_device_capability()[0] >= 10: + block_m = max(16, min(triton.next_power_of_2(8 * tokens_per_expt), 128)) + else: + block_m = max(16, min(triton.next_power_of_2(2 * tokens_per_expt), 64)) else: block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128)) # block n @@ -312,14 +315,13 @@ def make_default_opt_flags_nvidia( is_persistent = False block_n = block_n_tma if is_persistent else block_n # block k - if constraints.get("block_k", None) is not None: - block_k = constraints["block_k"] - else: - block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in) - if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1: + block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config, has_y_acc_in) + if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1 and lhs_dtype != torch.bfloat16: # Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large. # TODO: swizzle the HBM layout of the weights instead block_n, block_k = block_k, block_n + if constraints.get("block_k", None) is not None: + block_k = constraints["block_k"] # split_k if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None: split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k")) @@ -404,6 +406,10 @@ def reset_opt_flags_constraints(): global _opt_flags_constraints _opt_flags_constraints = dict() +def reset_opt_flags(): + global _opt_flags + _opt_flags = None + def set_opt_flags(opt_flags: OptFlags): global _opt_flags assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override" diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py index a9964a625a..2cc0f3d41d 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py @@ -1,9 +1,9 @@ import torch import triton from triton_kernels import target_info -from triton_kernels.tensor import get_layout, bitwidth, FP4 -from triton_kernels.tensor_details.layout import HopperAmpereMXScaleLayout from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE +from triton_kernels.tensor import FP4, bitwidth, get_layout +from triton_kernels.tensor_details.layout import HopperMXScaleLayout def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): @@ -18,8 +18,11 @@ def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n): def compute_block_n(n: int, arch, precision_config): # block_n: layout = get_layout(precision_config.weight_scale) - if isinstance(layout, HopperAmpereMXScaleLayout) and layout.num_warps == 4: - return 128, 128 + if isinstance(layout, HopperMXScaleLayout): + if layout.num_warps in [4, 8]: + # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py#L265 + block_n = 2 * layout.num_warps * 2 * 8 + return block_n, block_n elif precision_config.max_num_imprecise_acc is None and n > 128: return 256, 256 else: @@ -60,7 +63,7 @@ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int: def compute_num_warps(block_m, block_n, is_persistent: bool, precision_config): layout = get_layout(precision_config.weight_scale) - if isinstance(layout, HopperAmpereMXScaleLayout): + if isinstance(layout, HopperMXScaleLayout): return layout.num_warps return max(block_m * block_n // 4096, 4 if is_persistent else 1) diff --git a/python/triton_kernels/triton_kernels/reduce.py b/python/triton_kernels/triton_kernels/reduce.py index e408ff5d76..c712a13536 100644 --- a/python/triton_kernels/triton_kernels/reduce.py +++ b/python/triton_kernels/triton_kernels/reduce.py @@ -147,6 +147,8 @@ def reduce( Returns: - output: torch.Tensor The reduced tensor with `dim` removed. + - output_mxscale: Optional[torch.Tensor] + The output mx scale if input is micro-scaled, else None. """ if x.ndim != 3: raise NotImplementedError("reduce only supports 3D inputs in this implementation") diff --git a/python/triton_kernels/triton_kernels/tensor.py b/python/triton_kernels/triton_kernels/tensor.py index 64dbc73b05..48e762a69b 100644 --- a/python/triton_kernels/triton_kernels/tensor.py +++ b/python/triton_kernels/triton_kernels/tensor.py @@ -2,12 +2,13 @@ from typing import Type import torch -from triton.tools.tensor_descriptor import TensorDescriptor from triton.tools.ragged_tma import create_ragged_descriptor +from triton.tools.tensor_descriptor import TensorDescriptor + from .target_info import cuda_capability_geq -from .tensor_details.layout import Layout, StridedLayout -from .tensor_details import ragged_tensor as ragged_tensor_details from .tensor_details import bitmatrix as bitmatrix_details +from .tensor_details import ragged_tensor as ragged_tensor_details +from .tensor_details.layout import BlackwellMXValueLayout, Layout, StridedLayout from .tensor_details.ragged_tensor import RaggedTensorMetadata @@ -46,26 +47,28 @@ def is_tma_compliant(self): compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim] return all(compliant) - def make_dense_tma(self, block_shape, transpose=False): + def make_dense_tma(self, block_shape): strides = list(self.data.stride()) shape = list(self.data.shape) - transpose = self.data.stride()[-1] != 1 + transpose = strides[-1] != 1 if transpose: block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]] shape = shape[:-2] + [shape[-1], shape[-2]] strides = strides[:-2] + [strides[-1], strides[-2]] - if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE": + if self.data.dtype == torch.uint8 and (self.layout.name is None or "_SCALE" not in self.layout.name): indx = strides.index(1) block_shape[indx] = block_shape[indx] // 2 - if shape[-1] % 128 != 0: - raise ValueError("inner shape need to be multiple of 128 for " - "mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.") + if isinstance(self.layout, BlackwellMXValueLayout): + if shape[-1] % 128 != 0: + raise ValueError( + "inner shape need to be multiple of 128 for mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs." + ) block_shape = self.layout.swizzle_block_shape(block_shape) return TensorDescriptor(self.data, shape, strides, block_shape) - def make_tma(self, block_shape, mode, transpose=False): + def make_tma(self, block_shape, mode): if mode in ["dense", "gather", "scatter"]: - return self.make_dense_tma(block_shape, transpose) + return self.make_dense_tma(block_shape) assert mode == "ragged" ragged_dim = len(self.data.shape) - 2 return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim) @@ -195,6 +198,7 @@ class RaggedTensor: A ragged `tensor` is a collection of 2D tensors that share the same number of columns. Each tensor in this collection is called a `slice`. """ + # slice_sizes[i] is the number of rows in slice `i` slice_sizes: torch.Tensor # ragged tensors are stored in memory as (potentially padded) 2D tensors of shape diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout.py b/python/triton_kernels/triton_kernels/tensor_details/layout.py index d68a37d979..9398a6477f 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout.py @@ -1,8 +1,8 @@ from .layout_details.base import Layout from .layout_details.blackwell_scale import BlackwellMXScaleLayout from .layout_details.blackwell_value import BlackwellMXValueLayout -from .layout_details.hopper_scale import HopperAmpereMXScaleLayout -from .layout_details.hopper_value import HopperAmpereMXValueLayout +from .layout_details.hopper_scale import HopperMXScaleLayout +from .layout_details.hopper_value import HopperMXValueLayout from .layout_details.cdna4_scale import CDNA4MXScaleLayout from .layout_details.strided import StridedLayout from ..target_info import cuda_capability_geq, is_hip_cdna4 @@ -11,8 +11,8 @@ "Layout", "BlackwellMXValueLayout", "BlackwellMXScaleLayout", - "HopperAmpereMXScaleLayout", - "HopperAmpereMXValueLayout", + "HopperMXScaleLayout", + "HopperMXValueLayout", "CDNA4MXScaleLayout", "StridedLayout", ] @@ -21,8 +21,8 @@ def make_default_matmul_mxfp4_w_layout(mx_axis: int): if cuda_capability_geq(10): return BlackwellMXValueLayout, dict() - elif cuda_capability_geq(8): - return HopperAmpereMXValueLayout, {"mx_axis": mx_axis} + elif cuda_capability_geq(9): + return HopperMXValueLayout, {"mx_axis": mx_axis} else: return StridedLayout, dict() @@ -33,7 +33,7 @@ def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8): else: if cuda_capability_geq(10): return BlackwellMXScaleLayout, dict() - elif cuda_capability_geq(8): - return HopperAmpereMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps} + elif cuda_capability_geq(9): + return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps} return StridedLayout, dict() diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py index 7df29947fc..ec2637c750 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -1,7 +1,9 @@ import math + +import torch import triton import triton.language as tl -import torch + from .base import Layout SWIZZLE_ALIGN_INNER = tl.constexpr(8) @@ -14,7 +16,11 @@ class BlackwellMXScaleLayout(Layout): def __init__(self, shape) -> None: super().__init__(shape) - *self.leading_shape, self.K, self.N, = shape + ( + *self.leading_shape, + self.K, + self.N, + ) = shape self.B = math.prod(self.leading_shape) self.ALIGN_K = 8 self.ALIGN_N = 128 @@ -42,13 +48,17 @@ def unswizzle_data(self, data): def swizzle_block_shape(self, block_shape): MX_PACK_DIVISOR = 32 MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR + assert block_shape[0] >= 128, f"{block_shape[0]=} must be >= 128" return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256] @triton.jit -def unswizzle_mx_scale_bw(x, SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER, - SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER, - ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER): +def unswizzle_mx_scale_bw( + x, + SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER, + SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER, + ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER, +): shape_0: tl.constexpr = x.shape[0] shape_1: tl.constexpr = x.shape[1] tl.static_assert(shape_1 % SIZE_OUTER == 0) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py index 3674691782..7211468faa 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_scale.py @@ -4,7 +4,7 @@ from .base import Layout -class HopperAmpereMXScaleLayout(Layout): +class HopperMXScaleLayout(Layout): name: str = "HOPPER_SCALE" def __init__(self, shape, mx_axis, num_warps=8) -> None: diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py index 2706c75c72..a4b3a7c0bd 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py @@ -82,7 +82,7 @@ def _unpack_bits(x, mx_axis: int): # ----------------------------------------------------------------------- -class HopperAmpereMXValueLayout(Layout): +class HopperMXValueLayout(Layout): name: str = "HOPPER_VALUE" def __init__(self, shape, mx_axis, mma_version=3): diff --git a/python/tutorials/09-persistent-matmul.py b/python/tutorials/09-persistent-matmul.py index f354769615..3078e691b7 100644 --- a/python/tutorials/09-persistent-matmul.py +++ b/python/tutorials/09-persistent-matmul.py @@ -276,9 +276,9 @@ def matmul_tma(a, b, warp_specialize: bool): # A dummy block value that will be overwritten when we have the real block size dummy_block = [1, 1] - a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) - b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) - c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) + a_desc = TensorDescriptor.from_tensor(a, dummy_block) + b_desc = TensorDescriptor.from_tensor(b, dummy_block) + c_desc = TensorDescriptor.from_tensor(c, dummy_block) def grid(META): BLOCK_M = META["BLOCK_SIZE_M"] @@ -485,9 +485,9 @@ def matmul_tma_persistent(a, b, warp_specialize: bool): # A dummy block value that will be overwritten when we have the real block size dummy_block = [1, 1] - a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) - b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) - c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) + a_desc = TensorDescriptor.from_tensor(a, dummy_block) + b_desc = TensorDescriptor.from_tensor(b, dummy_block) + c_desc = TensorDescriptor.from_tensor(c, dummy_block) def grid(META): nonlocal a_desc, b_desc, c_desc diff --git a/python/tutorials/gluon/07-persistence.py b/python/tutorials/gluon/07-persistence.py index bf86cfc65c..624fedb2fd 100644 --- a/python/tutorials/gluon/07-persistence.py +++ b/python/tutorials/gluon/07-persistence.py @@ -97,6 +97,7 @@ class WGMMA: acc: Union[warpgroup_mma_accumulator, gl.tensor] use_acc: gl.tensor + @gluon.constexpr_function def __init__(self, acc, use_acc): self.acc = acc self.use_acc = use_acc @@ -136,12 +137,13 @@ class MMAv5: counter: gl.tensor reg_layout: gl.constexpr + @gluon.constexpr_function def __init__(self, use_acc, acc_tmem, bar, counter, reg_layout): self.use_acc = use_acc self.acc_tmem = acc_tmem self.bar = bar self.counter = counter - self.reg_layout = reg_layout + self.reg_layout = gl.constexpr(reg_layout) @gluon.jit def initialize(dtype: gl.constexpr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, num_warps: gl.constexpr): @@ -342,6 +344,7 @@ class PersistentTileScheduler: pid_end: gl.tensor num_pid_m: gl.tensor + @gluon.constexpr_function def __init__(self, pid_start, pid_end, num_pid_m): self.pid_start = pid_start self.pid_end = pid_end @@ -523,6 +526,7 @@ class GroupedPersistentTileSchedulerImpl: num_pid_in_group: gl.tensor num_pid: gl.tensor + @gluon.constexpr_function def __init__(self, start_pid, num_pid_m, num_pid_in_group, num_pid): self.start_pid = start_pid self.num_pid_m = num_pid_m diff --git a/python/tutorials/gluon/08-warp-specialization.py b/python/tutorials/gluon/08-warp-specialization.py index 3479694057..45059c40a8 100644 --- a/python/tutorials/gluon/08-warp-specialization.py +++ b/python/tutorials/gluon/08-warp-specialization.py @@ -276,15 +276,11 @@ def elementwise_add_warp_specialized_kernel( # # warps to reduce the amount of registers allocated. The default partition # receives whatever registers are left over, based on `maxnreg` passed to # the kernel. - gl.warp_specialize( - default_args=(barriers, buffers, ynumel, YBLOCK, layout), - default_partition=compute_partition, - worker_args=(descs, barriers, buffers, xoff, numel, YBLOCK), - worker_partitions=[load_partition, store_partition], - worker_num_warps=[1, 1], - # Registers must be allocated in multiples of 8, between [24, 256]. - worker_num_regs=[24, 24], - ) + gl.warp_specialize([ + (compute_partition, (barriers, buffers, ynumel, YBLOCK, layout)), + (load_partition, (descs, barriers, buffers, xoff, numel, YBLOCK)), + (store_partition, (descs, barriers, buffers, xoff, numel, YBLOCK)), + ], [1, 1], [24, 24]) def elementwise_add_warp_specialized(a, b, c, XBLOCK=32, YBLOCK=64, # @@ -404,6 +400,7 @@ class PartitionArgs: SUBTILE_FACTOR: gl.constexpr num_warps: gl.constexpr + @gluon.constexpr_function def __init__(self, a_desc, b_desc, c_desc, a_bufs, b_bufs, load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, SUBTILE_FACTOR, num_warps): self.a_desc = a_desc @@ -416,8 +413,8 @@ def __init__(self, a_desc, b_desc, c_desc, a_bufs, b_bufs, load_empty_bars, load self.acc_bufs = acc_bufs self.acc_empty_bars = acc_empty_bars self.acc_ready_bars = acc_ready_bars - self.SUBTILE_FACTOR = SUBTILE_FACTOR - self.num_warps = num_warps + self.SUBTILE_FACTOR = gl.constexpr(SUBTILE_FACTOR) + self.num_warps = gl.constexpr(num_warps) # Counter abstraction for tracking barrier index and phase. @@ -427,10 +424,11 @@ class Counter: phase: gl.tensor num_barriers: gl.constexpr + @gluon.constexpr_function def __init__(self, index, phase, num_barriers): self.index = index self.phase = phase - self.num_barriers = num_barriers + self.num_barriers = gl.constexpr(num_barriers) @gluon.jit def create(phase, num_barriers: gl.constexpr): @@ -588,14 +586,11 @@ def matmul_warp_specialized_kernel(a_desc, b_desc, c_desc, SchedulerImpl: gl.con p = PartitionArgs(a_desc, b_desc, c_desc, a_bufs, b_bufs, load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, SUBTILE_FACTOR, num_warps) - gl.warp_specialize( - default_args=(p, SchedulerImpl), - default_partition=matmul_epilogue_partition, - worker_args=(p, SchedulerImpl), - worker_partitions=[matmul_load_partition, matmul_mma_partition], - worker_num_warps=[1, 1], - worker_num_regs=[24, 24], - ) + gl.warp_specialize([ + (matmul_epilogue_partition, (p, SchedulerImpl)), + (matmul_load_partition, (p, SchedulerImpl)), + (matmul_mma_partition, (p, SchedulerImpl)), + ], [1, 1], [24, 24]) def matmul_warp_specialized(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, SUBTILE_FACTOR, num_warps, SchedulerImpl): diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 0e93214c31..a5a9b6e4ba 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -458,7 +458,7 @@ tt.func @max_min() { %4 = arith.constant dense<8> : tensor<128xi32> // expected-remark @below {{contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4}} %5 = arith.constant dense<4> : tensor<128xi32> - // expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8}} + // expected-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}} %6 = arith.maxsi %4, %5 : tensor<128xi32> tt.return } @@ -974,3 +974,86 @@ tt.func public @trans_4d_tensor_kernel(%arg0: tensor<32x32x32x32xi32> {tt.contig %102 = tt.trans %arg0 {order = array} : tensor<32x32x32x32xi32> -> tensor<32x32x32x32xi32> tt.return } + +// ----- + +tt.func @unrealized_conversion_cast(%arg0: tensor<128x128xi32> {tt.contiguity = dense<[16, 32]> : tensor<2xi32>}) { + // Case 1: AxisInfo is propagated through a sequence of + // unrealized_conversion_cast ops. + // expected-remark @below {{contiguity = [16, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %0 = builtin.unrealized_conversion_cast %arg0 : tensor<128x128xi32> to !llvm.struct<(i32, i32, i32, i32)> + // expected-remark @below {{contiguity = [16, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %1 = builtin.unrealized_conversion_cast %0 : !llvm.struct<(i32, i32, i32, i32)> to tensor<128x128xi32> + + // Case 2: AxisInfo is falling back to the pessimistic state if the + // propagated AxisInfo would be invalid. + // expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %2 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)> + // expected-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i32, i32, i32, i32)> to tensor<128x128xi32> + // expected-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %4 = tt.trans %3 {order = array} : tensor<128x128xi32> -> tensor<128x128xi32> + tt.return +} + +// ----- + +// Axis analysis does not support multi-dimensional function arguments. Make +// sure that we don't crash. +tt.func @callee(%arg0: tensor<128x1xi32>) { + tt.return +} + +tt.func @caller() { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expected-remark @below {{contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = }} + %1 = tt.expand_dims %0 {axis = 1: i32} : tensor<128xi32> -> tensor<128x1xi32> + tt.call @callee(%1) : (tensor<128x1xi32>) -> () + tt.return +} + +// ----- + +tt.func @mul_zero_constancy() { + %range = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %zeros = arith.constant dense<0> : tensor<128xi32> + // expected-remark @below {{constancy = [128]}} + %product = arith.muli %zeros, %range : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @max_constancy() { + %c5 = arith.constant dense<5> : tensor<4xi32> + %c7 = arith.constant dense<7> : tensor<4xi32> + // expected-remark @below {{constancy = [4], constant_value = 7}} + %max = arith.maxsi %c5, %c7 : tensor<4xi32> + tt.return +} + +// ----- + +tt.func @select_same_value_constancy() { + %range = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %two = arith.constant dense<2> : tensor<4xi32> + %mod = arith.remsi %range, %two : tensor<4xi32> + %zero = arith.constant dense<0> : tensor<4xi32> + %cond = arith.cmpi ne, %mod, %zero : tensor<4xi32> + %lhs = arith.constant dense<42> : tensor<4xi32> + %rhs = arith.constant dense<42> : tensor<4xi32> + // expected-remark @below {{constancy = [4], constant_value = 42}} + %sel = arith.select %cond, %lhs, %rhs : tensor<4xi1>, tensor<4xi32> + tt.return +} + +// ----- + +tt.func @cmp_after_max_constancy() { + %c5 = arith.constant dense<5> : tensor<4xi32> + %c7 = arith.constant dense<7> : tensor<4xi32> + %max = arith.maxsi %c5, %c7 : tensor<4xi32> + // expected-remark @below {{constancy = [4], constant_value = 1}} + %cmp = arith.cmpi sgt, %max, %c5 : tensor<4xi32> + tt.return +} diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index f0c6d652ed..0140d9cb7b 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -155,6 +155,18 @@ tt.func @preallocate(%A : !tt.ptr) { tt.return } +// expected-remark @below {{memdesc_ptr}} +// expected-remark @below {{size = 6144}} +tt.func @memdesc_ptr() { + // expected-remark @below {{offset = 0, size = 4096}} + %a0 = ttg.local_alloc : () -> !ttg.memdesc<32x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 4096, size = 2048}} + %a1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %a0 : !ttg.memdesc<32x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %a1 : !ttg.memdesc<1x16x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + // Unused tensors are immediately released // expected-remark @below {{unused}} // expected-remark @below {{size = 1024}} @@ -279,9 +291,9 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { } -// expected-remark @below {{alloc}} +// expected-remark @below {{alloc_ptr}} // expected-remark @below {{size = 512}} -tt.func @alloc(%A : !tt.ptr) { +tt.func @alloc_ptr(%A : !tt.ptr) { // expected-remark @below {{offset = 0, size = 512}} %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> diff --git a/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir b/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir new file mode 100644 index 0000000000..5e450d5027 --- /dev/null +++ b/test/Conversion/amd/async_ops_to_llvm_gfx1250.mlir @@ -0,0 +1,35 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx1250 | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 8, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { + // CHECK-LABEL: async_copy_with_swizzle + tt.func public @async_copy_with_swizzle(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, + %arg2: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { + // We need the splat to allow the AxisAnalysis to work during lowering + %1 = tt.splat %arg0 : !tt.ptr -> tensor<32x32x!tt.ptr, #blocked> + // Each thread needs to load 8 elements and we load 1 (sizePerThread) per global.load.lds + // CHECK-COUNT-8: llvm.amdgcn.global.load.async.to.lds.b32 + // CHECK-NOT: llvm.amdgcn.global.load.async.to.lds + %2 = ttg.async_copy_global_to_local %1, %arg2 : tensor<32x32x!tt.ptr, #blocked> -> <32x32xf32, #shared, #smem, mutable> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 8192 : i32} { + // CHECK-LABEL: async_load_strided_into_lds_with_swizzle + tt.func public @async_load_strided_into_lds_with_swizzle(%arg0: tensor<32x32x!tt.ptr, #blocked> {tt.divisibility = dense<[16, 16]> : tensor<2xi32>, tt.contiguity = dense<[16, 16]> : tensor<2xi32>, tt.constancy = dense<[1, 1]> : tensor<2xi32>}, + %arg1: !ttg.memdesc<32x32xf32, #shared, #smem, mutable>) { + // Each thread loads 256 contiguous bits so we split into 2 128bit loads. This was not possible on GFX9 + // CHECK-COUNT-2: llvm.amdgcn.global.load.async.to.lds.b128 + // CHECK-NOT: llvm.amdgcn.global.load.async.to.lds + %6 = ttg.async_copy_global_to_local %arg0, %arg1 : tensor<32x32x!tt.ptr, #blocked> -> <32x32xf32, #shared, #smem, mutable> + tt.return + } +} diff --git a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir index ad4c215b2c..3696bf6801 100644 --- a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir +++ b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir @@ -150,6 +150,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // COMMON: llvm.cond_br // COMMON: llvm.store + // Make sure branch condition is set properly when there is other value. + // COMMON: [[AND:%.*]] = llvm.and + // COMMON: llvm.cond_br [[AND]] + // COMMON: rocdl.raw.ptr.buffer.load.lds // COMMON: llvm.cond_br // COMMON: llvm.store diff --git a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir index bceb77fde2..6337fec57e 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir @@ -7,8 +7,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_scaled_dot_fp4 - tt.func @wmma_scaled_dot_fp4(%arg0: tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x4xi8, #linear>, %arg2: tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg3: tensor<16x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + tt.func @wmma_scaled_dot_fp4(%arg0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<32x4xi8, #linear>, %arg2: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg3: tensor<32x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // Matrix A // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<32xi8> @@ -18,21 +18,185 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<32xi8> // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32> // Scale A - // CHECK-COUNT-2: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> - // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 // Scale B - // CHECK-COUNT-2: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> - // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 // Matrix C // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<8xi32>, i32, vector<8xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> - %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, tensor<16x4xi8, #linear> * tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, tensor<16x4xi8, #linear1> -> tensor<16x16xf32, #mma> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, tensor<32x4xi8, #linear> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, tensor<32x4xi8, #linear1> -> tensor<32x32xf32, #mma> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> - %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<16x16x!tt.ptr, #mma> - tt.store %ptr0, %c : tensor<16x16x!tt.ptr, #mma> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> +#mma1 = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 64]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp4_fp8 + tt.func @wmma_scaled_dot_fp4_fp8(%arg0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<32x4xi8, #linear>, %arg2: tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Matrix A + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<32xi8> + // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32> + // Matrix B + // CHECK-COUNT-64: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<8xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e2m1 rhs = e4m3 {fastMath = false} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, tensor<32x4xi8, #linear> * tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8 + tt.func @wmma_scaled_dot_fp8(%arg0: tensor<32x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<32x4xi8, #linear>, %arg2: tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Matrix A + // CHECK-COUNT-64: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Matrix B + // CHECK-COUNT-64: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear> * tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8_k64 + tt.func @wmma_scaled_dot_fp8_k64(%arg0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<32x2xi8, #linear>, %arg2: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x2xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Adjust for acc + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i8) : i8 + // Matrix A + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK-COUNT-32: llvm.insertelement %[[ZERO]], {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Matrix B + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK-COUNT-32: llvm.insertelement %[[ZERO]], {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x2xi8, #linear> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x2xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8_repeat_k + tt.func @wmma_scaled_dot_fp8_repeat_k(%arg0: tensor<32x256xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<32x8xi8, #linear>, %arg2: tensor<256x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x8xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Matrix A + // CHECK-COUNT-128: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Matrix B + // CHECK-COUNT-128: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-8: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-8: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x256xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x8xi8, #linear> * tensor<256x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x8xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> tt.return } } diff --git a/test/Conversion/tma_to_llvm.mlir b/test/Conversion/tma_to_llvm.mlir index 1f43127e57..8552bebf02 100644 --- a/test/Conversion/tma_to_llvm.mlir +++ b/test/Conversion/tma_to_llvm.mlir @@ -65,7 +65,7 @@ tt.func @tma_gather_simple(%arg0: !tt.tensordesc>, // CHECK: [[OFFSET0:%.*]] = zext nneg i32 [[WARP_STRIDE]] to i64 // CHECK: [[BASEPTR0:%.*]] = getelementptr bfloat, ptr addrspace(3) [[BASE_PTR]], i64 [[OFFSET0]] - // CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cluster.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r" + // CHECK: "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared::cta.global.mbarrier::complete_tx::bytes [$1], [$2, {$3, $4, $5, $6, $7}], [$8];", "b,r,l,r,r,r,r,r,r" // CHECK-SAME: (i1 [[PRED]], ptr addrspace(3) [[BASEPTR0]], ptr nonnull %0, i32 [[Y0]], i32 [[IDX0]], i32 [[IDX1]], i32 [[IDX2]], i32 [[IDX3]], ptr addrspace(3) [[BAR]]) // CHECK: [[BASEPTR1:%.*]] = getelementptr i8, ptr addrspace(3) [[BASEPTR0]], i64 4096 diff --git a/test/NVWS/lower_aref.mlir b/test/NVWS/lower_aref.mlir index 58d7c1c7c9..298acdf6e1 100644 --- a/test/NVWS/lower_aref.mlir +++ b/test/NVWS/lower_aref.mlir @@ -224,8 +224,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %c0_i32 = arith.constant 0 : i32 %true = arith.constant true %result, %token = ttng.tmem_alloc : () -> (!ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.async.token) - %1 = ttg.memdesc_index %result[%c0_i32] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> - %2 = ttng.tmem_store %cst, %1[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128> + %1 = ttg.memdesc_index %result[%c0_i32] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> + %2 = ttng.tmem_store %cst, %1[%token], %true : tensor<128x128xf32, #blocked> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> // CHECK: [[BUF_A:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> // CHECK: [[BUF_B:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #shared, #smem, mutable> // CHECK: [[TMA_EMPTY:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x1xi64, #shared1, #smem, mutable> @@ -236,8 +236,8 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %6 = nvws.aref.create %5 : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> %7 = arith.subi %arg0, %c1_i32 : i32 %8 = ttg.local_alloc : () -> !ttg.memdesc<1x1xi64, #shared1, #smem, mutable> - %9 = ttg.memdesc_index %8[%c0_i32] : !ttg.memdesc<1x1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable, 1x1> - ttng.init_barrier %9, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable, 1x1> + %9 = ttg.memdesc_index %8[%c0_i32] : !ttg.memdesc<1x1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> + ttng.init_barrier %9, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> %10 = scf.for %arg5 = %c0_i32 to %arg0 step %c1_i32 iter_args(%arg6 = %2) -> (!ttg.async.token) : i32 { %11 = arith.muli %arg5, %c64_i32 {loop.cluster = 1 : i32, loop.stage = 0 : i32} : i32 // CHECK-COUNT-1: ttng.wait_barrier {{.*}}, {{.*}} {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} @@ -246,23 +246,23 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { // CHECK: [[TMA_FULL_SLICE:%.*]] = ttg.memdesc_index [[TMA_FULL]] // CHECK: ttng.async_tma_copy_global_to_local {{.*}} [[BUF_A_SLICE]], [[TMA_FULL_SLICE]], {{.*}} {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} // CHECK: ttng.async_tma_copy_global_to_local {{.*}} [[BUF_B_SLICE]], [[TMA_FULL_SLICE]], {{.*}} {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} - %buffers, %token_2 = nvws.aref.put.enter %4[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64>, !ttg.async.token - nvws.descriptor_load %arg3[%arg1, %11] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64> + %buffers, %token_2 = nvws.aref.put.enter %4[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.async.token + nvws.descriptor_load %arg3[%arg1, %11] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> nvws.aref.put.exit %4[%c0_i32], %token_2 [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token - %buffers_3, %token_4 = nvws.aref.get.enter %4[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64>, !ttg.async.token - %buffers_5, %token_6 = nvws.aref.put.enter %6[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64>, !ttg.async.token - nvws.descriptor_load %arg4[%arg2, %11] 16384 %buffers_5 {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64> + %buffers_3, %token_4 = nvws.aref.get.enter %4[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.async.token + %buffers_5, %token_6 = nvws.aref.put.enter %6[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.async.token + nvws.descriptor_load %arg4[%arg2, %11] 16384 %buffers_5 {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> nvws.aref.put.exit %6[%c0_i32], %token_6 [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token - %buffers_7, %token_8 = nvws.aref.get.enter %6[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64>, !ttg.async.token + %buffers_7, %token_8 = nvws.aref.get.enter %6[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.async.token // CHECK-COUNT-1: ttng.wait_barrier {{.*}}, {{.*}} {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} // CHECK: [[BUF_A_SLICE:%.*]] = ttg.memdesc_index [[BUF_A]] // CHECK: [[BUF_B_SLICE:%.*]] = ttg.memdesc_index [[BUF_B]] // CHECK: [[BUF_B_SLICE_TRANS:%.*]] = ttg.memdesc_trans [[BUF_B_SLICE]] {loop.cluster = 0 : i32, loop.stage = 1 : i32 - %12 = ttg.memdesc_trans %buffers_7 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64> -> !ttg.memdesc<64x128xf16, #shared2, #smem, 1x64x128> + %12 = ttg.memdesc_trans %buffers_7 {loop.cluster = 0 : i32, loop.stage = 1 : i32, order = array, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> !ttg.memdesc<64x128xf16, #shared2, #smem> %13 = arith.cmpi eq, %arg5, %7 : i32 // CHECK: ttng.tc_gen5_mma [[BUF_A_SLICE]], [[BUF_B_SLICE_TRANS]] - %14 = ttng.tc_gen5_mma %buffers_3, %12, %1[], %true, %true, %9[%13] {is_async, loop.cluster = 0 : i32, loop.stage = 1 : i32, tt.self_latency = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64>, !ttg.memdesc<64x128xf16, #shared2, #smem, 1x64x128>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x128x128>, !ttg.memdesc<1xi64, #shared1, #smem, mutable, 1x1> + %14 = ttng.tc_gen5_mma %buffers_3, %12, %1[], %true, %true, %9[%13] {is_async, loop.cluster = 0 : i32, loop.stage = 1 : i32, tt.self_latency = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.memdesc<64x128xf16, #shared2, #smem>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> // CHECK: [[TMA_EMPTY_SLICE:%.*]] = ttg.memdesc_index [[TMA_EMPTY]] // CHECK-COUNT-1: ttng.tc_gen5_commit [[TMA_EMPTY_SLICE]] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} nvws.aref.get.exit %6[%c0_i32], %token_8 [#nvws.async_op] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token @@ -291,22 +291,22 @@ module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} { %0 = ttg.local_alloc : () -> !ttg.memdesc<1x128x64xf16, #shared, #smem, mutable> %1 = nvws.aref.create %0 : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> scf.for %arg2 = %c0_i32 to %arg1 step %c1_i32 : i32 { - %buffers, %token = nvws.aref.put.enter %1[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64>, !ttg.async.token - nvws.descriptor_load %arg0[%arg2, %arg2] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable, 1x128x64> + %buffers, %token = nvws.aref.put.enter %1[%c0_i32, %c0_i32] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, mutable>, !ttg.async.token + nvws.descriptor_load %arg0[%arg2, %arg2] 16384 %buffers {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : !tt.tensordesc>, i32, i32, !ttg.memdesc<128x64xf16, #shared, #smem, mutable> nvws.aref.put.exit %1[%c0_i32], %token [#nvws.async_op] {loop.cluster = 1 : i32, loop.stage = 0 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token - %buffers_0, %token_1 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64>, !ttg.async.token - %2 = ttg.local_load %buffers_0 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64> -> tensor<128x64xf16, #blocked> + %buffers_0, %token_1 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.async.token + %2 = ttg.local_load %buffers_0 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : !ttg.memdesc<128x64xf16, #shared, #smem> -> tensor<128x64xf16, #blocked> // CHECK: ttng.fence_async_shared {bCluster = false, loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} // CHECK: [[EMPTYSLICE:%.*]] = ttg.memdesc_index [[EMPTY]] // CHECK: ttng.arrive_barrier [[EMPTYSLICE]], 1 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} nvws.aref.get.exit %1[%c0_i32], %token_1 [#nvws.async_op] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token - %buffers_2, %token_3 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64>, !ttg.async.token + %buffers_2, %token_3 = nvws.aref.get.enter %1[%c0_i32, %c0_i32] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]> -> !ttg.memdesc<128x64xf16, #shared, #smem>, !ttg.async.token "use1"(%2) {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (tensor<128x64xf16, #blocked>) -> () // CHECK: "use2" // CHECK: ttng.fence_async_shared {bCluster = false, loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} // CHECK: [[EMPTYSLICE:%.*]] = ttg.memdesc_index [[EMPTY]] // CHECK: ttng.arrive_barrier [[EMPTYSLICE]], 1 {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} - "use2"(%buffers_2) {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (!ttg.memdesc<128x64xf16, #shared, #smem, 1x128x64>) -> () + "use2"(%buffers_2) {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : (!ttg.memdesc<128x64xf16, #shared, #smem>) -> () nvws.aref.get.exit %1[%c0_i32], %token_3 [#nvws.async_op] {loop.cluster = 0 : i32, loop.stage = 1 : i32, ttg.partition = array} : <[!ttg.memdesc<1x128x64xf16, #shared, #smem, mutable>]>, !ttg.async.token } {tt.num_stages = 2 : i32, tt.scheduled_max_stage = 1 : i32, tt.warp_specialize, ttg.partition.stages = [0 : i32, 1 : i32, 0 : i32], ttg.warp_specialize.tag = 0 : i32} tt.return diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir index 1f1ca95182..5232cb859d 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-mfma-gfx950.mlir @@ -320,3 +320,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ tt.return } } + +// ----- +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked5 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked6 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked7 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 2, 2, 1], threadsPerWarp = [1, 1, 4, 16, 1, 1, 1], warpsPerCTA = [4, 1, 1, 1, 1, 1, 1], order = [6, 5, 4, 3, 2, 1, 0]}> +#blocked8 = #ttg.blocked<{sizePerThread = [1, 2, 1, 1, 2, 1, 1], threadsPerWarp = [1, 1, 16, 1, 1, 4, 1], warpsPerCTA = [4, 1, 1, 1, 1, 1, 1], order = [6, 1, 4, 2, 5, 3, 0]}> +#linear = #ttg.linear<{register = [[16, 0], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp = [[32, 0], [64, 0]], block = []}> + +// MFMA16: [[$linear1:#.*]] = #ttg.linear<{register = {{\[\[}}0, 4{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2{{]]}}, warp = {{\[\[}}0, 0], [0, 0{{]]}}, block = []}> +// MFMA16: [[$linear2:#.*]] = #ttg.linear<{register = {{\[\[}}0, 4], [16, 0{{]]}}, lane = {{\[\[}}1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2{{]]}}, warp = {{\[\[}}32, 0], [64, 0{{]]}}, block = []}> +// MFMA16: [[$mma:#.*]] = #ttg.amd_mfma<{version = 4, warpsPerCTA = [1, 4], instrShape = [16, 16, 128], isTransposed = true, tilesPerWarp = [1, 2]}> +// MFMA16-LABEL: mfma_dot_scaled_fp8_mxfp4 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mfma_dot_scaled_fp8_mxfp4( + %arg0: tensor<16x256xf8E4M3FN, #blocked6>, + %arg1: tensor<4x256x!tt.ptr, #blocked5>, + %arg2: tensor<128x128xi8, #blocked1>, + %arg3: tensor<16x128x!tt.ptr, #blocked1> + ) { + // MFMA16: [[SCALE0:%.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<16x8xi8, [[$linear1]]> + // MFMA16: [[SCALE1:%.+]] = ttg.convert_layout {{.*}} : {{.*}} -> tensor<128x8xi8, [[$linear2]]> + // MFMA16: tt.dot_scaled {{.*}} scale [[SCALE0]], {{.*}} scale [[SCALE1]], {{.*}} -> tensor<16x128xf32, [[$mma]]> + %cst0 = arith.constant dense<127> : tensor<16x8xi8, #blocked> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x128xf32, #blocked1> + %load = tt.load %arg1 : tensor<4x256x!tt.ptr, #blocked5> + %reshape0 = tt.reshape %load : tensor<4x256xi8, #blocked5> -> tensor<4x1x4x16x2x2x1xi8, #blocked7> + %trans = tt.trans %reshape0 {order = array} : tensor<4x1x4x16x2x2x1xi8, #blocked7> -> tensor<4x2x16x1x2x4x1xi8, #blocked8> + %reshape1 = tt.reshape %trans : tensor<4x2x16x1x2x4x1xi8, #blocked8> -> tensor<128x8xi8, #linear> + %scale = ttg.convert_layout %reshape1 : tensor<128x8xi8, #linear> -> tensor<128x8xi8, #blocked> + %1 = tt.dot_scaled %arg0 scale %cst0, %arg2 scale %scale, %cst1 lhs = e4m3 rhs = e2m1 {fastMath = true} : tensor<16x256xf8E4M3FN, #blocked6>, tensor<16x8xi8, #blocked> * tensor<128x128xi8, #blocked1>, tensor<128x8xi8, #blocked> -> tensor<16x128xf32, #blocked1> + tt.store %arg3, %1 : tensor<16x128x!tt.ptr, #blocked1> + tt.return + } +} diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir index 2a521f7475..0816308c8d 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx1250 matrix-instruction-size=16" | FileCheck %s --check-prefixes CHECK +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx1250" | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -32,3 +32,165 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK{LITERAL}: #mma1 = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 64]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp4_mxfp8 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp4_mxfp8( + %arg0: tensor<32x64xi8, #blocked>, + %arg1: tensor<128x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x4xi8, #blocked2>, + %arg3: tensor<32x4xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x64xi8, #blocked> -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<128x32xf8E4M3FN, #blocked1> -> tensor<128x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e2m1 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e2m1 rhs = e4m3 {fastMath = false} : tensor<32x64xi8, #blocked>, tensor<32x4xi8, #blocked2> * tensor<128x32xf8E4M3FN, #blocked1>, tensor<32x4xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8( + %arg0: tensor<32x128xf8E4M3FN, #blocked>, + %arg1: tensor<128x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x4xi8, #blocked2>, + %arg3: tensor<32x4xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x128xf8E4M3FN, #blocked> -> tensor<32x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<128x32xf8E4M3FN, #blocked1> -> tensor<128x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x128xf8E4M3FN, #blocked>, tensor<32x4xi8, #blocked2> * tensor<128x32xf8E4M3FN, #blocked1>, tensor<32x4xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8_k64 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8_k64( + %arg0: tensor<32x64xf8E4M3FN, #blocked>, + %arg1: tensor<64x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x2xi8, #blocked2>, + %arg3: tensor<32x2xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x64xf8E4M3FN, #blocked> -> tensor<32x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<64x32xf8E4M3FN, #blocked1> -> tensor<64x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x2xi8, #blocked2> -> tensor<32x2xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x2xi8, #blocked2> -> tensor<32x2xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x64xf8E4M3FN, #blocked>, tensor<32x2xi8, #blocked2> * tensor<64x32xf8E4M3FN, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8_repeat_k +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8_repeat_k( + %arg0: tensor<32x256xf8E4M3FN, #blocked>, + %arg1: tensor<256x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x8xi8, #blocked2>, + %arg3: tensor<32x8xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x256xf8E4M3FN, #blocked> -> tensor<32x256xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<256x32xf8E4M3FN, #blocked1> -> tensor<256x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x8xi8, #blocked2> -> tensor<32x8xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x8xi8, #blocked2> -> tensor<32x8xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x256xf8E4M3FN, #blocked>, tensor<32x8xi8, #blocked2> * tensor<256x32xf8E4M3FN, #blocked1>, tensor<32x8xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8_repeat_mn +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8_repeat_mn( + %arg0: tensor<64x128xf8E4M3FN, #blocked>, + %arg1: tensor<128x64xf8E4M3FN, #blocked1>, + %arg2: tensor<64x4xi8, #blocked2>, + %arg3: tensor<64x4xi8, #blocked2>, + %arg4: tensor<64x64x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<64x128xf8E4M3FN, #blocked> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<128x64xf8E4M3FN, #blocked1> -> tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<64x4xi8, #blocked2> -> tensor<64x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<64x4xi8, #blocked2> -> tensor<64x4xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<64x128xf8E4M3FN, #blocked>, tensor<64x4xi8, #blocked2> * tensor<128x64xf8E4M3FN, #blocked1>, tensor<64x4xi8, #blocked2> -> tensor<64x64xf32, #blocked3> + tt.store %arg4, %1 : tensor<64x64x!tt.ptr, #blocked3> + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-block-pingpong-chained-dots.mlir b/test/TritonGPU/amd/amd-block-pingpong-chained-dots.mlir index a095775288..514f67da9c 100644 --- a/test/TritonGPU/amd/amd-block-pingpong-chained-dots.mlir +++ b/test/TritonGPU/amd/amd-block-pingpong-chained-dots.mlir @@ -40,23 +40,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> - %2 = ttg.memdesc_index %1[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %3 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %4 = ttg.memdesc_index %1[%c1_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %5:9 = scf.for %arg14 = %c0_i32 to %arg1 step %arg2 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %arg3, %arg19 = %arg3, %arg20 = %2, %arg21 = %4, %arg22 = %arg3, %arg23 = %3) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.async.token, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>) : i32 { + %2 = ttg.memdesc_index %1[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %3 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %4 = ttg.memdesc_index %1[%c1_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %5:9 = scf.for %arg14 = %c0_i32 to %arg1 step %arg2 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %arg3, %arg19 = %arg3, %arg20 = %2, %arg21 = %4, %arg22 = %arg3, %arg23 = %3) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.async.token, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>) : i32 { %6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %7 = ttg.async_wait %arg18 {num = 0 : i32} - %8 = ttg.local_load %arg20 token %7 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %9 = ttg.memdesc_index %0[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %10 = ttg.async_copy_global_to_local %arg0, %9 : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable, 2x64x16> + %8 = ttg.local_load %arg20 token %7 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %9 = ttg.memdesc_index %0[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %10 = ttg.async_copy_global_to_local %arg0, %9 : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> %11 = ttg.async_commit_group tokens %10 %12 = tt.dot %arg10, %8, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %13 = ttg.async_wait %arg22 {num = 0 : i32} - %14 = ttg.local_load %arg23 token %13 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %15 = ttg.memdesc_index %1[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %16 = ttg.async_copy_global_to_local %arg0, %15 : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable, 2x64x16> + %14 = ttg.local_load %arg23 token %13 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %15 = ttg.memdesc_index %1[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %16 = ttg.async_copy_global_to_local %arg0, %15 : tensor<64x16x!tt.ptr, #blocked> -> <64x16xf16, #shared, #smem, mutable> %17 = ttg.async_commit_group tokens %16 - scf.yield %12, %6, %14, %arg19, %17, %arg21, %15, %11, %9 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.async.token, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> + scf.yield %12, %6, %14, %arg19, %17, %arg21, %15, %11, %9 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.async.token, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.async.token, !ttg.memdesc<64x16xf16, #shared, #smem, mutable> } ttg.local_dealloc %1 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> ttg.local_dealloc %0 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> @@ -107,21 +107,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> - %2 = ttg.memdesc_index %1[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %3 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %4 = ttg.memdesc_index %1[%c1_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %5:8 = scf.for %arg14 = %c0_i32 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %2, %arg19 = %4, %arg20 = %3, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 { + %2 = ttg.memdesc_index %1[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %3 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %4 = ttg.memdesc_index %1[%c1_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %5:8 = scf.for %arg14 = %c0_i32 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %2, %arg19 = %4, %arg20 = %3, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 { %6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> - ttg.local_store %arg21, %arg18 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %7 = ttg.local_load %arg18 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %8 = ttg.memdesc_index %0[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> + ttg.local_store %arg21, %arg18 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %7 = ttg.local_load %arg18 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %8 = ttg.memdesc_index %0[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> %9 = tt.load %arg1 : tensor<64x16x!tt.ptr, #blocked> %10 = tt.dot %arg10, %7, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> - ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %12 = ttg.memdesc_index %1[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> + ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %12 = ttg.memdesc_index %1[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> %13 = tt.load %arg1 : tensor<64x16x!tt.ptr, #blocked> - scf.yield %10, %6, %11, %arg19, %12, %8, %9, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked> + scf.yield %10, %6, %11, %arg19, %12, %8, %9, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked> } ttg.local_dealloc %1 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> ttg.local_dealloc %0 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> @@ -147,17 +147,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %c0_i32 = arith.constant 0 : i32 %0 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> %1 = ttg.local_alloc : () -> !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> - %2 = ttg.memdesc_index %1[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %3 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %4 = ttg.memdesc_index %1[%c1_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %5:8 = scf.for %arg14 = %c0_i32 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %2, %arg19 = %4, %arg20 = %3, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 { + %2 = ttg.memdesc_index %1[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %3 = ttg.memdesc_index %0[%c0_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %4 = ttg.memdesc_index %1[%c1_i32] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %5:8 = scf.for %arg14 = %c0_i32 to %arg2 step %arg3 iter_args(%arg15 = %arg4, %arg16 = %arg4, %arg17 = %arg7, %arg18 = %2, %arg19 = %4, %arg20 = %3, %arg21 = %arg0, %arg22 = %arg0) -> (tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked>) : i32 { %6 = tt.dot %arg10, %arg17, %arg15 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> %10 = tt.dot %arg10, %arg17, %arg16 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x16xf32, #mma> - ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> - %11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %12 = ttg.memdesc_index %1[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16> + ttg.local_store %arg22, %arg20 : tensor<64x16xf16, #blocked> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> + %11 = ttg.local_load %arg20 : !ttg.memdesc<64x16xf16, #shared, #smem, mutable> -> tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %12 = ttg.memdesc_index %1[%arg6] : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x16xf16, #shared, #smem, mutable> %13 = tt.load %arg1 : tensor<64x16x!tt.ptr, #blocked> - scf.yield %10, %6, %11, %arg19, %12, %12, %13, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable, 2x64x16>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked> + scf.yield %10, %6, %11, %arg19, %12, %12, %13, %13 : tensor<128x16xf32, #mma>, tensor<128x16xf32, #mma>, tensor<64x16xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, !ttg.memdesc<64x16xf16, #shared, #smem, mutable>, tensor<64x16xf16, #blocked>, tensor<64x16xf16, #blocked> } ttg.local_dealloc %1 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> ttg.local_dealloc %0 : !ttg.memdesc<2x64x16xf16, #shared, #smem, mutable> diff --git a/test/TritonGPU/amd/amd-update-async-wait-count.mlir b/test/TritonGPU/amd/amd-update-async-wait-count.mlir index 1ad3f8a65b..afaf10d99f 100644 --- a/test/TritonGPU/amd/amd-update-async-wait-count.mlir +++ b/test/TritonGPU/amd/amd-update-async-wait-count.mlir @@ -369,3 +369,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +// Test mixing async_copy and async_tdm_copy + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} { + // CHECK-LABEL: mix_async_copy_and_async_tdm_copy + tt.func public @mix_async_copy_and_async_tdm_copy(%memDesc: !ttg.memdesc<128x16xf16, #shared, #smem, mutable>, %tensorDesc: !tt.tensordesc>, %mask: i1, %ptr: tensor<128x16x!tt.ptr, #blocked> + ) { + %c0_i32 = arith.constant 0 : i32 + + // Each async_tdm_copy only emits a single instruction (-> counts 1) + %1 = amdgpu.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, %mask : !tt.tensordesc> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + + %2 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %21 = ttg.async_commit_group tokens %2 + + %3 = amdgpu.async_tdm_copy_global_to_local %tensorDesc[%c0_i32, %c0_i32] into %memDesc, %mask : !tt.tensordesc> -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + + %4 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %5 = ttg.async_copy_global_to_local %ptr, %memDesc : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %51 = ttg.async_commit_group tokens %4, %5 + + // Check that we do not take other TDM loads into account (they use a different HW counter) + + // CHECK: ttg.async_wait {{.*}} {num = 2 + %cw1 = ttg.async_wait %21 {num = 0 : i32} + + // CHECK: ttg.async_wait {{.*}} {num = 0 + %cw2 = ttg.async_wait %51 {num = 0 : i32} + tt.return + } +} diff --git a/test/TritonGPU/invalid.mlir b/test/TritonGPU/invalid.mlir index e2e714aaa7..02d1d8ff90 100644 --- a/test/TritonGPU/invalid.mlir +++ b/test/TritonGPU/invalid.mlir @@ -73,6 +73,17 @@ tt.func public @result_rank_too_large(%arg0: !ttg.memdesc<3x8x16xf32, #shared, # %a = ttg.memdesc_index %arg0[%zero] : !ttg.memdesc<3x8x16xf32, #shared, #smem> -> !ttg.memdesc<3x8x16xf32, #shared, #smem> tt.return } + +// ----- + +#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}> +#smem = #ttg.shared_memory +tt.func public @memdesc_index_result_alloc_shape_mismatch(%arg0: !ttg.memdesc<3x8x16xf32, #shared, #smem>) { + %zero = arith.constant 0 : i32 + // expected-error @+1 {{alloc shape must match shape for both result and src}} + %a = ttg.memdesc_index %arg0[%zero] : !ttg.memdesc<3x8x16xf32, #shared, #smem> -> !ttg.memdesc<8x16xf32, #shared, #smem, 3x8x16> + tt.return +} // ----- #shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0]}> diff --git a/test/TritonGPU/loop-pipeline-hip.mlir b/test/TritonGPU/loop-pipeline-hip.mlir index 7e34439471..3eeee75394 100644 --- a/test/TritonGPU/loop-pipeline-hip.mlir +++ b/test/TritonGPU/loop-pipeline-hip.mlir @@ -250,7 +250,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // Check that the stream pipeliner updates the resulting memory layout of transpose ops to mutable if immutable local buffers are replaced // COMMON-LABEL: loop_with_dot_and_transpose // COMMON: ttg.local_alloc {{.*}}, mutable> -// COMMON: ttg.memdesc_trans {{.*}}, mutable, {{.*}} -> {{.*}}, mutable +// COMMON: ttg.memdesc_trans {{.*}}, mutable> -> {{.*}}, mutable> #blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index a6cc102fe1..294895ed0f 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -21,11 +21,11 @@ // CHECK: %[[BBUFFER:.*]] = ttg.local_alloc // CHECK-DAG: %[[LOOP_COND_0:.*]] = arith.cmpi slt, %[[LB:.*]], %[[UB:.*]] // CHECK-DAG: %[[LOOP_COND_0_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_0]] -// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_index %[[ABUFFER]]{{\[}}%[[CONSTANT_0]]{{\]}} : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 2x128x32> -// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #smem, mutable, 2x128x32> +// CHECK-DAG: %[[ASUB:.*]] = ttg.memdesc_index %[[ABUFFER]]{{\[}}%[[CONSTANT_0]]{{\]}} : !ttg.memdesc<2x128x32xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable> +// CHECK: %[[T_A0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[ASUB]] mask %[[LOOP_COND_0_SPLAT_A]] : tensor<128x32x!tt.ptr, #blocked1> -> <128x32xf16, #shared, #smem, mutable> // CHECK-DAG: %[[LOOP_COND_0_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_0]] // CHECK-DAG: %[[BSUB:.*]] = ttg.memdesc_index %[[BBUFFER]]{{\[}}%[[CONSTANT_0]]{{\]}} -// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #smem, mutable, 2x32x128> +// CHECK: %[[T_B0:.*]] = ttg.async_copy_global_to_local %{{.*}}, %[[BSUB]] mask %[[LOOP_COND_0_SPLAT_B]] other %{{.*}} : tensor<32x128x!tt.ptr, #blocked> -> <32x128xf16, #shared1, #smem, mutable> // CHECK-DAG: %[[IV_1:.*]] = arith.addi %[[LB]], %[[STEP:.*]] // CHECK-DAG: %[[LOOP_COND_1:.*]] = arith.cmpi slt, %[[IV_1]], %[[UB]] // CHECK-DAG: %[[LOOP_COND_1_SPLAT_A:.*]] = tt.splat %[[LOOP_COND_1]] diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index fc047f54c2..e3f5c1a152 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -984,7 +984,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { // CHECK: ttg.async_wait {{.*}} {num = 1 : i32} // CHECK: %[[NEXT_BUFFER_1:.*]] = tt.addptr %{{.*}}, {{.*}} // CHECK: ttg.async_copy_global_to_local %[[NEXT_BUFFER_1]] -// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_index {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #smem, mutable, 1x16> +// CHECK: %[[IND_BUFFER_0:.*]] = ttg.memdesc_index {{.*}} : !ttg.memdesc<1x16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> -> !ttg.memdesc<16xi64, #[[$SHARED_LAYOUT]], #smem, mutable> // CHECK: %[[IND_BUFFER_1:.*]] = ttg.local_load %[[IND_BUFFER_0]] // CHECK: %[[IND_BUFFER_2:.*]] = tt.expand_dims %[[IND_BUFFER_1]] {axis = 1 : i32} // CHECK: %[[IND_BUFFER_3:.*]] = tt.broadcast %[[IND_BUFFER_2]] diff --git a/test/TritonGPU/memdesc-subview-split.mlir b/test/TritonGPU/memdesc-subview-split.mlir index cb27ace665..74ae6d18db 100644 --- a/test/TritonGPU/memdesc-subview-split.mlir +++ b/test/TritonGPU/memdesc-subview-split.mlir @@ -21,8 +21,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ %9 = ttg.memdesc_subslice %1 [128, 96] : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> !ttg.memdesc<128x32xf16, #shared, #smem, mutable, 256x128> %padded = ttg.local_alloc : () -> !ttg.memdesc<1x256x128xf16, #padded, #smem, mutable> - %padded_indexed_explicit_alloc_shape = ttg.memdesc_index %padded[%c0_i32] : !ttg.memdesc<1x256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<256x128xf16, #padded, #smem, mutable, 1x256x128> - %10 = ttg.memdesc_subslice %padded_indexed_explicit_alloc_shape [128, 96] : !ttg.memdesc<256x128xf16, #padded, #smem, mutable, 1x256x128> -> !ttg.memdesc<128x32xf16, #padded, #smem, mutable, 1x256x128> + %padded_indexed_explicit_alloc_shape = ttg.memdesc_index %padded[%c0_i32] : !ttg.memdesc<1x256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<256x128xf16, #padded, #smem, mutable> + %10 = ttg.memdesc_subslice %padded_indexed_explicit_alloc_shape [128, 96] : !ttg.memdesc<256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<128x32xf16, #padded, #smem, mutable, 256x128> %padded_indexed_implicit_alloc_shape = ttg.memdesc_index %padded[%c0_i32] : !ttg.memdesc<1x256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<256x128xf16, #padded, #smem, mutable> %11 = ttg.memdesc_subslice %padded_indexed_implicit_alloc_shape [128, 96] : !ttg.memdesc<256x128xf16, #padded, #smem, mutable> -> !ttg.memdesc<128x32xf16, #padded, #smem, mutable, 256x128> tt.return diff --git a/test/TritonGPU/pipeline-assign-latencies.mlir b/test/TritonGPU/pipeline-assign-latencies.mlir index 82f5136992..181a454862 100644 --- a/test/TritonGPU/pipeline-assign-latencies.mlir +++ b/test/TritonGPU/pipeline-assign-latencies.mlir @@ -1147,3 +1147,21 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +// Test that ub.poison producing a memdesc does not get treated like a tensor +// value in AxisInfo analysis. +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-warps" = 4 : i32} { + tt.func public @minimal_crash(%lb: i32, %ub: i32) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> { + %c1 = arith.constant 1 : i32 + %poison = ub.poison : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + %normal = ttg.local_alloc : () -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + %result = scf.for %i = %lb to %ub step %c1 iter_args(%current = %poison) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> : i32 { + scf.yield %normal : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + } + tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable> + } +} diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir index a00b102d9d..a01e35fa88 100644 --- a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -58,26 +58,26 @@ // CHECK: %[[VAL_43:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_44:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_45:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -// CHECK: %[[VAL_46:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.init_barrier %[[VAL_46]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_47:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.init_barrier %[[VAL_47]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_48:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_7]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.init_barrier %[[VAL_48]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> +// CHECK: %[[VAL_46:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.init_barrier %[[VAL_46]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_47:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.init_barrier %[[VAL_47]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_48:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_7]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.init_barrier %[[VAL_48]], 1 : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_49:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_12]] : i32 -// CHECK: %[[VAL_50:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_51:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: %[[VAL_52:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_52]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_50:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_51:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_52:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_52]], %[[VAL_50]], %[[VAL_49]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_53:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_15]] : i32 -// CHECK: %[[VAL_54:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.barrier_expect %[[VAL_54]], 49152, %[[VAL_53]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_55:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_55]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: %[[VAL_56:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_56]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_54:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.barrier_expect %[[VAL_54]], 49152, %[[VAL_53]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_55:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_55]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_56:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_56]], %[[VAL_54]], %[[VAL_53]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_57:.*]]:5 = scf.for %[[VAL_58:.*]] = %[[VAL_12]] to %[[VAL_42]] step %[[VAL_15]] iter_args(%[[VAL_59:.*]] = %[[VAL_19]], %[[VAL_60:.*]] = %[[VAL_13]], %[[VAL_61:.*]] = %[[VAL_15]], %[[VAL_62:.*]] = %[[VAL_8]], %[[VAL_63:.*]] = %[[VAL_12]]) -> (tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32) : i32 { // CHECK: %[[VAL_64:.*]] = arith.subi %[[VAL_42]], %[[VAL_7]] : i32 // CHECK: %[[VAL_65:.*]] = arith.cmpi slt, %[[VAL_58]], %[[VAL_64]] : i32 @@ -86,32 +86,32 @@ // CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_67]], %[[VAL_12]], %[[VAL_66]] : i32 // CHECK: %[[VAL_69:.*]] = arith.xori %[[VAL_63]], %[[VAL_15]] : i32 // CHECK: %[[VAL_70:.*]] = arith.select %[[VAL_67]], %[[VAL_69]], %[[VAL_63]] : i32 -// CHECK: %[[VAL_71:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_68]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.wait_barrier %[[VAL_71]], %[[VAL_70]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_72:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_68]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -// CHECK: %[[VAL_73:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_68]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: %[[VAL_74:.*]] = ttg.memdesc_trans %[[VAL_72]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable, 3x64x256> -// CHECK: %[[VAL_75:.*]] = ttng.warp_group_dot %[[VAL_73]], %[[VAL_74]], %[[VAL_59]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable, 3x64x256> -> tensor<128x256xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_76:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_75]], %[[VAL_73]], %[[VAL_74]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable, 3x64x256> +// CHECK: %[[VAL_71:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_68]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.wait_barrier %[[VAL_71]], %[[VAL_70]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_72:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_68]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_73:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_68]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_74:.*]] = ttg.memdesc_trans %[[VAL_72]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_75:.*]] = ttng.warp_group_dot %[[VAL_73]], %[[VAL_74]], %[[VAL_59]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> -> tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_76:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_75]], %[[VAL_73]], %[[VAL_74]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #[[$ATTR_5]], mutable> // CHECK: %[[VAL_77:.*]] = arith.addi %[[VAL_60]], %[[VAL_13]] : i32 // CHECK: %[[VAL_78:.*]] = arith.addi %[[VAL_61]], %[[VAL_15]] : i32 // CHECK: %[[VAL_79:.*]] = arith.cmpi sge, %[[VAL_78]], %[[VAL_6]] : i32 // CHECK: %[[VAL_80:.*]] = arith.select %[[VAL_79]], %[[VAL_12]], %[[VAL_78]] : i32 -// CHECK: %[[VAL_81:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.barrier_expect %[[VAL_81]], 49152, %[[VAL_65]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_82:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_77]]] %[[VAL_82]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x128x64> -// CHECK: %[[VAL_83:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> -// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_77]]] %[[VAL_83]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable, 3x256x64> +// CHECK: %[[VAL_81:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.barrier_expect %[[VAL_81]], 49152, %[[VAL_65]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_82:.*]] = ttg.memdesc_index %[[VAL_43]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_35]]{{\[}}%[[VAL_39]], %[[VAL_77]]] %[[VAL_82]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_83:.*]] = ttg.memdesc_index %[[VAL_44]]{{\[}}%[[VAL_80]]{{\]}} : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_36]]{{\[}}%[[VAL_40]], %[[VAL_77]]] %[[VAL_83]], %[[VAL_81]], %[[VAL_65]] : !tt.tensordesc>, !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: scf.yield %[[VAL_76]]#0, %[[VAL_77]], %[[VAL_80]], %[[VAL_68]], %[[VAL_70]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32 // CHECK: } // CHECK: %[[VAL_84:.*]] = ttng.warp_group_dot_wait %[[VAL_85:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_86:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.inval_barrier %[[VAL_86]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_87:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.inval_barrier %[[VAL_87]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: %[[VAL_88:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_7]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> -// CHECK: ttng.inval_barrier %[[VAL_88]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable, 3x1> +// CHECK: %[[VAL_86:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_12]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.inval_barrier %[[VAL_86]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_87:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_15]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.inval_barrier %[[VAL_87]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: %[[VAL_88:.*]] = ttg.memdesc_index %[[VAL_45]]{{\[}}%[[VAL_7]]{{\]}} : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> +// CHECK: ttng.inval_barrier %[[VAL_88]] : !ttg.memdesc<1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: ttg.local_dealloc %[[VAL_45]] : !ttg.memdesc<3x1xi64, #[[$ATTR_3]], #[[$ATTR_5]], mutable> // CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> // CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #[[$ATTR_5]], mutable> diff --git a/test/TritonNvidiaGPU/mma_lowering.mlir b/test/TritonNvidiaGPU/mma_lowering.mlir index e6fd7929ee..5f0ad43ea4 100644 --- a/test/TritonNvidiaGPU/mma_lowering.mlir +++ b/test/TritonNvidiaGPU/mma_lowering.mlir @@ -88,8 +88,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { !ttg.memdesc<128x256xf8E5M2, #shared1, #ttg.shared_memory>, !ttg.memdesc<128x256xf32, #tmem, #ttng.tensor_memory, mutable> ttng.tc_gen5_commit %barrier, %barrierPred : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable> - %barrier_slice = ttg.memdesc_index %barrier2[%c0_i32] : !ttg.memdesc<2x1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<1xi64, #shared2, #smem, mutable, 2x1> - ttng.tc_gen5_commit %barrier_slice : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable, 2x1> + %barrier_slice = ttg.memdesc_index %barrier2[%c0_i32] : !ttg.memdesc<2x1xi64, #shared2, #smem, mutable> -> !ttg.memdesc<1xi64, #shared2, #smem, mutable> + ttng.tc_gen5_commit %barrier_slice : !ttg.memdesc<1xi64, #shared2, #ttg.shared_memory, mutable> ttng.tc_gen5_mma %a, %b, %c, %accUse, %pred {is_async} : !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, diff --git a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir index 0c22310bcc..52132242b8 100644 --- a/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir +++ b/test/TritonNvidiaGPU/test_tensor_memory_allocation.mlir @@ -320,11 +320,11 @@ tt.func @alloc_warp_specialize_explicit_capture_subview() { %c0_i32 = arith.constant 0 : i32 %b = ttg.memdesc_index %arg0[%c0_i32] : !ttg.memdesc<2x128x128xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xbf16, #shared1, #smem> - %a = ttg.memdesc_index %arg1[%c0_i32] : !ttg.memdesc<1x64x128xbf16, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xbf16, #tmem, #ttng.tensor_memory, mutable, 1x64x128> - %d = ttg.memdesc_index %arg2[%c0_i32] : !ttg.memdesc<1x64x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x64x128> + %a = ttg.memdesc_index %arg1[%c0_i32] : !ttg.memdesc<1x64x128xbf16, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xbf16, #tmem, #ttng.tensor_memory, mutable> + %d = ttg.memdesc_index %arg2[%c0_i32] : !ttg.memdesc<1x64x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable> %barrier = ttg.memdesc_index %arg3[%c0_i32] : !ttg.memdesc<2x1xi64, #shared, #smem, mutable> -> !ttg.memdesc<1xi64, #shared, #smem, mutable> - ttng.tc_gen5_mma %a, %b, %d, %true, %true, %barrier[%true] {is_async} : !ttg.memdesc<64x128xbf16, #tmem, #ttng.tensor_memory, mutable, 1x64x128>, !ttg.memdesc<128x128xbf16, #shared1, #smem>, !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable, 1x64x128>, !ttg.memdesc<1xi64, #shared, #smem, mutable> + ttng.tc_gen5_mma %a, %b, %d, %true, %true, %barrier[%true] {is_async} : !ttg.memdesc<64x128xbf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x128xbf16, #shared1, #smem>, !ttg.memdesc<64x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1xi64, #shared, #smem, mutable> ttg.warp_return } : (!ttg.memdesc<2x128x128xbf16, #shared1, #smem, mutable>, !ttg.memdesc<1x64x128xbf16, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1x64x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<2x1xi64, #shared, #smem, mutable>) -> () tt.return diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 2ef0816ff4..3cfd5b9bea 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -39,8 +39,8 @@ using ValueTable = std::map, Value>; ValueTable getValuesFromDotOperandLayoutStruct( ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, - int wmmaVer, Value value, int batch, int n0, int n1, int kBase, Type type, - Location loc) { + int wmmaVer, Value value, int batch, int n0, int n1, int kBase, + int kPadding, Type type, Location loc) { auto tb = TritonLLVMOpBuilder(loc, rewriter); auto elems = unpackLLElements(loc, value, rewriter); ValueTable vals; @@ -50,11 +50,18 @@ ValueTable getValuesFromDotOperandLayoutStruct( Type elemTy = typeConverter->convertType(type); Type ty = vec_ty(elemTy, kBase); Value rawElems = tb.undef(ty); + for (int k = 0; k < kBase; ++k) { - rawElems = tb.insert_element( - ty, rawElems, - elems[n0 * n1 * kBase * b + kBase * (n1 * i + j) + k], - tb.i32_val(k)); + int idx = n0 * n1 * kBase * b + kBase * (n1 * i + j) + k; + if (k < kBase - kPadding) { + rawElems = + tb.insert_element(ty, rawElems, elems[idx], tb.i32_val(k)); + } else { + // pad with zeros + Value zero = rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)); + rawElems = tb.insert_element(ty, rawElems, zero, tb.i32_val(k)); + } } Value convertedElems; @@ -65,12 +72,14 @@ ValueTable getValuesFromDotOperandLayoutStruct( // Before wmma v3, bf16 is converted to i16 if (wmmaVer < 3) convertedElems = tb.bitcast(rawElems, vec_ty(i16_ty, kBase)); - } else if (kBase == 4 && type.getIntOrFloatBitWidth() == 8) { - convertedElems = tb.bitcast(rawElems, i32_ty); } else { - convertedElems = tb.bitcast( - rawElems, vec_ty(i32_ty, kBase * type.getIntOrFloatBitWidth() / - i32_ty.getIntOrFloatBitWidth())); + auto elems = kBase * type.getIntOrFloatBitWidth() / + i32_ty.getIntOrFloatBitWidth(); + assert(elems >= 1 && "unexpected number of elements"); + if (elems == 1) + convertedElems = tb.bitcast(rawElems, i32_ty); + else + convertedElems = tb.bitcast(rawElems, vec_ty(i32_ty, elems)); } vals[{b, i, j}] = convertedElems; } @@ -254,13 +263,19 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, auto numRepK = repA[2]; auto numRepB = repA[0]; - int kBase = maybeWmmaIntrinsic->kBase; + // If kDim > kDimTensor, we need add zeros to the kBase vector. The amount of + // zeros is determined by kBase * (1 - kDimTensor / kDim) + auto kBase = maybeWmmaIntrinsic->kBase; + auto kDimTensor = aTensorTy.getShape().back(); + auto paddingFactor = kDim > kDimTensor ? (kDim / kDimTensor) : 1; + auto kPadding = kBase - kBase / paddingFactor; + ValueTable ha = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedA, numRepB, numRepM, numRepK, - kBase, aTensorTy.getElementType(), loc); + kBase, kPadding, aTensorTy.getElementType(), loc); ValueTable hb = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedB, numRepB, numRepN, numRepK, - kBase, aTensorTy.getElementType(), loc); + kBase, kPadding, aTensorTy.getElementType(), loc); auto dstElemTy = dTensorTy.getElementType(); auto fc = unpackLLElements(loc, loadedC, rewriter); @@ -373,6 +388,13 @@ LogicalResult convertScaledDot(triton::DotScaledOp op, int kBaseB = isFp4B ? kBase / 2 : kBase; int kDimB = isFp4B ? kDim / 2 : kDim; + bool isFp6A = (op.getAElemType() == triton::ScaleDotElemType::E2M3) || + (op.getAElemType() == triton::ScaleDotElemType::E3M2); + bool isFp6B = (op.getBElemType() == triton::ScaleDotElemType::E2M3) || + (op.getBElemType() == triton::ScaleDotElemType::E3M2); + if (isFp6A || isFp6B) + return op.emitError("NYI: FP6 scaled dot"); + auto repA = wmmaLayout.getRepForOperand(aTensorTy.getShape(), kDimA, 0); auto repB = wmmaLayout.getRepForOperand(bTensorTy.getShape(), kDimB, 1); @@ -388,23 +410,26 @@ LogicalResult convertScaledDot(triton::DotScaledOp op, auto numRepK = repA[2]; auto numRepB = repA[0]; - auto scaleShapeA = aScaleTensorTy.getShape(); - constexpr int scaleKWidthA = 4; - auto scaleShapeB = bScaleTensorTy.getShape(); - constexpr int scaleKWidthB = 4; + // If kDim > kDimTensor, we need add zeros to the kBase vector. The amount of + // zeros is determined by kBase * (1 - kDimTensor / kDim) + auto kDimTensorA = aTensorTy.getShape().back(); + auto paddingFactor = kDimA > kDimTensorA ? (kDimA / kDimTensorA) : 1; + auto kPaddingA = kBaseA - kBaseA / paddingFactor; + auto kPaddingB = kBaseB - kBaseB / paddingFactor; + auto KBaseScale = 4; ValueTable ha = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedA, numRepB, numRepM, numRepK, - kBaseA, aTensorTy.getElementType(), loc); + kBaseA, kPaddingA, aTensorTy.getElementType(), loc); ValueTable hb = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedB, numRepB, numRepN, numRepK, - kBaseB, bTensorTy.getElementType(), loc); + kBaseB, kPaddingB, bTensorTy.getElementType(), loc); ValueTable sa = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedAScale, numRepB, numRepM, numRepK, - scaleKWidthA, aScaleTensorTy.getElementType(), loc); + KBaseScale, 0, aScaleTensorTy.getElementType(), loc); ValueTable sb = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedBScale, numRepB, numRepN, numRepK, - scaleKWidthB, bScaleTensorTy.getElementType(), loc); + KBaseScale, 0, bScaleTensorTy.getElementType(), loc); auto dstElemTy = dTensorTy.getElementType(); auto fc = unpackLLElements(loc, loadedC, rewriter); @@ -438,11 +463,11 @@ LogicalResult convertScaledDot(triton::DotScaledOp op, ? generateScaledWMMAIntrinsic( rewriter, loc, hb[{b, n, k}], sb[{b, n, k}], ha[{b, m, k}], sa[{b, m, k}], acc, scaledBElemType, - scaledAElemType, dstElemTy, scaleKWidthA) + scaledAElemType, dstElemTy, KBaseScale) : generateScaledWMMAIntrinsic( rewriter, loc, ha[{b, m, k}], sa[{b, m, k}], hb[{b, n, k}], sb[{b, n, k}], acc, scaledAElemType, - scaledBElemType, dstElemTy, scaleKWidthB); + scaledBElemType, dstElemTy, KBaseScale); } for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { fc[fcThreadOffIdx + v] = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 4e76cf1dbb..25c51c99d6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -304,7 +304,11 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { LogicalResult canWriteCoalesced(RewriterBase &rewriter, Operation *op, RankedTensorType srcTy, MemDescType dstTy, unsigned vectorSize, - bool hasSwizzling) const { + bool requiresSrcPtrSwizzling) const { + if (targetInfo.supportsDirectToLDSScattering()) { + return success(); + } + int vecBits = vectorSize * dstTy.getElementTypeBitWidth(); if (!targetInfo.supportsDirectToLdsLoadBitWidth(vecBits)) { LDBG(*op << " results in unsupported load bitwidth: " << vecBits); @@ -322,15 +326,16 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { LinearLayout srcToSharedLayout = srcLayout.invertAndCompose(sharedLayout); unsigned threadsPerWarp = lookupThreadsPerWarp(rewriter); - if (!hasSwizzling && + if (!requiresSrcPtrSwizzling && !LLVM::AMD::canCoalesceWriteIntoSharedMemory( rewriter, srcToSharedLayout, threadsPerWarp, vectorSize)) { LDBG(*op << " does not write coalesced into LDS and is not swizzled"); return failure(); } - if (hasSwizzling && !LLVM::AMD::doesSwizzleInsideWarp( - rewriter, srcToSharedLayout, threadsPerWarp)) { + if (requiresSrcPtrSwizzling && + !LLVM::AMD::doesSwizzleInsideWarp(rewriter, srcToSharedLayout, + threadsPerWarp)) { LDBG(*op << " does swizzle across warp boundaries"); return failure(); } @@ -506,7 +511,7 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { LLVM::getSharedMemoryObjectFromStruct(loc, llDst, resElemTy, rewriter); auto affineOffset = smemObj.getShmemOffset(loc, rewriter, dstTy); auto maskSpanAffineOffset = SharedMemoryObject::getMaskSpanOffsets(dstTy); - auto [_, warpId] = getLaneAndWarpId(rewriter, loc); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); auto calcPaddedOffset = [&](Value smemOffset) { TritonLLVMOpBuilder b(loc, rewriter); auto bitwidth = dstTy.getElementTypeBitWidth(); @@ -519,26 +524,33 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { } return smemOffset; }; - // We pass laneId==0 because GFX9 requires a scalar base pointer into LDS + // If we do not support scattering (GFX9) the address should be the start + // address (scalar) of the warp + laneId = targetInfo.supportsDirectToLDSScattering() ? laneId : b.i32_val(0); lowerLdSt(loc, ctx, cvt, loadVals, resElemTy, smemObj.getBase(), - calcPaddedOffset, affineOffset, maskSpanAffineOffset, - b.i32_val(0), warpId, rewriter, targetInfo, vec, lowerInst); + calcPaddedOffset, affineOffset, maskSpanAffineOffset, laneId, + warpId, rewriter, targetInfo, vec, lowerInst); } void emitOtherStore(RewriterBase &rewriter, Location loc, const LLVMTypeConverter *typeConverter, VectorType vecTy, Value mask, ArrayRef otherElems, Value shmemAddr, - Value laneId, bool hasSwizzling, + Value laneId, bool requiresSrcPtrSwizzling, Value swizzleLaneOffset) const { TritonLLVMOpBuilder b(loc, rewriter); Value storeVal = packElementRangeIntoVector(rewriter, typeConverter, loc, vecTy, otherElems, 0); Type ptrTy = shmemAddr.getType(); - Value ldsAddr = b.gep(ptrTy, vecTy, shmemAddr, laneId); - if (hasSwizzling) - ldsAddr = b.gep(ptrTy, vecTy, ldsAddr, swizzleLaneOffset); + Value ldsAddr = shmemAddr; + // When scattering is unsupported, shmemAddr is the warp base address. + // Use shmemAddr + lane_id [+ swizzleOffset] to compute each lane's address. + if (!targetInfo.supportsDirectToLDSScattering()) { + ldsAddr = b.gep(ptrTy, vecTy, shmemAddr, laneId); + if (requiresSrcPtrSwizzling) + ldsAddr = b.gep(ptrTy, vecTy, ldsAddr, swizzleLaneOffset); + } llStore(rewriter, loc, ldsAddr, storeVal, b.icmp_ne(mask, b.true_val()), - CacheModifier::NONE, /*forceNoAliasAsyncLoads=*/true); + CacheModifier::NONE, targetInfo.requiresAliasInfoForAsyncOps()); } }; @@ -775,9 +787,11 @@ struct BufferLoadToLocalOpConversion } auto maybeSwizzledEnc = dyn_cast(dstEnc); - bool hasSwizzling = maybeSwizzledEnc && maybeSwizzledEnc.getMaxPhase() != 1; + bool requiresSrcPtrSwizzling = + !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && + maybeSwizzledEnc.getMaxPhase() != 1; if (failed(canWriteCoalesced(rewriter, op, ptrType, dstTy, vec, - hasSwizzling))) { + requiresSrcPtrSwizzling))) { return failure(); } @@ -786,7 +800,7 @@ struct BufferLoadToLocalOpConversion auto flatDstTy = dstTy; SmallVector swizzledLaneOffsets; - if (hasSwizzling) { + if (requiresSrcPtrSwizzling) { // TODO (alex): this is only correct as long as the lds view is a // contiguous block. So this can break if we slice along the 2 minor // dimensions. @@ -817,10 +831,10 @@ struct BufferLoadToLocalOpConversion auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); auto emitBufferLoadLds = [this, &op, &b, &bufferEmitter, &rsrcDesc, laneId = laneId, threadPred, - offsetTy, otherTy, hasOther, - hasSwizzling](RewriterBase &rewriter, Location loc, - ArrayRef loadVals, Value shmemAddr, int startIdx, - VectorType vecTy) -> SmallVector { + offsetTy, otherTy, hasOther, requiresSrcPtrSwizzling]( + RewriterBase &rewriter, Location loc, ArrayRef loadVals, + Value shmemAddr, int startIdx, + VectorType vecTy) -> SmallVector { auto [offsetElem, maskElem, otherElems, swizzleLaneOffset] = unzipLoadValues(rewriter, loc, startIdx, loadVals, offsetTy, otherTy, hasOther, vecTy.getNumElements()); @@ -829,7 +843,7 @@ struct BufferLoadToLocalOpConversion Value vecBytesVal = b.i32_val(vecBits / 8); Value maybeSwizzledMaskElem = maskElem; - if (hasSwizzling) + if (requiresSrcPtrSwizzling) applySwizzling(rewriter, loc, offsetElem, maybeSwizzledMaskElem, laneId, swizzleLaneOffset); @@ -839,16 +853,17 @@ struct BufferLoadToLocalOpConversion Value cond = hasOther ? b.and_(threadPred, maybeSwizzledMaskElem) : threadPred; - auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, threadPred); + auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond); auto bufferLoadToLds = bufferEmitter.emitLoadToLds( vecTy, vecBytesVal, rsrcDesc, offsetElem, shmemAddr, hasOther ? b.true_val() : maybeSwizzledMaskElem, op.getCache()); - AMD::addAsyncCopyAliasScope(bufferLoadToLds); + if (targetInfo.requiresAliasInfoForAsyncOps()) + AMD::addAsyncCopyAliasScope(bufferLoadToLds); if (hasOther) { emitOtherStore(rewriter, loc, this->getTypeConverter(), vecTy, maskElem, - otherElems, shmemAddr, laneId, hasSwizzling, + otherElems, shmemAddr, laneId, requiresSrcPtrSwizzling, swizzleLaneOffset); } @@ -913,9 +928,12 @@ struct AsyncCopyGlobalToLocalOpConversion } auto maybeSwizzledEnc = dyn_cast(dstEnc); - bool hasSwizzling = maybeSwizzledEnc && maybeSwizzledEnc.getMaxPhase() != 1; - if (failed( - canWriteCoalesced(rewriter, op, srcTy, dstTy, vec, hasSwizzling))) { + bool requiresSrcPtrSwizzling = + !targetInfo.supportsDirectToLDSScattering() && maybeSwizzledEnc && + maybeSwizzledEnc.getMaxPhase() != 1; + + if (failed(canWriteCoalesced(rewriter, op, srcTy, dstTy, vec, + requiresSrcPtrSwizzling))) { return failure(); } @@ -923,7 +941,7 @@ struct AsyncCopyGlobalToLocalOpConversion // the LDS addresses since we gather into LDS auto flatDstTy = dstTy; SmallVector swizzledLaneOffsets; - if (hasSwizzling) { + if (requiresSrcPtrSwizzling) { auto flatSharedEnc = SwizzledSharedEncodingAttr::get( op->getContext(), maybeSwizzledEnc.getVec(), 1, 1, maybeSwizzledEnc.getOrder(), maybeSwizzledEnc.getCTALayout()); @@ -947,10 +965,10 @@ struct AsyncCopyGlobalToLocalOpConversion auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); auto emitGlobalLoadLds = [this, &op, &b, laneId = laneId, threadPred, srcPtrTy, otherTy, - hasOther, hasSwizzling](RewriterBase &rewriter, Location loc, - ArrayRef loadValues, Value shmemAddr, - int startIdx, - VectorType vecTy) -> SmallVector { + hasOther, requiresSrcPtrSwizzling]( + RewriterBase &rewriter, Location loc, ArrayRef loadValues, + Value shmemAddr, int startIdx, + VectorType vecTy) -> SmallVector { auto [srcElem, maskElem, otherElems, swizzleLaneOffset] = unzipLoadValues(rewriter, loc, startIdx, loadValues, srcPtrTy, otherTy, hasOther, vecTy.getNumElements()); @@ -958,7 +976,7 @@ struct AsyncCopyGlobalToLocalOpConversion assert(targetInfo.supportsDirectToLdsLoadBitWidth(vecBits)); Value maybeSwizzledMaskElem = maskElem; - if (hasSwizzling) + if (requiresSrcPtrSwizzling) applySwizzling(rewriter, loc, srcElem, maybeSwizzledMaskElem, laneId, swizzleLaneOffset); @@ -966,19 +984,14 @@ struct AsyncCopyGlobalToLocalOpConversion auto cond = b.and_(threadPred, maybeSwizzledMaskElem); auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond); - int32_t cacheModifiers = - mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget( - op.getCache(), /*isLoad=*/true, targetInfo); - auto globalLoadLdsOp = rewriter.create( - loc, srcElem, shmemAddr, vecBits / 8, - /*offset=*/0, cacheModifiers, nullptr, nullptr, nullptr); - AMD::addAsyncCopyAliasScope(globalLoadLdsOp); + emitAsyncLoad(rewriter, loc, targetInfo, vecBits, srcElem, shmemAddr, + op.getCache()); rewriter.setInsertionPointToStart(afterLoadBlock); if (hasOther) { emitOtherStore(rewriter, loc, this->getTypeConverter(), vecTy, maskElem, - otherElems, shmemAddr, laneId, hasSwizzling, + otherElems, shmemAddr, laneId, requiresSrcPtrSwizzling, swizzleLaneOffset); } @@ -995,6 +1008,33 @@ struct AsyncCopyGlobalToLocalOpConversion rewriter.replaceOp(op, zero); return success(); } + + void emitAsyncLoad(RewriterBase &rewriter, Location loc, + AMD::TargetInfo targetInfo, int vecBits, Value srcPtr, + Value shmemAddr, triton::CacheModifier cacheMod) const { + auto b = TritonLLVMOpBuilder(loc, rewriter); + int32_t cacheModifiers = + mlir::LLVM::AMD::getCtrlBitsForCacheModifierOnTarget( + cacheMod, /*isLoad=*/true, targetInfo); + + if (llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4}, + targetInfo.getISAFamily())) { + auto globalLoadLdsOp = rewriter.create( + loc, srcPtr, shmemAddr, vecBits / 8, + /*offset=*/0, cacheModifiers, nullptr, nullptr, nullptr); + if (targetInfo.requiresAliasInfoForAsyncOps()) + AMD::addAsyncCopyAliasScope(globalLoadLdsOp); + } else if (targetInfo.getISAFamily() == ISAFamily::GFX1250) { + if (cacheMod != triton::CacheModifier::NONE) { + emitRemark(loc) << "cache modifiers not yet implemented on gfx1250"; + } + std::string intrinsic = + "llvm.amdgcn.global.load.async.to.lds.b" + std::to_string(vecBits); + auto globalLoadLdsOp = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, intrinsic, {}, + {srcPtr, shmemAddr, b.i32_val(0), b.i32_val(cacheModifiers)}); + } + } }; struct AsyncTDMCopyGlobalToLocalOpConversion @@ -1856,38 +1896,46 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(AsyncWaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + switch (targetInfo.getISAFamily()) { case ISAFamily::CDNA1: case ISAFamily::CDNA2: case ISAFamily::CDNA3: - case ISAFamily::CDNA4: + case ISAFamily::CDNA4: { + // global.load.lds uses vmcnt to synchronize + // The rocdl op stores all available counters in a single int32 value (v). + // The vmcnt (6 bits) is split into a lower 3:0 and higher 5:4 parts. + // The lower part is stored in bits 3:0 of v and the higher part in bits + // 15:14. We have to set all other bits in v to 1 to signal we are not + // interested in those. + + // Clamp vmcnt to 6bits; a lower vmcnt will produce a conservative wait + unsigned vmCnt = std::min(63u, op.getNum()); + + // Extract low and high bits and combine while setting all other bits to 1 + unsigned lowBits = vmCnt & 0xF; + unsigned highBits = vmCnt >> 4 << 14; + unsigned otherCnts = ~0xC00F; // C00F has bits 15:14 and 3:0 set + unsigned waitValue = lowBits | highBits | otherCnts; + + rewriter.create(loc, waitValue); break; + } + case ISAFamily::GFX1250: { + // Clamp asyncCnt to 6bits(hw imit); lower means conservative + unsigned asyncCnt = std::min(63u, op.getNum()); + LLVM::createLLVMIntrinsicCallOp(rewriter, loc, + "llvm.amdgcn.s.wait.asynccnt", {}, + {b.i16_val(asyncCnt)}); + break; + } default: return rewriter.notifyMatchFailure( op, "Only supported on CDNA target architecture"); } - auto loc = op->getLoc(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - - // global.load.lds uses vmcnt to synchronize - // The rocdl op stores all available counters in a single int32 value (v). - // The vmcnt (6 bits) is split into a lower 3:0 and higher 5:4 parts. - // The lower part is stored in bits 3:0 of v and the higher part in bits - // 15:14. We have to set all other bits in v to 1 to signal we are not - // interested in those. - - // Clamp vmcnt to 6bits; a lower vmcnt will produce a conservative wait - unsigned vmCnt = std::min(63u, op.getNum()); - - // Extract low and high bits and combine while setting all other bits to 1 - unsigned lowBits = vmCnt & 0xF; - unsigned highBits = vmCnt >> 4 << 14; - unsigned otherCnts = ~0xC00F; // C00F has bits 15:14 and 3:0 set - unsigned waitValue = lowBits | highBits | otherCnts; - - rewriter.create(loc, waitValue); - // Drop the result AsyncToken rewriter.replaceOp(op, b.i32_val(0)); return success(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 2bd2e7c267..783ba0c355 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -81,8 +81,15 @@ int TargetInfo::getWarpSize() const { } int TargetInfo::getSharedMemorySize() const { - int kbytes = getISAFamily() == ISAFamily::CDNA4 ? 160 : 64; - return kbytes * 1024; + // Should return the maximum capacity in kbyte + switch (getISAFamily()) { + case ISAFamily::GFX1250: + return 320 * 1024; + case ISAFamily::CDNA4: + return 160 * 1024; + default: + return 64 * 1024; + } } bool TargetInfo::supportMaximumMinimum() const { @@ -601,6 +608,30 @@ bool TargetInfo::supportVectorizedAtomics() const { return true; } +bool TargetInfo::supportsDirectToLDSScattering() const { + switch (getISAFamily()) { + case ISAFamily::GFX1250: + return true; + case ISAFamily::CDNA3: + case ISAFamily::CDNA4: + return false; + default: + llvm::report_fatal_error( + "Unsupported architecture for direct to lds loads"); + return false; + } +} + +bool TargetInfo::requiresAliasInfoForAsyncOps() const { + switch (getISAFamily()) { + case ISAFamily::CDNA3: + case ISAFamily::CDNA4: + return true; + default: + return false; + } +} + bool TargetInfo::supportsDirectToLdsLoadBitWidth(int bitWidth) const { switch (getISAFamily()) { case ISAFamily::CDNA3: @@ -609,6 +640,10 @@ bool TargetInfo::supportsDirectToLdsLoadBitWidth(int bitWidth) const { case ISAFamily::CDNA4: // Disable 8, 16, 96 bits because they get extended to 32/128 bit. return llvm::is_contained({128, /*96, */ 32, /*16, 8*/}, bitWidth); + case ISAFamily::GFX1250: + // Disable 8, 16 bits because they get extended to 32 bit and therefore + // overwrite. 96 is not a pow2 and generally not useful in Triton + return llvm::is_contained({128, 64, /*96, */ 32, /*16, 8*/}, bitWidth); default: break; } @@ -618,7 +653,8 @@ bool TargetInfo::supportsDirectToLdsLoadBitWidth(int bitWidth) const { void TargetInfo::localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp, Operation *llLoadOp) const { - AMD::addLocalLoadNoAliasScope(localLoadOp, cast(llLoadOp)); + if (requiresAliasInfoForAsyncOps()) + AMD::addLocalLoadNoAliasScope(localLoadOp, cast(llLoadOp)); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 3172248cbf..8772506fd0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -86,6 +86,16 @@ class TargetInfo : public mlir::triton::TargetInfoBase { bool supportVectorizedAtomics() const override; + // Returns true if the target supports per lane addresses into LDS for + // direct-to-lds loads. Some architectures (e.g. GFX9) do not support + // scattering and instead have to write warp coalesced into LDS + bool supportsDirectToLDSScattering() const; + + // Some architectures (GFX9) require alias information on direct-to-lds loads + // and loads from LDS so LLVM does not add conservative waits between those + // ops. For such case we ensure syncronization between data hazards via + // ttg.async_wait + bool requiresAliasInfoForAsyncOps() const; bool supportsDirectToLdsLoadBitWidth(int bitWidth) const; void localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index ff0590597a..8f04f3f17d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -119,7 +119,8 @@ struct ConvertTritonAMDGPUToLLVM // Allocate shared memory and set barrier ModuleAllocation allocation(mod); - AMD::annotateLocalLoadsSyncedViaAsyncWait(mod); + if (targetInfo.requiresAliasInfoForAsyncOps()) + AMD::annotateLocalLoadsSyncedViaAsyncWait(mod); ModuleMembarAnalysis membarPass(&allocation, mlir::triton::AMD::membarFilter); membarPass.run(); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index f17a02dd53..fdeb6a825e 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -453,43 +453,36 @@ Value findScaleAsDecompositionSource(Value v) { return {}; } -// Figure out a best tilesPerWarp parameter that gives largest vector size for -// global load for the given |scale| tensor feeding into dot_scaled op. Returns -// the largest vector size and writes the choice to |result|. -int deduceTilesPerWarp(TypedValue scale, unsigned opIdx, - unsigned nonKDim, ArrayRef warpsPerCTA, - SmallVectorImpl *result) { - std::array chosen{1, 1}; - int vecSize = 1; - if (!scale) { - result->assign(chosen.begin(), chosen.end()); - return vecSize; - } - +// Figure out the best tilesPerWarp that gives largest vector size for |scale| +// tensors feeding into dot_scaled op. +SmallVector deduceTilesPerWarpForScale( + TypedValue scaleA, TypedValue scaleB, + unsigned nonKDim, unsigned m, unsigned n, ArrayRef warpsPerCTA) { // Source code have flexibility to preshuffle scale tensor to achieve better // global load vectorization. That preshuffle scheme is conveyed via some // tl.reshape and tl.trans op combinations. Instead of hardcoding one case or // pattern match the op chain here, we try certain scale tensor layouts and // see which one gives us better vectorization when pushed upwards to the // global load. - // - // For 16x16x128 scaled MFMA intrinsic, each thread only reads one i8 value. - // For better vectorization, we prefer to stick 2x2 such intrinsic together so - // each thread can read 4xi8 values. - SmallVector, 2> choices{{2, 2}, {1, 1}}; - for (const auto &choice : choices) { - LLVM_DEBUG(llvm::dbgs() - << "choice: [" << choice[0] << ", " << choice[1] << "]\n"); + auto inferScaleSrcVecSize = + [&](unsigned opIdx, TypedValue scale, + SmallVector tilesPerWarp) -> unsigned { + if (!scale) + return 1; + LinearLayout layout = ttg::chooseScaledMfmaScaleLayout( - scale.getContext(), opIdx, scale.getType().getShape(), nonKDim, choice, - warpsPerCTA); + scale.getContext(), opIdx, scale.getType().getShape(), nonKDim, + tilesPerWarp, warpsPerCTA); LLVM_DEBUG(llvm::dbgs() << "trying scale layout: " << layout << "\n"); + auto scaleDef = scale.getDefiningOp(); + // assume vec=4 for constant scale + if (isa_and_nonnull(scaleDef)) + return 4; // Infer source layout used for global load using the current scale layout. - auto loadLayoutPair = - ttg::inferSourceLoadLayout(layout, scale.getDefiningOp()); + auto loadLayoutPair = ttg::inferSourceLoadLayout(layout, scaleDef); if (!loadLayoutPair) - continue; + return 1; tt::LoadOp loadOp = loadLayoutPair->first; const LinearLayout &inferredLayout = loadLayoutPair->second; LLVM_DEBUG(llvm::dbgs() @@ -507,18 +500,46 @@ int deduceTilesPerWarp(TypedValue scale, unsigned opIdx, auto sharedLL = triton::gpu::toLinearLayout(loadType.getShape(), passThruShared); auto composedLL = inferredLayout.invertAndCompose(sharedLL).flattenOuts(); + LLVM_DEBUG(llvm::dbgs() + << "inferred composed layout: " << composedLL << "\n"); auto [v, _] = largestVectorisation(context, composedLL, /*bitwidth=*/8, std::nullopt); + return v; + }; - if (v > vecSize) { - LLVM_DEBUG(llvm::dbgs() << "found vector size: " << v << "\n"); - chosen = choice; - vecSize = v; - break; + unsigned largest = 2; + SmallVector chosen{1, 1}; + // For scaled MFMA intrinsic, each thread only reads one i8 value. + // For better vectorization, we prefer to stick tilesPerWarp 2x2 for 16x16x128 + // and 1x1 for 32x32x64 so that each thread can read 4xi8 values. + // limit tilesPerWarp to block boundary + for (unsigned mDimTiles = 1; mDimTiles <= std::min(2u, m / nonKDim); + mDimTiles++) { + for (unsigned nDimTiles = 1; nDimTiles <= std::min(2u, n / nonKDim); + nDimTiles++) { + SmallVector tilesPerWarp{mDimTiles, nDimTiles}; + unsigned vecSizeA = inferScaleSrcVecSize(0, scaleA, tilesPerWarp); + unsigned vecSizeB = inferScaleSrcVecSize(1, scaleB, tilesPerWarp); + LLVM_DEBUG(llvm::dbgs() << "when tilesPerWarp: " << tilesPerWarp[0] + << ", " << tilesPerWarp[1] << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "inferred scaleA vecSize: " << vecSizeA << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "inferred scaleB vecSize: " << vecSizeB << "\n"); + unsigned score = vecSizeA + vecSizeB; + if (score > largest) { + largest = score; + chosen = tilesPerWarp; + } } } - result->assign(chosen.begin(), chosen.end()); - return vecSize; + assert(largest <= 8 && "at most pack 4 scales for scale a & b respectively"); + // fixup: align with dimension that has scale + if (!scaleA && scaleB) + chosen[0] = std::min(ceil(m, nonKDim), chosen[1]); + if (!scaleB && scaleA) + chosen[1] = std::min(ceil(n, nonKDim), chosen[0]); + return chosen; } class BlockedToMFMA : public OpRewritePattern { @@ -611,11 +632,10 @@ class BlockedToMFMA : public OpRewritePattern { SmallVector tilesA{1, 1}, tilesB{1, 1}; Value scaleA = findScaleAsDecompositionSource(a); Value scaleB = findScaleAsDecompositionSource(b); - int vecA = deduceTilesPerWarp(dyn_cast_if_present(scaleA), 0, - mDim, warpsPerTile, &tilesA); - int vecB = deduceTilesPerWarp(dyn_cast_if_present(scaleB), 1, - mDim, warpsPerTile, &tilesB); - tilesPerWarp = vecA > vecB ? tilesA : tilesB; + tilesPerWarp = deduceTilesPerWarpForScale( + dyn_cast_if_present(scaleA), + dyn_cast_if_present(scaleB), mDim, retShape[0], + retShape[1], warpsPerTile); LLVM_DEBUG(llvm::dbgs() << "chosen tilesPerWarp: [" << tilesPerWarp[0] << ", " << tilesPerWarp[1] << "]\n"); } @@ -1066,10 +1086,8 @@ class ScaledBlockedToScaledMFMAF8F6F4 final auto warpsPerTile = warpsPerTileMFMA(dotOp, oldShape, numWarps, {mDim, nDim}); - SmallVector tilesA{1, 1}, tilesB{1, 1}, tilesPerWarp; - int vecA = deduceTilesPerWarp(aScale, 0, mDim, warpsPerTile, &tilesA); - int vecB = deduceTilesPerWarp(bScale, 1, mDim, warpsPerTile, &tilesB); - tilesPerWarp = vecA > vecB ? tilesA : tilesB; + SmallVector tilesPerWarp = deduceTilesPerWarpForScale( + aScale, bScale, mDim, oldShape[0], oldShape[1], warpsPerTile); LLVM_DEBUG(llvm::dbgs() << "chosen tilesPerWarp: [" << tilesPerWarp[0] << ", " << tilesPerWarp[1] << "]\n"); @@ -1249,9 +1267,10 @@ class ScaledBlockedToScaledWMMAF8F6F4 final ScaleDotElemType aElemType = dotOp.getAElemType(); ScaleDotElemType bElemType = dotOp.getBElemType(); - // TODO: Add more supported types auto supportsTypes = [](ScaleDotElemType elemType) { - return elemType == ScaleDotElemType::E2M1; + return elemType == ScaleDotElemType::E2M1 || + elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2; }; if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) { @@ -1315,9 +1334,6 @@ class ScaledBlockedToScaledWMMAF8F6F4 final auto convertScaleLayout = [&](TensorValue scale, llvm::ArrayRef valShape, LinearLayout dotLL, int idx) -> Value { - LinearLayout::BasesT scaleBases = dotLL.getBases(); - auto &warpBases = scaleBases[kWarp]; - SmallVector shape; if (!scale) { int64_t nonKDim = idx == 0 ? valShape[0] : valShape[1]; @@ -1330,7 +1346,7 @@ class ScaledBlockedToScaledWMMAF8F6F4 final } LinearLayout newLL = - ttg::chooseScaledWmmaScaleLayout(ctx, idx, warpBases, shape); + ttg::chooseScaledWmmaScaleLayout(ctx, idx, warpsPerTile, shape); Attribute newScaleEncoding = ttg::LinearEncodingAttr::get(ctx, newLL); // Scale's data type is always i8 auto newScaleType = RankedTensorType::get(shape, i8_ty, newScaleEncoding); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index f2347e269a..8a3237740b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -590,6 +590,8 @@ struct TritonAMDGPUConvertToBufferOpsPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); ModuleOp mod = getOperation(); + auto arch = getAMDArch(mod); + triton::AMD::TargetInfo targetInfo(arch ? arch->str() : ""); // Collect assumptions in the function DenseMap> assumptions = @@ -605,9 +607,15 @@ struct TritonAMDGPUConvertToBufferOpsPass AMD::ModuleAxisInfoAnalysis axisInfoAnalysis(mod); patterns.add, - ConvertTritonLoadToBufferLoad, ConvertTritonStoreToBufferStore>(context, assumptions, solver, this->analyzeSmallTensorOfst); + // BufferLoadToLds is only supported on CDNA3 and CDNA4 + if (llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4}, + targetInfo.getISAFamily())) { + patterns + .add>( + context, assumptions, solver, this->analyzeSmallTensorOfst); + } // Gate buffer atomics behind CDNA3 for now // GFX942-specific assumptions regarding cache coherence are made when diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp index fc82ae7df2..ff8459a9f4 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/UpdateAsyncWaitCount.cpp @@ -57,7 +57,8 @@ int getNumberOfLoadInstructions(RankedTensorType srcTy, // [token] -> ttg.async_commit_group -> [token] -> ttg.async_wait. So here we // scan the operands of ttg.async_commit_group to count the number of issued // async load intrinsics. -int getNumberOfLoadInstructions(Operation *op) { +int getNumOfAsyncLoadInstructionsForOp(Operation *op, + bool emitRemarkOnNonAsyncOp) { if (isa(op)) { int count = 0; for (auto token : op->getOperands()) { @@ -76,7 +77,8 @@ int getNumberOfLoadInstructions(Operation *op) { } return count; } - if (isa(op)) { op->emitRemark("Global memory operation between async wait and " @@ -91,7 +93,10 @@ int getNumberOfLoadInstructions(Operation *op) { // waitcnt to represent the number of hardware instructions we are // interleaving with. This allows us to manually emit the waitcnt during // lowering. -void updateWaitCount(ttg::AsyncWaitOp waitOp, RewriterBase &rewriter) { +template +void updateWaitCount(WaitType waitOp, + llvm::function_ref computeCountForOp, + RewriterBase &rewriter) { int waitCnt = std::numeric_limits::max(); // AsyncWait can await multiple tokens so we get the minimum from all @@ -100,9 +105,7 @@ void updateWaitCount(ttg::AsyncWaitOp waitOp, RewriterBase &rewriter) { // Traverse def chain from waitOp to the producer of the token and count // the minumum number of vmcnt instructions auto tokenWaitCnt = - deduceMinCountOnDefChain(token, waitOp, [](Operation *op) { - return getNumberOfLoadInstructions(op); - }); + deduceMinCountOnDefChain(token, waitOp, computeCountForOp); waitCnt = std::min(waitCnt, tokenWaitCnt); } @@ -125,15 +128,36 @@ struct TritonAMDGPUUpdateAsyncWaitCountPass return; } + // For HW which does not support async loads (GFX9) but only direct-to-lds, + // we still use the waitcnt to support interleaving of direct-to-lds loads + // when pipelining. The flag is used to emit warnings in case we find + // tt.loads/store which make the computed count conservative and hinder + // performance. + bool supportsAsyncLoads = true; + switch (targetInfo.getISAFamily()) { + case triton::AMD::ISAFamily::CDNA3: + case triton::AMD::ISAFamily::CDNA4: + supportsAsyncLoads = false; + break; + default: + break; + } + ModuleOp m = getOperation(); SmallVector waitOps; getOperation()->walk( [&](ttg::AsyncWaitOp waitOp) { waitOps.push_back(waitOp); }); + // Note: AsyncWaits should ignore TDM ops; different HW counter for (auto waitOp : waitOps) { IRRewriter builder(waitOp->getContext()); - updateWaitCount(waitOp, builder); + updateWaitCount( + waitOp, + [&](Operation *op) { + return getNumOfAsyncLoadInstructionsForOp(op, !supportsAsyncLoads); + }, + builder); } } }; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp index 5fb49e6985..2bb12ec65f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp @@ -1,6 +1,7 @@ #include "Utility.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Tools/LayoutUtils.h" #include @@ -159,7 +160,7 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4( return {}; } - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + unsigned bitWidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType()); unsigned elemByteWidth = std::max(bitWidth / 8u, 1u); auto loadBytes = shape[0] * shape[1] * elemByteWidth; if (loadBytes < 16384) { diff --git a/third_party/amd/python/example/gluon/f16_gemm_gfx1250.py b/third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py similarity index 99% rename from third_party/amd/python/example/gluon/f16_gemm_gfx1250.py rename to third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py index 7c7a64c34a..21ef1b0814 100644 --- a/third_party/amd/python/example/gluon/f16_gemm_gfx1250.py +++ b/third_party/amd/python/examples/gluon/f16_gemm_gfx1250.py @@ -19,6 +19,7 @@ class PersistentTileScheduler: pid_end: ttgl.tensor num_pid_m: ttgl.tensor + @gluon.constexpr_function def __init__(self, pid_start, pid_end, num_pid_m): self.pid_start = pid_start self.pid_end = pid_end diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index ceb896c10f..56e4ebf17d 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -333,6 +333,7 @@ def gluon_to_ttgir(self, src, metadata, options, capability): passes.gluon.add_inliner(pm) passes.gluon.add_resolve_auto_encodings(pm) + nvidia.passes.ttnvgpuir.add_tma_lowering(pm) passes.gluon.add_canonicalizer(pm) passes.common.add_sccp(pm) passes.ttir.add_loop_aware_cse(pm) diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp index 37571a0066..dd01b60aff 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp @@ -157,8 +157,7 @@ static Value createBufferView(OpBuilderWithAsyncTaskIds &builder, Value alloc, allocDescType.getShape().end()); auto viewDescType = triton::gpu::MemDescType::get( shape, allocDescType.getElementType(), allocDescType.getEncoding(), - allocDescType.getMemorySpace(), allocDescType.getMutableMemory(), - /*allocShape=*/allocDescType.getAllocShape()); + allocDescType.getMemorySpace(), allocDescType.getMutableMemory()); return builder.create(alloc.getLoc(), viewDescType, alloc, idx); } diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp index 2889a78ca0..b5c4e518b5 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp @@ -145,7 +145,7 @@ int getTxCount(Operation *descOp) { auto encoding = getEncodingFromDescriptor(descOp, tensorType, desc); auto shapePerCTA = getShapePerCTA(encoding, tensorType.getShape()); return product(shapePerCTA) * - tensorType.getElementType().getIntOrFloatBitWidth() / 8; + getIntOrFloatOrPtrBitWidth(tensorType.getElementType()) / 8; } void createNVWSDescriptorLoadOp(OpBuilder &builder, Operation *ttDescLoadOp, diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8c7dfc7771..22f82ba6a9 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -256,7 +256,8 @@ class LoadAcquireOpPattern : public OpRewritePattern { auto loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); Type valueTy = op.getType(); - const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth()); + const unsigned valueNBits = + std::max(8u, (unsigned)getIntOrFloatOrPtrBitWidth(valueTy)); const size_t maxWordWidth = std::max(32, valueNBits); const size_t width = std::min((size_t)valueNBits, maxWordWidth); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 270a2fc91a..f1ed5161a3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -559,7 +559,7 @@ struct FDivOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned bitwidth = getIntOrFloatOrPtrBitWidth(elemTy); StringRef name; Type resultTy; if (32 == bitwidth) { @@ -643,7 +643,7 @@ struct ExpOpConversionApprox Location loc) const { auto b = TritonLLVMOpBuilder(loc, rewriter); // For non-FP32 input, call __nv_expf for higher-precision calculation - if (elemTy.getIntOrFloatBitWidth() != 32) + if (getIntOrFloatOrPtrBitWidth(elemTy) != 32) return {}; const double log2e = 1.4426950408889634; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index c860a05fac..0611453915 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1717,7 +1717,7 @@ LogicalResult AsyncTMAGatherOpConversion::matchAndRewrite( auto callback = [&](Value pred, Value shMemPtr, Value yOffset, ArrayRef xOffsets) { std::string tmaInst = "@$0 cp.async.bulk.tensor.2d.tile::gather4.shared" - "::cluster.global.mbarrier::complete_tx::bytes " + "::cta.global.mbarrier::complete_tx::bytes " "[$1], [$2, {$3, $4, $5, $6, $7}], [$8];"; PTXBuilder ptxBuilder; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index bd661f26a7..3b3008bae7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -195,7 +195,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, auto vecTy = cast(val.getType()); Type elemTy = vecTy.getElementType(); unsigned vec = vecTy.getNumElements(); - unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); assert(llvm::isPowerOf2_32(vec)); if (elemBitwidth < 8) { @@ -213,7 +213,11 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, if (!elemTy.isInteger()) { SmallVector vals = unpackLLVector(loc, val, rewriter); for (Value &v : vals) { - v = b.bitcast(v, int_ty(elemBitwidth)); + if (isa(v.getType())) { + v = b.ptrtoint(int_ty(elemBitwidth), v); + } else { + v = b.bitcast(v, int_ty(elemBitwidth)); + } } storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), pred); @@ -316,7 +320,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, auto vecTy = cast(loadTy); Type elemTy = vecTy.getElementType(); unsigned vec = vecTy.getNumElements(); - unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); assert(llvm::isPowerOf2_32(vec)); if (elemBitwidth < 8) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 593adbc750..dd253e27ec 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -163,7 +163,7 @@ LogicalResult lowerLdStMatrix( auto kOffset = S("offset"); auto kAddr = S("addr"); auto smemPtrTy = ptr_ty(ctx, 3); - auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); // In the contiguous case we can pack elements <= 32 bits // In the transpose case we just have the b8 and b16 cases if ((!transpose && bitwidth > 32) ||