diff --git a/include/triton/Dialect/Triton/IR/Utility.h b/include/triton/Dialect/Triton/IR/Utility.h index 896bf7316a..28b5ac824f 100644 --- a/include/triton/Dialect/Triton/IR/Utility.h +++ b/include/triton/Dialect/Triton/IR/Utility.h @@ -1,6 +1,8 @@ #ifndef TRITON_IR_UTILITY_H_ #define TRITON_IR_UTILITY_H_ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinTypes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include #include @@ -10,6 +12,14 @@ namespace mlir { // Bitwidth of pointers constexpr int kPtrBitWidth = 64; +// Returns the bit width of a type, treating pointer-like types as 64-bit. +// This handles LLVM dialect pointer types. +inline int getIntOrFloatOrPtrBitWidth(Type type) { + if (isa(type)) + return kPtrBitWidth; + return type.getIntOrFloatBitWidth(); +} + template SmallVector convertType(ArrayRef in) { SmallVector out; for (const auto &i : in) diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index bedee8e604..fed4ded91f 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -137,10 +137,9 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, ArrayRef tilesPerWarp, ArrayRef warpsPerCTA); -LinearLayout chooseScaledWmmaScaleLayout( - MLIRContext *ctx, int dotOperandIdx, - const std::vector> &dotOperandWarpBasis, - ArrayRef dotOperandShape); +LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef warpsPerCTA, + ArrayRef dotOperandShape); LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, ArrayRef shape, int opIdx, diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 313dbd41b6..983b6645b8 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -152,7 +152,8 @@ class AllocationAnalysis { auto shapePerCTA = gpu::getAllocationShapePerCTA(allocType); numElems = product(shapePerCTA); } - int64_t bytes = numElems * allocType.getElementTypeBitWidth() / 8; + int64_t bytes = + numElems * getIntOrFloatOrPtrBitWidth(allocType.getElementType()) / 8; auto alignment = alloc.getAlignmentOrDefault(); allocation->addBuffer(alloc, bytes, diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 336667d129..f5106b11a7 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { auto lhsInfo = operands[0]->getValue(); auto rhsInfo = operands[1]->getValue(); auto rank = lhsInfo.getRank(); + assert(isa(op.getType()) || + rank == 1 && "Expected ranked tensor or scalar"); assert(operands.size() == 2 && "Expected two operands"); + auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); + if (constantValue.has_value()) { + auto resTy = dyn_cast(op.getType()); + AxisInfo::DimVectorT constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + AxisInfo::DimVectorT contiguity(rank, 1); + AxisInfo::DimVectorT divisibility( + rank, highestPowOf2Divisor(constantValue.value())); + return AxisInfo(contiguity, divisibility, constancy, constantValue); + } AxisInfo::DimVectorT contiguity; AxisInfo::DimVectorT divisibility; AxisInfo::DimVectorT constancy; - auto constantValue = getConstantValue(op, lhsInfo, rhsInfo); for (auto d = 0; d < rank; ++d) { - if (constantValue.has_value()) { - contiguity.push_back(1); - constancy.push_back( - std::max(lhsInfo.getConstancy(d), rhsInfo.getConstancy(d))); - divisibility.push_back( - highestPowOf2Divisor(constantValue.value())); - } else { - contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); - constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); - divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); - } + contiguity.push_back(getContiguity(op, lhsInfo, rhsInfo, d)); + constancy.push_back(getConstancy(op, lhsInfo, rhsInfo, d)); + divisibility.push_back(getDivisibility(op, lhsInfo, rhsInfo, d)); } return AxisInfo(contiguity, divisibility, constancy, constantValue); } @@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { virtual int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) { - return 1; + return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); } - virtual std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) { return {}; @@ -192,6 +194,26 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { } }; +class UnrealizedConversionCastOpAxisInfoVisitor final + : public AxisInfoVisitorImpl { +public: + using AxisInfoVisitorImpl< + mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl; + + AxisInfo + getAxisInfo(mlir::UnrealizedConversionCastOp op, + ArrayRef *> operands) override { + auto tensorType = dyn_cast(op.getResultTypes()[0]); + if (tensorType && + tensorType.getRank() != operands[0]->getValue().getRank()) { + // Do not propagate AxisInfo with incorrect rank. This can cause a crash + // in future visitor applications. + return AxisInfo::getPessimisticValueState(op->getResult(0)); + } + return operands[0]->getValue(); + } +}; + class MakeRangeOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: @@ -308,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl { return gcd(lhs.getDivisibility(dim), rhsDivisibility); } - int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, - int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -355,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { return std::max(lhsContiguity, rhsContiguity); } - int64_t getConstancy(arith::MulIOp op, const AxisInfo &lhs, - const AxisInfo &rhs, int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto lhsDivisibility = lhs.getDivisibility(dim); @@ -379,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs) override { - if (lhs.getConstantValue().has_value() && - rhs.getConstantValue().has_value()) - return {lhs.getConstantValue().value() * rhs.getConstantValue().value()}; + auto lhsConst = lhs.getConstantValue(); + auto rhsConst = rhs.getConstantValue(); + if (lhsConst.has_value() && rhsConst.has_value()) + return {lhsConst.value() * rhsConst.value()}; + if ((lhsConst.has_value() && lhsConst.value() == 0) || + (rhsConst.has_value() && rhsConst.value() == 0)) + return 0; return {}; } }; @@ -404,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { auto resTy = dyn_cast(op.getType()); + auto constancy = BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); if (!resTy) - return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); + return constancy; auto shape = resTy.getShape(); - // Case 1: both lhs and rhs are constants. - auto constancy = gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - // Case 2: lhs contiguous, rhs constant. + // Case: lhs contiguous, rhs constant. // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p), @@ -506,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl { int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { + auto constancy = BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); auto resTy = dyn_cast(op.getType()); if (!resTy) - return BinaryOpVisitorImpl::getConstancy(op, lhs, rhs, dim); - auto shape = resTy.getShape(); - // lhs % 1 = 0 - return rhs.getConstantValue().has_value() && - rhs.getConstantValue().value() == 1 - ? shape[dim] - : gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); + return constancy; + // Case: lhs % 1 = 0 + if (rhs.getConstantValue().has_value() && + rhs.getConstantValue().value() == 1) + return resTy.getDimSize(dim); + return constancy; } std::optional getConstantValue(OpTy op, const AxisInfo &lhs, @@ -669,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { int64_t constHint = 1; if (lhsInfo.getConstantValue().has_value() && rhsInfo.getConstantValue().has_value()) { - constHint = lhsInfo.getConstancy(d); + constHint = shape[d]; constantValue = compare(getPredicate(op), lhsInfo.getConstantValue().value(), rhsInfo.getConstantValue().value()) @@ -828,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { rhsInfo.getConstantValue().has_value() && lhsInfo.getConstantValue() == rhsInfo.getConstantValue()) constantValue = lhsInfo.getConstantValue(); + + if (constantValue.has_value()) { + auto resTy = dyn_cast(op.getType()); + assert(resTy || rank == 1); + constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + } } return AxisInfo(contiguity, divisibility, constancy, constantValue); @@ -840,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl { using BinaryOpVisitorImpl::BinaryOpVisitorImpl; private: - int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, - int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -890,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl { return multiplyDivisor(lhsDivisibility, 1ll << shift); } - int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, - const AxisInfo &rhs, int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(arith::ShLIOp op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -932,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl { return std::max(1, lhsDivisibility / (int64_t(1) << shift)); } - int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, - int dim) override { - return gcd(lhs.getConstancy(dim), rhs.getConstancy(dim)); - } - std::optional getConstantValue(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && @@ -969,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { constantValue = {std::min(lhsInfo.getConstantValue().value(), rhsInfo.getConstantValue().value())}; } + auto resTy = dyn_cast(op.getType()); + assert(resTy || rank == 1); + AxisInfo::DimVectorT constancy = + resTy ? to_vector(resTy.getShape()) : AxisInfo::DimVectorT(rank, 1); + AxisInfo::DimVectorT divisibility( + rank, highestPowOf2Divisor(constantValue.value())); return AxisInfo(/*knownContiguity=*/AxisInfo::DimVectorT(rank, 1), - /*knownDivisibility=*/AxisInfo::DimVectorT(rank, 1), - /*knownConstancy=*/AxisInfo::DimVectorT(rank, 1), + /*knownDivisibility=*/divisibility, + /*knownConstancy=*/constancy, /*constantValue=*/constantValue); } else { AxisInfo::DimVectorT contiguity, divisibility, constancy; @@ -1029,11 +1042,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver, // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is // in the process of a PartialConversion, where UnrealizedConversionCast // may exist + visitors.append(); visitors.append, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, - CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); visitors.append(); visitors.append(); @@ -1384,7 +1397,10 @@ void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, callee.setArgAttr(index, attrName, attr); }; auto axisInfo = axisInfoMap->lookup(value); - assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + // Only scalar arguments are supported. Do not forward multi-dimensional + // AxisInfo to the callee. + if (axisInfo.getRank() != 1) + continue; setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); setAttrFn("tt.constancy", axisInfo.getConstancy(0)); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 24f7a444a7..73c913b176 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -26,8 +26,6 @@ struct ConvertLayoutOpConversion : public ConvertOpToLLVMPattern { const TargetInfoBase &targetInfo; - // Set benefit to 2 so that this pattern applies before other convert-layout - // conversions. TODO(jlebar): Eventually we want this to be the only pattern. explicit ConvertLayoutOpConversion(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit = 1) @@ -277,8 +275,7 @@ struct ConvertLayoutOpConversion StringAttr kReg = str_attr("register"); StringAttr kLane = str_attr("lane"); auto elemTy = getTypeConverter()->convertType(srcTy.getElementType()); - int bitwidth = - elemTy.isIntOrFloat() ? elemTy.getIntOrFloatBitWidth() : kPtrBitWidth; + int bitwidth = getIntOrFloatOrPtrBitWidth(elemTy); auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, bitwidth); auto &[pReg, pLane, mixedTranspositions, nPack] = factors; diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 37f20547c2..e1d03bc1ca 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -276,7 +276,7 @@ struct ElementwiseInlineAsmOpConversion auto ty = getTypeConverter()->convertType(getElementType(result)); // Pack return elements into 32-bits. - unsigned bitWidth = ty.isIntOrFloat() ? ty.getIntOrFloatBitWidth() : 64; + unsigned bitWidth = getIntOrFloatOrPtrBitWidth(ty); unsigned numElemsPerReg = std::min(std::max(32 / bitWidth, 1u), op.getPackedElement()); assert(op.getPackedElement() % numElemsPerReg == 0); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index b545cff52c..97bd6d4cb0 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -540,7 +540,7 @@ SmallVector lowerLdSt( auto kLane = str_attr("lane"); auto kWarp = str_attr("warp"); auto kOffset = str_attr("offset"); - auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); auto [elemsPerVec, permutation] = largestVectorisation(ctx, cvt, bitwidth, maybeMaxVecElems); @@ -625,7 +625,7 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx, assert(*cvt.getOutDimNames().begin() == str_attr("offset")); auto calcPaddedOffset = [&](Value smemOffset) { TritonLLVMOpBuilder b(loc, rewriter); - auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); if (auto paddedEnc = dyn_cast( srcTy.getEncoding())) { // Apply the offset needed for padding. diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 16f3215b70..ab5d809bdf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -11,6 +11,16 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; using ::mlir::LLVM::getSharedMemoryObjectFromStruct; namespace { + +Value bitOrPtrCast(Value val, Type type, TritonLLVMOpBuilder &b) { + if (isa(val.getType()) && + !isa(type)) { + return b.ptrtoint(type, val); + } else { + return b.bitcast(val, type); + } +} + struct SplatOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert SplatOp or arith::ConstantOp with SplatElementsAttr to a @@ -39,13 +49,13 @@ struct SplatOpConversion : public ConvertOpToLLVMPattern { unsigned ratio = srcBitWidth / cstBitWidth; Type intTy = IntegerType::get(elemType.getContext(), cstBitWidth); VectorType vecType = VectorType::get(ratio, intTy); - Value intCst = b.bitcast(constVal, intTy); + Value intCst = bitOrPtrCast(constVal, intTy, b); Value vec = b.undef(vecType); for (unsigned i = 0; i < ratio; ++i) vec = b.insert_element(vecType, vec, intCst, b.int_val(32, i)); constVal = vec; } - auto llSrc = b.bitcast(constVal, srcType); + Value llSrc = bitOrPtrCast(constVal, srcType, b); size_t elemsPerThread = getTotalElemsPerThread(tensorTy); llvm::SmallVector elems(elemsPerThread, llSrc); return packLLElements(loc, typeConverter, elems, rewriter, resType); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index ac6866157a..c2e89dd8ea 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1438,64 +1438,54 @@ chooseDsReadTrLayout(Attribute enc, ArrayRef shape, } } -LinearLayout chooseScaledWmmaScaleLayout( - MLIRContext *ctx, int dotOperandIdx, - const std::vector> &dotOperandWarpBasis, - ArrayRef dotOperandShape) { +LinearLayout chooseScaledWmmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, + ArrayRef warpsPerCTA, + ArrayRef dotOperandShape) { using basisT = std::vector>; unsigned rank = dotOperandShape.size(); auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true); - auto standardOutDims = standardOutDimNames(ctx, rank); + auto outDimNames = standardOutDimNames(ctx, rank); + StringAttr kRegister = StringAttr::get(ctx, "register"); StringAttr kLane = StringAttr::get(ctx, "lane"); StringAttr kWarp = StringAttr::get(ctx, "warp"); StringAttr kBlock = StringAttr::get(ctx, "block"); - unsigned int scaleKWidth = dotOperandShape[1]; - // Init register layout. Will be adjusted later - auto regs = - mlir::triton::identityStandardND(kRegister, {1, scaleKWidth}, order); - LinearLayout lanes = LinearLayout::empty(); + // In scaled dot, the shapes of operands(without batch dimension) are, // respectively: // - A: [M, K] // - B: [K, N] // - aScale: [M, K / 32 or 16] // - bScale: [N, K / 32 or 16] - // - // To correctly feed A/B and its scale into instruction, we need to - // distribute aScale/bScale among warps in the same way as A/B. But bScale - // is not transposed like B. So we need to transpose the warp layout of - // bScale. - // - // The tricky part is, our desired outputs are [dim0, dim1], but - // at this position, the layouts are transposed to [dim1, dim0]. So - // instead of reverse bScale's layout, we need to reverse aScale's. There - // will be a transpose in the end to correct everything. - basisT warps = dotOperandWarpBasis; - if (dotOperandIdx == 0) { - for (auto &basis : warps) { - std::reverse(basis.begin(), basis.end()); - } - } + auto dimK = outDimNames[order[0]]; + auto dimNonK = outDimNames[order[1]]; - lanes = LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 0}}}, - {kWarp, warps}, - {kBlock, {}}}, - {standardOutDims[order[0]], standardOutDims[order[1]]}); - LinearLayout newLL = regs * lanes; + // Each lane holds kWidth=4 consecutive values along the k dim. + // The first 16 lanes are distributed along the non-k dim. We are not using + // the remaining 16 lanes, so just let them duplicate values of the first 16 + // lanes. If the shape along the k dim is larger than kWidth, repeat this + // pattern to fill the k dim. + unsigned scaleKWidth = 4; + auto kSize = dotOperandShape[1]; + LinearLayout tileLayout = + LinearLayout::identity1D(scaleKWidth, kRegister, dimK) * + LinearLayout::identity1D(16, kLane, dimNonK) * + LinearLayout::zeros1D(2, kLane, dimK) * + LinearLayout::identity1D(kSize / scaleKWidth, kRegister, dimK); - // Adjust register-level layout to fill the shape, at this level, both - // aScale and bScale should align with A operand. - SmallVector repOrder = {1, 0}; - for (auto d : repOrder) { - auto outDim = standardOutDims[d]; - auto dimSize = newLL.getOutDimSize(outDim); - newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister, - outDim); - } - newLL = newLL.transposeOuts(standardOutDims); + auto warpsPerCTANew = (dotOperandIdx == 1) + ? SmallVector{warpsPerCTA[1], warpsPerCTA[0]} + : SmallVector{warpsPerCTA[0], warpsPerCTA[1]}; + + auto warpOrder = (dotOperandIdx == 1) ? SmallVector{0, 1} + : SmallVector{1, 0}; + LinearLayout warpLayout = + identityStandardND(kWarp, warpsPerCTANew, warpOrder); + LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * + warpLayout.transposeOuts(outDimNames); - return newLL; + return combineCtaCgaWithShape( + ctaLayout, CTALayoutAttr::getDefault(ctx, /*rank=*/2), dotOperandShape); } // PTX ISA - Warp-level MMA Block Scaling diff --git a/python/test/gluon/test_core.py b/python/test/gluon/test_core.py index 361940e02c..bf5fc4b4ae 100644 --- a/python/test/gluon/test_core.py +++ b/python/test/gluon/test_core.py @@ -140,6 +140,92 @@ def test_async_copy_mbarrier(device): torch.testing.assert_close(out[20:], torch.zeros((12, 32), **tensor_opts)) +@pytest.mark.xfail(not is_hopper_or_newer(), reason="Requires Hopper", run=False) +def test_device_tma_load(): + + @gluon.jit + def tma_device_load_kernel(input_ptr, output_ptr, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): + input_desc = tma.make_tensor_descriptor( + input_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + + smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + + mbarrier.expect(bar, input_desc.block_type.nbytes) + tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + mbarrier.wait(bar, 0) + mbarrier.invalidate(bar) + + block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) + xindex = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(1, block_layout))[:, None] + yindex = ttgl.arange(0, XBLOCK, ttgl.SliceLayout(0, block_layout))[None, :] + val = smem.load(block_layout) + ttgl.store(output_ptr + yindex + xindex * XBLOCK, val) + + XBLOCK = 16 + input = torch.zeros((XBLOCK, XBLOCK), device="cuda", dtype=torch.float16) + output = torch.ones_like(input) + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + tma_device_load_kernel[(1, )](input, output, XBLOCK, smem_layout) + torch.testing.assert_close(input, output) + + +@pytest.mark.xfail(not is_hopper_or_newer(), reason="Requires Hopper", run=False) +def test_device_tma_store(): + + @gluon.jit + def tma_device_store_kernel(out_ptr, XBLOCK: ttgl.constexpr, smem_layout: ttgl.constexpr): + layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0]) + value = ttgl.full([XBLOCK, XBLOCK], 0, ttgl.float16, layout) + alloc = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], smem_layout, value) + out_desc = tma.make_tensor_descriptor( + out_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + tma.async_copy_shared_to_global(out_desc, [0, 0], alloc) + tma.store_wait(0) + alloc._keep_alive() + + XBLOCK = 16 + out = torch.ones((XBLOCK, XBLOCK), dtype=torch.float16, device="cuda") + smem_layout = ttgl.NVMMASharedLayout( + swizzle_byte_width=32, + element_bitwidth=16, + rank=2, + transposed=False, + fp4_padded=False, + ) + + def alloc_fn(size: int, alignment: int, stream: int): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + tma_device_store_kernel[(1, )](out, XBLOCK, smem_layout) + torch.testing.assert_close(out, torch.zeros_like(out)) + + @gluon.jit def mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, block_layout: ttgl.constexpr, mma_layout: ttgl.constexpr, shared_layout_a: ttgl.constexpr, shared_layout_b: ttgl.constexpr, diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 77cb8707d8..844592eca4 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -2952,3 +2952,94 @@ def test_amd_tdm_store(target): } } """) + + +@pytest.mark.parametrize("target", [BLACKWELL_TARGET, HOPPER_TARGET]) +def test_nv_tma_descriptor_load_kernel(target): + + @gluon.jit + def nv_tma_descriptor_load_kernel(input_ptr): + XBLOCK: ttgl.constexpr = 128 + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2) + input_desc = tma.make_tensor_descriptor( + input_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout) + bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout()) + mbarrier.init(bar, count=1) + mbarrier.expect(bar, XBLOCK * XBLOCK * ttgl.float32.primitive_bitwidth // 8) + tma.async_copy_global_to_shared(input_desc, [0, 0], bar, smem) + + ptr = MockTensor(ttgl.float32) + module = run_parser(nv_tma_descriptor_load_kernel, *make_args(ptr), target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @nv_tma_descriptor_load_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 + %c128_i32_0 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , > + %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + %2 = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared1, #smem, mutable> + ttng.init_barrier %2, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable> + %true = arith.constant true + ttng.barrier_expect %2, 65536, %true : !ttg.memdesc<1xi64, #shared1, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %c0_i32_1 = arith.constant 0 : i32 + %true_2 = arith.constant true + ttng.async_tma_copy_global_to_local %0[%c0_i32, %c0_i32_1] %1, %2, %true_2 : !tt.tensordesc>, !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + tt.return + } +} +""") + + +@pytest.mark.parametrize("target", [BLACKWELL_TARGET, HOPPER_TARGET]) +def test_nv_tma_descriptor_store_kernel(target): + + @gluon.jit + def nv_tma_descriptor_store_kernel(input_ptr): + XBLOCK: ttgl.constexpr = 128 + smem_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2) + input_desc = tma.make_tensor_descriptor( + input_ptr, + shape=[XBLOCK, XBLOCK], + strides=[XBLOCK, 1], + block_shape=[XBLOCK, XBLOCK], + layout=smem_layout, + ) + smem = ttgl.allocate_shared_memory(ttgl.float32, [XBLOCK, XBLOCK], smem_layout) + tma.async_copy_shared_to_global(input_desc, [0, 0], smem) + tma.store_wait(0) + + ptr = MockTensor(ttgl.float32) + module = run_parser(nv_tma_descriptor_store_kernel, *make_args(ptr), target) + expecttest.assert_expected_inline( + anonymize_ir(module.str_nodebug()), """\ +#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 32}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @nv_tma_descriptor_store_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 + %c128_i32_0 = arith.constant 128 : i32 + %c128_i64 = arith.constant 128 : i64 + %c1_i64 = arith.constant 1 : i64 + %0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32_0], [%c128_i64, %c1_i64] : , > + %1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + %c0_i32 = arith.constant 0 : i32 + %c0_i32_1 = arith.constant 0 : i32 + ttng.async_tma_copy_local_to_global %0[%c0_i32, %c0_i32_1] %1 : !tt.tensordesc>, !ttg.memdesc<128x128xf32, #shared, #smem, mutable> + ttng.async_tma_store_wait {pendings = 0 : i32} + tt.return + } +} +""") diff --git a/python/test/unit/language/test_tensor_descriptor.py b/python/test/unit/language/test_tensor_descriptor.py index 40a5d71996..521b4dbab3 100644 --- a/python/test/unit/language/test_tensor_descriptor.py +++ b/python/test/unit/language/test_tensor_descriptor.py @@ -384,8 +384,8 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): @pytest.mark.interpreter def test_tensor_descriptor_padding(device): - if not is_cuda(): - pytest.xfail("padding is unsupported") + if is_xpu(): + pytest.skip("padding is unsupported") @triton.jit def device_tma_load(in_ptr, out_ptr, IM, IN, YM, YN, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr, diff --git a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py index 60a0724c55..717331e53c 100644 --- a/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/blackwell/tma.py @@ -5,6 +5,7 @@ store_wait, tensor_descriptor, tensor_descriptor_type, + make_tensor_descriptor, ) __all__ = [ @@ -15,6 +16,7 @@ "store_wait", "tensor_descriptor", "tensor_descriptor_type", + "make_tensor_descriptor", ] diff --git a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py index ea26e20bd2..a7d3cd4e50 100644 --- a/python/triton/experimental/gluon/language/nvidia/hopper/tma.py +++ b/python/triton/experimental/gluon/language/nvidia/hopper/tma.py @@ -22,6 +22,14 @@ class tensor_descriptor_type(base_type): def __str__(self) -> str: return f"tensor_descriptor<{self.block_type}, {self.layout}>" + def _to_ir(self, builder: ir.builder) -> ir.type: + is_signed = self.block_type.element_ty.is_int_signed() + return builder.get_tensor_descriptor_layout_type( + self.block_type.to_ir(builder), + is_signed, + self.layout._to_ir(builder), + ) + def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]: handle = handles[cursor] cursor += 1 @@ -95,3 +103,66 @@ def async_copy_shared_to_global(tensor_desc, coord, src, _semantic=None): def store_wait(pendings, _semantic=None): pendings = _unwrap_if_constexpr(pendings) _semantic.builder.create_async_tma_store_wait(pendings) + + +@builtin +def make_tensor_descriptor( + base: ttgl.tensor, + shape: List[ttgl.tensor], + strides: List[ttgl.tensor], + block_shape: List[ttgl.constexpr], + layout: NVMMASharedLayout, + padding_option="zero", + _semantic=None, +) -> tensor_descriptor: + padding_option = _unwrap_if_constexpr(padding_option) + + ndim = len(shape) + if not (1 <= ndim <= 5): + raise ValueError(f"Expected 1 <= ndim <= 5 but got {ndim} dimensions") + if len(strides) != ndim: + raise ValueError(f"Expected {ndim} strides but got {len(strides)}") + if len(block_shape) != ndim: + raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") + assert isinstance(base.dtype, ttgl.pointer_type) + elem_size = base.dtype.element_ty.primitive_bitwidth // 8 + contig_dim_size = ttgl._unwrap_if_constexpr(block_shape[-1]) + if contig_dim_size * elem_size < 16: + raise ValueError( + f"Descriptor block shape must have at least 16 bytes in the last dimension, but got {contig_dim_size} * {elem_size} = {contig_dim_size * elem_size} bytes" + ) + + last_stride = ttgl._unwrap_if_constexpr(strides[-1]) + if last_stride != 1: + raise ValueError(f"Tensor descriptor last dim must be 1 but got {last_stride}") + + shape = [_semantic.make_scalar(x, ttgl.int32) for x in shape] + strides = [_semantic.make_scalar(ttgl._unwrap_if_constexpr(x), ttgl.int64) for x in strides] + + # Check whether `block_shape` is static + block_shape = ttgl._unwrap_shape(block_shape) + + assert isinstance(base.type, ttgl.pointer_type) + block_type = ttgl.block_type(base.type.element_ty, block_shape) + base_handle = base.handle + + padding = _semantic._str_to_padding_option(padding_option) + + layout = _unwrap_if_constexpr(layout) + assert isinstance(layout, NVMMASharedLayout), \ + "Expected layout to be a NVMMASharedLayout" + + shape_type = ttgl.tuple(shape).type + strides_type = ttgl.tuple(strides).type + ty = tensor_descriptor_type(block_type, shape_type, strides_type, layout) + + if base.type.element_ty.is_int() and padding == ttgl.ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer blocks") + handle = _semantic.builder.create_make_tensor_descriptor( + ty._to_ir(_semantic.builder), + base_handle, + [s.handle for s in shape], + [s.handle for s in strides], + padding, + ) + return tensor_descriptor(handle, shape, strides, block_type, layout) diff --git a/python/triton_kernels/bench/bench_mlp.py b/python/triton_kernels/bench/bench_mlp.py index a640e07d71..11c50716ad 100644 --- a/python/triton_kernels/bench/bench_mlp.py +++ b/python/triton_kernels/bench/bench_mlp.py @@ -69,6 +69,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d x_dtype = torch.float8_e4m3fnuz input_x = torch.randn((batch // DP, dim1), device=dev) + expt_assignment = triton_dist.create_expt_assignment(EP, n_expts_tot, torch.device(dev)) # run layer fpath = Path(tempfile.mktemp()) proton.start(str(fpath), hook="triton") @@ -78,7 +79,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d if n_expts_tot > 1: # sparse logits = matmul_ogs(xg, wg, bg, precision_config=pcg) x, rdata, gather_indx, scatter_indx, metadata = triton_dist.routing(input_x, logits, n_expts_act, EP=EP, - TP=TP) + TP=TP, expt_assignment=expt_assignment) else: # dense x = triton_dist.all_gather(input_x, dim=0) rdata, gather_indx, scatter_indx, metadata = None, None, None, None @@ -86,7 +87,7 @@ def bench_mlp(batch_per_expt, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act) x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=scatter_indx, precision_config=pc2) - x = triton_dist.reduce_scatter(x, metadata=metadata, dim=0) + x = triton_dist.reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) proton.finalize() return roofline.parse_profile(fpath.with_suffix(".hatchet"), useful_op_regex=".*matmul.*") @@ -136,6 +137,8 @@ def roofline_mlp(batch_sizes, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d parser.add_argument("--name", type=str, choices=["dense", "gpt-oss-x2"]) parser.add_argument("--quantized", action="store_true", default=False) args = parser.parse_args() + if args.tp > 1: + raise NotImplementedError("TP>1 is not supported yet in distributed mode.") dtypes = quantized_dtypes if args.quantized else dense_dtypes if args.name == "dense": assert args.ep == 1, "EP must be 1 for dense" diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 75099b2949..924952a83b 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -5,56 +5,41 @@ import torch.multiprocessing as mp from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, Optional -import triton -import triton.language as tl import triton_kernels import triton_kernels.swiglu +from triton_kernels.reduce import reduce from triton_kernels.matmul_ogs import RoutingData, GatherIndx, ScatterIndx -from triton_kernels.topk import topk_torch from triton_kernels.topk import topk from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq from triton_kernels.tensor_details import layout -from triton_kernels.tensor import BIT, SparseMatrix, Bitmatrix, make_ragged_tensor_metadata +from triton_kernels.tensor import make_ragged_tensor_metadata, remap_ragged_tensor_metadata +from triton_kernels.distributed import make_expt_dict_uniform, make_expt_assignment, convert_dp_to_ep, convert_ep_to_dp, ExptAssignment from bench_utils import quantize_weight -def legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act): - sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix) - dispatch_indx = sparse_logits.mask_metadata.col_sorted_indx - combine_indx = sparse_logits.mask_metadata.row_sorted_indx - ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0]) - gate_scal = sparse_logits.vals.flatten()[combine_indx] - routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, - ragged_batch_metadata) - gather_idx = GatherIndx(combine_indx, dispatch_indx) - scatter_idx = ScatterIndx(dispatch_indx, combine_indx) - return routing_data, gather_idx, scatter_idx - - -def legacy_routing(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None): - if sm_first: - logits = torch.softmax(logits, dim=-1) - sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows) - return legacy_routing_from_bitmatrix(sparse_logits.mask, sparse_logits.vals, sparse_logits.indx, logits.shape[-1], - n_expts_act) - - @dataclass class ReduceScatterMetadata: - input_split_sizes: list[int] - ep_indx: torch.Tensor - EP: int = 1 - TP: int = 1 + mode: str + active_indx: Optional[torch.Tensor] = None + dispatch_indx: Optional[torch.Tensor] = None + combine_indx: Optional[torch.Tensor] = None def _is_distributed_launch() -> bool: return int(os.environ.get("WORLD_SIZE", "1")) > 1 +def create_expt_assignment(EP: int, n_expts_tot: int, device: torch.device) -> Optional[ExptAssignment]: + if not _is_distributed_launch(): + return None + expt_dict = make_expt_dict_uniform(EP, n_expts_tot) + return make_expt_assignment(EP, n_expts_tot, expt_dict, device) + + def setup() -> Tuple[int, int]: if _is_distributed_launch(): world_size = int(os.environ["WORLD_SIZE"]) @@ -102,482 +87,79 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor: def reduce_scatter( input_tensor: torch.Tensor, - metadata: ReduceScatterMetadata = None, + n_expts_act: int, + metadata: ReduceScatterMetadata, + expt_assignment: Optional[ExptAssignment] = None, dim: int = 0, op: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM, ) -> torch.Tensor: if _is_distributed_launch(): - - def dtype_cast(dtype: torch.dtype) -> torch.dtype: - # check if dtype is fp8, then convert it to float16 before reducing - if dtype in [torch.float16, torch.bfloat16, torch.float32]: - return dtype - else: - return torch.float16 - - world_size = dist.get_world_size() - original_dtype = input_tensor.dtype - intermediate_dtype = dtype_cast(original_dtype) - if metadata and metadata.input_split_sizes: - assert dim == 0, "metadata only works with dim=0" - input_list = list(input_tensor.split(metadata.input_split_sizes, dim=0)) - output_list = all_to_all(input_list, dim=0) - n_tokens = metadata.ep_indx.size(dim) - other_dims = input_tensor.shape[1:] - output_tensor = input_tensor.new_zeros((n_tokens, ) + other_dims, dtype=intermediate_dtype) - for i in range(world_size): - ep_rank = i // metadata.TP - mask = torch.any(metadata.ep_indx == ep_rank, dim=1) - if op == dist.ReduceOp.SUM: - output_tensor[mask] += output_list[i].to(intermediate_dtype) - else: - raise NotImplementedError(f"Reduce operation {op} is not implemented.") - return output_tensor.to(original_dtype) + if metadata.mode and metadata.mode == "ep_sharding": + if dim != 0 or op != dist.ReduceOp.SUM: + raise NotImplementedError("Only dim=0 and op=SUM are supported for MoE reduce_scatter.") + output = convert_ep_to_dp(input_tensor, expt_assignment, metadata.active_indx, metadata.combine_indx) + # weighted average of the output token from experts + output = output.view(-1, n_expts_act, output.shape[-1]) + output, _ = reduce(output, dim=1) + return output else: - input_list = list(input_tensor.chunk(world_size, dim=dim)) - shape = input_list[0].shape - input_list = [x.to(intermediate_dtype) for x in input_list] - output_tensor = input_tensor.new_empty(shape, dtype=intermediate_dtype) - dist.reduce_scatter(output_tensor, input_list, op=op) - return output_tensor.to(original_dtype) + raise NotImplementedError(f"Distributed reduce_scatter mode {metadata.mode} is not implemented yet.") else: return input_tensor -def all_to_all(input_list: list[torch.Tensor], dim: int = 0) -> list[torch.Tensor]: - if _is_distributed_launch(): - # Check if all tensors have only one dimension with different sizes - for t in input_list: - for d in range(t.dim()): - if d != dim and t.size(d) != input_list[0].size(d): - raise ValueError("All tensors must have the same size in all dimensions except the specified one.") - input_sizes = [t.size(dim) for t in input_list] - input_sizes = torch.tensor(input_sizes, device=input_list[0].device).unsqueeze(0) - input_sizes = all_gather(input_sizes, dim=0) - output_split_sizes = input_sizes[:, dist.get_rank()].tolist() - other_dims = list(input_list[0].shape[:dim] + input_list[0].shape[dim + 1:]) - output_list = [ - torch.empty([size] + other_dims, dtype=input_list[0].dtype, device=input_list[0].device) - for size in output_split_sizes - ] - dist.all_to_all(output_list, input_list) - return output_list - else: - return input_list - - -def _apply_parallelism( - expt_scal: torch.Tensor, - expt_indx: torch.Tensor, - x: torch.Tensor, - chunk_size: int, - EP: int = 1, - TP: int = 1, -): - if EP > 1: - # Distributed Expert Parallelism - ep_indx = expt_indx // chunk_size - - # Partition tokens by expert parallelism rank - expt_scal_list = [] - expt_indx_list = [] - x_list = [] - - for i in range(EP): - mask = torch.any(ep_indx == i, dim=1) - expt_scal_masked = expt_scal[mask] - expt_indx_masked = expt_indx[mask] - x_masked = x[mask] - - for _ in range(TP): - expt_scal_list.append(expt_scal_masked) - expt_indx_list.append(expt_indx_masked) - x_list.append(x_masked) - - # Exchange data across processes - expt_scal_list = all_to_all(expt_scal_list, dim=0) - expt_indx_list = all_to_all(expt_indx_list, dim=0) - x_list = all_to_all(x_list, dim=0) - - output_split_sizes = [x.size(0) for x in expt_scal_list] - expt_scal = torch.cat(expt_scal_list, dim=0) - expt_indx = torch.cat(expt_indx_list, dim=0) - x = torch.cat(x_list, dim=0) - - # Filter for local experts only - ep_rank = dist.get_rank() // TP - mask = (expt_indx // chunk_size) == ep_rank - expt_indx -= ep_rank * chunk_size - expt_scal = expt_scal.masked_fill(~mask, 0) - expt_indx = expt_indx.masked_fill(~mask, chunk_size) - else: - # Distributed Data Parallelism - ep_indx = None - output_split_sizes = None - x = all_gather(x, dim=0) - expt_scal = all_gather(expt_scal, dim=0) - expt_indx = all_gather(expt_indx, dim=0) - - return expt_scal, expt_indx, ep_indx, x, output_split_sizes - - -def routing_torch(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None, EP=1, TP=1): - _, n_expts_tot = logits.shape - - if n_rows: - logits = logits[:n_rows] - if sm_first: - logits = torch.softmax(logits, dim=-1) - - expt_scal, expt_indx = topk_torch(logits, n_expts_act, expt_indx, has_user_provided_indx=expt_indx is not None) - expt_indx = expt_indx.int() - if not sm_first: - expt_scal = torch.softmax(expt_scal, dim=-1) - - # Sort each token's selections by expert - expt_indx, sort_indices = torch.sort(expt_indx, dim=1, stable=True) - expt_scal = torch.gather(expt_scal, 1, sort_indices) - - chunk_size = n_expts_tot // EP - - expt_scal, expt_indx, ep_indx, x, output_split_sizes = _apply_parallelism(expt_scal, expt_indx, x, chunk_size, - EP=EP, TP=TP) - - # Flatten topk data - expt_scal = expt_scal.reshape(-1) - expt_indx = expt_indx.reshape(-1).to(torch.int32) - - # Sort by expert_id for contiguous experts in matmul - expt_indx, topk_indx = torch.sort(expt_indx, stable=True) - gate_indx = torch.argsort(topk_indx, stable=True) - - mask = expt_indx != chunk_size - topk_indx[~mask] = -1 - gate_indx[gate_indx >= mask.sum()] = -1 - gate_scal = expt_scal[topk_indx] - hist = torch.histc(expt_indx[mask], bins=chunk_size, min=0, max=chunk_size - 1) - - # Pack the matmul data structures - gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) - scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) - n_gates = mask.sum().item() - expt_data = make_ragged_tensor_metadata(hist, n_gates) - - return ( - x, - RoutingData(gate_scal, hist, chunk_size, n_expts_act, expt_data=expt_data), - gather_indx, - scatter_indx, - ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP), - ) - - -@triton.jit -def pack_bitmatrix( - bitmatrix, - expt_indx, - n_rows, - n_cols, - n_expts_act, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - sentinel: tl.constexpr, -): - """ - Packs expt_indx into a bitmatrix. - """ - pid_m = tl.program_id(0) - offsets_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offsets_k = tl.arange(0, BLOCK_SIZE_K) - offsets = offsets_m[:, None] * n_expts_act + offsets_k[None, :] - mask = (offsets_m < n_rows)[:, None] & (offsets_k < n_expts_act)[None, :] - indices = tl.load(expt_indx + offsets, mask=mask, other=sentinel) - div = indices // 32 - rem = indices % 32 - iters = tl.cdiv(sentinel, BLOCK_SIZE_K) - for i in range(iters): - offs = tl.arange(0, BLOCK_SIZE_K // 32) + i * (BLOCK_SIZE_K // 32) - x = tl.where(div[:, :, None] == offs[None, None, :], (1 << rem)[:, :, None], 0) - y = tl.reduce_or(x, axis=1) - bitmatrix_ptrs = bitmatrix + offsets_m[:, None] * n_cols + offs[None, :] - tl.store(bitmatrix_ptrs, y, mask=offsets_m[:, None] < n_rows) - - -@triton.jit -def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr): - pid_m = tl.program_id(0) - cutoff_word = cutoff // 32 - cutoff_bit = cutoff % 32 - cutoff_mask = (1 << (cutoff_bit)) - 1 - for start_n in range(0, shape_bn, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn) - values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values) - values = tl.where(offs_n > cutoff_word, 0, values) - tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn) - - -class PruneRouting(torch.autograd.Function): - - @staticmethod - def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): - from triton_kernels.compaction import compaction - n_tokens_pad = expt_scal.shape[0] - assert n_expts_tot % simulated_ep == 0 - _routing_clear_bitmatrix[(n_tokens_pad, )]( - bitmatrix.storage.data, - bitmatrix.storage.data.stride(0), - bitmatrix.storage.data.stride(1), - bitmatrix.storage.data.shape[1], - n_expts_tot // simulated_ep, - BLOCK_N=512, - ) - # perform compaction to update expt_scal / expt_indx - expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix) - n_expts_tot = n_expts_tot // simulated_ep - bitmatrix.shape[-1] = n_expts_tot - return expt_scal, expt_indx, bitmatrix - - -def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): - return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep) - - -def routing_triton(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None, EP=1, TP=1): - _, n_expts_tot = logits.shape - - if sm_first: - logits = torch.softmax(logits, dim=-1) - - sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows) - expt_scal = sparse_logits.vals - expt_indx = sparse_logits.indx - expt_indx = expt_indx.int() - - chunk_size = n_expts_tot // EP - - expt_scal, expt_indx, ep_indx, x, output_split_sizes = _apply_parallelism(expt_scal, expt_indx, x, chunk_size, - EP=EP, TP=TP) - - # TODO: Skip all the following if `EP == 1` - # Recover bitmatrix for local experts - BLOCK_SIZE_M = 512 - BLOCK_SIZE_K = 32 - # The sentinel value is chunk_size + 1 instead of chunk_size to ensure the bitmatrix buffer - # doesn't overflow. For example, cdiv(32, 32) is 1, while the 32th bit is on the second column. - sentinel = chunk_size + 1 - n_cols = triton.cdiv(sentinel, BLOCK_SIZE_K) - n_rows = expt_indx.size(0) - bitmatrix = torch.zeros((n_rows, n_cols), dtype=torch.uint32, device=expt_indx.device) - grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), ) - - pack_bitmatrix[grid]( - bitmatrix, - expt_indx, - n_rows, - n_cols, - n_expts_act, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_K=BLOCK_SIZE_K, - sentinel=sentinel, - ) - bitmatrix_shape = [n_rows, triton.cdiv(chunk_size, BLOCK_SIZE_K) * 32] - bitmatrix_shape_max = [n_rows, None] - bitmatrix = Bitmatrix(bitmatrix, dtype=BIT, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max) - expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, EP) - routing_data, gather_indx, scatter_indx = legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, - n_expts_tot // EP, n_expts_act) - - return ( - x, - routing_data, - gather_indx, - scatter_indx, - ReduceScatterMetadata(input_split_sizes=output_split_sizes, ep_indx=ep_indx, EP=EP, TP=TP), - ) - - -def routing(x, logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None, EP=1, TP=1, - backend="triton") -> Tuple[RoutingData, GatherIndx, ScatterIndx, ReduceScatterMetadata]: +# TODO: support TP > 1 +# TODO: clean up duplicate code with triton_kernels.test_distributed.py +# TODO: Support nonuniform expert assignment +def routing( + x, logits, n_expts_act, sm_first: bool = False, y_indx: Optional[torch.Tensor] = None, EP: int = 1, TP: int = 1, + expt_assignment: Optional[ExptAssignment] = None, mode: str = "ep_sharding" +) -> Tuple[torch.Tensor, RoutingData, GatherIndx, ScatterIndx, Optional[ReduceScatterMetadata]]: + n_expts_tot = logits.shape[-1] if _is_distributed_launch(): - assert backend in ["torch", "triton"], "backend must be either 'torch' or 'triton'" - if backend == "torch": - return routing_torch(x, logits, n_expts_act, sm_first, expt_indx, n_rows, EP, TP) - elif backend == "triton": - return routing_triton(x, logits, n_expts_act, sm_first, expt_indx, n_rows, EP, TP) + if mode == "ep_sharding": + if not expt_assignment: + raise ValueError("expt_assignment must be provided for distributed routing.") + if TP > 1: + raise NotImplementedError("TP > 1 is not supported in distributed MoE benchmark yet.") + rank = dist.get_rank() + expt_map = expt_assignment.expt_map[rank, :] + logits_global = topk( + logits, + n_expts_act, + apply_softmax=not sm_first, + y_indx=y_indx, + all_gather=True, + ) + active_indx = logits_global.indx + expt_sizes = logits_global.mask_metadata.col_sum + dispatch_indx = logits_global.mask_metadata.col_sorted_indx + combine_indx = logits_global.mask_metadata.row_sorted_indx + logits_global_metadata = make_ragged_tensor_metadata(expt_sizes, dispatch_indx.shape[0]) + x = convert_dp_to_ep(x, expt_assignment, active_indx, dispatch_indx) + logits_local_metadata = remap_ragged_tensor_metadata(logits_global_metadata, expt_map) + gate_scal = logits_global.vals.flatten()[combine_indx] + rdata = RoutingData(gate_scal, expt_sizes, n_expts_tot // EP, n_expts_act, logits_local_metadata) + reduce_scatter_metadata = ReduceScatterMetadata( + mode=mode, + active_indx=active_indx, + dispatch_indx=dispatch_indx, + combine_indx=combine_indx, + ) + return x, rdata, None, None, reduce_scatter_metadata else: - raise ValueError(f"Unknown backend: {backend}") + raise NotImplementedError(f"Distributed routing mode {mode} is not implemented yet.") else: - return x, *legacy_routing(logits, n_expts_act, sm_first, expt_indx, n_rows), None - - -# The following dummy methods simulate the behavior of distributed operations -# in a non-distributed environment for testing purposes. -# Assuming each rank has the same data for simplicity. - - -def dummy_all_gather(out, x): - out[0].copy_(x) - out[1].copy_(x) - - -def dummy_all_to_all(output_list, input_list): - output_list[0].copy_(input_list[0]) - output_list[1].copy_(input_list[0]) - - -def dummy_reduce_scatter(out, x_list, op): - out.copy_(x_list[0] * 2) - - -def test_all_gather_non_distributed(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "1") - x = torch.randn(4, 5) - result = all_gather(x, dim=0) - torch.testing.assert_close(result, x) - - -@pytest.mark.parametrize("dim", [0, 1]) -def test_all_gather_distributed(monkeypatch, dim): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - - x = torch.randn(4, 4) - result = all_gather(x, dim=dim) - expected = torch.cat([x, x], dim=dim) - torch.testing.assert_close(result, expected) - - -def test_reduce_scatter_non_distributed(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "1") - x = torch.randn(4, 6) - result = reduce_scatter(x, dim=0) - torch.testing.assert_close(result, x) - - -def test_reduce_scatter_distributed(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "reduce_scatter", dummy_reduce_scatter) - - x = torch.randn(4, 6) - expected = x.chunk(2, dim=0)[0] * 2 - - result = reduce_scatter(x, dim=0) - torch.testing.assert_close(result, expected) - - -def test_reduce_scatter_distributed_with_metadata(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "get_rank", lambda: 0) - monkeypatch.setattr(dist, "all_to_all", dummy_all_to_all) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - - input_split_sizes = [1, 1] - ep_indx = torch.tensor([[0], [1]]) - metadata = ReduceScatterMetadata(input_split_sizes=input_split_sizes, ep_indx=ep_indx, EP=2) - # Assume the current ep rank is 0. - # [1, 2] comes from rank 0 - # [3, 4] comes from rank 1. - x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) - - result = reduce_scatter(x, metadata=metadata, dim=0) - torch.testing.assert_close(result, torch.tensor([[1, 2], [1, 2]], dtype=torch.float32)) - - -def test_routing_distributed_EP(monkeypatch): - # Test distributed routing with EP=1 (token_mask should be None) - monkeypatch.setenv("WORLD_SIZE", "2") - # Set environment for local rank and distributed group - monkeypatch.setenv("LOCAL_RANK", "0") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "get_rank", lambda: 0) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - monkeypatch.setattr(dist, "all_to_all", dummy_all_to_all) - - # NOTE: must set `device="cuda"` since `routing` expects CUDA tensors. - logits = torch.tensor([[0.1, 0.2, 0.4, 0.3], [0.5, 0.4, 0.3, 0.1]], device="cuda", dtype=torch.float16) - x = torch.randn_like(logits, device="cuda", dtype=torch.float16) - n_expts_act = 2 - EP = 2 - expt_indx = torch.tensor([[0, 1], [0, 1]], device="cuda").reshape(-1) - topk_indx = torch.argsort(expt_indx, stable=True) - gate_indx = torch.argsort(topk_indx, stable=True) - _, rdata, gather_indx, scatter_indx, metadata = routing(x, logits, n_expts_act, EP=EP) - assert torch.equal(gather_indx.src_indx, topk_indx.int()) - assert torch.equal(gather_indx.dst_indx, gate_indx.int()) - assert torch.equal(scatter_indx.src_indx, gate_indx.int()) - assert torch.equal(scatter_indx.dst_indx, topk_indx.int()) - - -def test_all_to_all(monkeypatch): - monkeypatch.setenv("WORLD_SIZE", "2") - monkeypatch.setenv("LOCAL_RANK", "0") - monkeypatch.setattr(dist, "is_initialized", lambda: True) - monkeypatch.setattr(dist, "get_world_size", lambda: 2) - monkeypatch.setattr(dist, "get_rank", lambda: 0) - monkeypatch.setattr(dist, "all_to_all", dummy_all_to_all) - monkeypatch.setattr(dist, "all_gather", dummy_all_gather) - - input_list = [torch.tensor([1, 2], dtype=torch.float32), torch.tensor([3, 4], dtype=torch.float32)] - output_list = all_to_all(input_list) - assert torch.equal(output_list[0], torch.tensor([1, 2], dtype=torch.float32)) - assert torch.equal(output_list[1], torch.tensor([1, 2], dtype=torch.float32)) - assert len(output_list) == 2 - - -def test_pack_bitmatrix(): - # Test parameters - n_rows, n_expts_act = 4, 3 - sentinel = 63 # We have experts 0-62, and 63 is a dummy value - - expt_indx = torch.tensor([[0, 33, 63], [31, 32, 33], [5, 10, 15], [0, 62, 63]], dtype=torch.int32, device="cuda") - n_cols = triton.cdiv(sentinel, 32) - bitmatrix = torch.zeros((n_rows, n_cols), dtype=torch.uint32, device="cuda") - - BLOCK_SIZE_M = 128 - BLOCK_SIZE_K = 32 - grid = (triton.cdiv(n_rows, BLOCK_SIZE_M), ) - - pack_bitmatrix[grid]( - bitmatrix, - expt_indx, - n_rows, - n_cols, - n_expts_act, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_K=BLOCK_SIZE_K, - sentinel=sentinel, - ) - # Prune the bitmatrix to remove dummy values - _routing_clear_bitmatrix[(n_rows, )]( - bitmatrix, - bitmatrix.stride(0), - bitmatrix.stride(1), - bitmatrix.shape[1], - sentinel, - BLOCK_N=128, - ) - - # Old pytorch version do not have "bitwise_and_cpu" not implemented for 'UInt32' - bitmatrix = bitmatrix.cpu().numpy() - - # Verify bit packing - def is_bit_set(row, expert_id): - word_idx, bit_idx = expert_id // 32, expert_id % 32 - return (bitmatrix[row, word_idx] & (1 << bit_idx)) != 0 - - # Check specific cases - assert is_bit_set(0, 0) and is_bit_set(0, 33) and not is_bit_set(0, 63) # Token 0 - assert is_bit_set(1, 31) and is_bit_set(1, 32) and is_bit_set(1, 33) # Token 1 - assert is_bit_set(2, 5) and is_bit_set(2, 10) and is_bit_set(2, 15) # Token 2 - assert is_bit_set(3, 0) and not is_bit_set(3, 63) and is_bit_set(3, 62) # Token 3 + logits = topk(logits, n_expts_act, y_indx=y_indx, apply_softmax=not sm_first) + dispatch_indx = logits.mask_metadata.col_sorted_indx + combine_indx = logits.mask_metadata.row_sorted_indx + ragged_batch_metadata = make_ragged_tensor_metadata(logits.mask_metadata.col_sum, dispatch_indx.shape[0]) + gate_scal = logits.vals.flatten()[combine_indx] + routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act, + ragged_batch_metadata) + gather_indx = GatherIndx(combine_indx, dispatch_indx) + scatter_indx = ScatterIndx(dispatch_indx, combine_indx) + return x, routing_data, gather_indx, scatter_indx, None def gather_ep(rank, world_size, param, TP, EP): @@ -679,13 +261,14 @@ def distributed_run(rank, world_size, batch, dim1, dim2, n_expts_tot, n_expts_ac } xd = torch.randn((batch // world_size, dim1), device=dev).to(dtype_map[x_dtype]) x0 = all_gather(xd, dim=0) + expt_assignment = create_expt_assignment(EP, n_expts_tot, torch.device(dev)) # single-GPU pass def single(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: logits = matmul_ogs(xg, wg, bg, precision_config=pcg) - rdata, gi, si = legacy_routing(logits, n_expts_act) + x, rdata, gi, si, _ = routing(x, logits, n_expts_act) else: rdata = gi = si = None x = matmul_ogs(x, w1_full, b1_full, rdata, gather_indx=gi, precision_config=pc1_full, fused_activation=act) @@ -696,13 +279,13 @@ def distributed(x): xg = x.to(wg.dtype if n_expts_tot > 1 else x.dtype) if n_expts_tot > 1: # sparse logits = matmul_ogs(xg, wg, bg, precision_config=pcg) - x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP) + x, rdata, gi, si, metadata = routing(x, logits, n_expts_act, EP=EP, TP=TP, expt_assignment=expt_assignment) else: # dense x = all_gather(x, dim=0) rdata = gi = si = metadata = None x = matmul_ogs(x, w1, b1, rdata, gather_indx=gi, precision_config=pc1, fused_activation=act) x = matmul_ogs(x, w2, b2 if rank % TP == 0 else None, rdata, scatter_indx=si, precision_config=pc2) - x = reduce_scatter(x, metadata=metadata, dim=0) + x = reduce_scatter(x, n_expts_act, metadata=metadata, expt_assignment=expt_assignment) # gather the result from all GPUs, just for verification return all_gather(x, dim=0) @@ -721,31 +304,20 @@ def distributed(x): @pytest.mark.parametrize( "batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP", - [ - # dense cases - test parallelism - (1024, 1024, 1024, 1, 1, "bf16", "bf16", 1, 1), - (1024, 1024, 1024, 1, 1, "bf16", "bf16", 4, 1), - ] + - # dense cases - test precision - [(1024, 1024, 1024, 1, 1, "fp8", "fp8", 1, 1), (1024, 1024, 1024, 1, 1, "fp8", "fp8", 4, 1)] + # dense cases + [(1024, 1024, 1024, 1, 1, "bf16", "bf16", 1, 1), (1024, 1024, 1024, 1, 1, "fp8", "fp8", 1, 1)] # moe cases - test parallelism + [ (1024, 1024, 1024, 128, 2, "bf16", "bf16", 1, 1), - (1024, 1024, 1024, 128, 2, "bf16", "bf16", 4, 1), (1024, 1024, 1024, 128, 2, "bf16", "bf16", 1, 4), - (1024, 1024, 1024, 128, 2, "bf16", "bf16", 2, 2), ] + # moe cases - test precision ([ (1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 1), - (1024, 1024, 1024, 128, 2, "fp8", "mx4", 4, 1), (1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 4), - (1024, 1024, 1024, 128, 2, "fp8", "mx4", 2, 2), ] if has_native_mx4 else [ (1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 1), - (1024, 1024, 1024, 128, 2, "bf16", "mx4", 4, 1), (1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 4), - (1024, 1024, 1024, 128, 2, "bf16", "mx4", 2, 2), ]), ) def test_mlp_mp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP, monkeypatch): @@ -756,8 +328,8 @@ def test_mlp_mp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, T pytest.skip("Test requires CUDA compute capability >= 9.0.") if is_hip() and get_cdna_version() == 4 and EP > 1: pytest.skip("[TODO] Unknown issue with CDNA 4 and EP > 1") - if TP > 1 and x_dtype == "fp8": - pytest.skip("[TODO] Testing FP8 is not supported for TP > 1.") + if TP > 1: + pytest.skip("[TODO] TP > 1 is not supported yet in distributed mode.") monkeypatch.setenv("WORLD_SIZE", f"{parallelism}") monkeypatch.setenv("MASTER_ADDR", "127.0.0.1") diff --git a/python/triton_kernels/reduce.py b/python/triton_kernels/reduce.py index e408ff5d76..c712a13536 100644 --- a/python/triton_kernels/reduce.py +++ b/python/triton_kernels/reduce.py @@ -147,6 +147,8 @@ def reduce( Returns: - output: torch.Tensor The reduced tensor with `dim` removed. + - output_mxscale: Optional[torch.Tensor] + The output mx scale if input is micro-scaled, else None. """ if x.ndim != 3: raise NotImplementedError("reduce only supports 3D inputs in this implementation") diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 0e93214c31..a5a9b6e4ba 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -458,7 +458,7 @@ tt.func @max_min() { %4 = arith.constant dense<8> : tensor<128xi32> // expected-remark @below {{contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4}} %5 = arith.constant dense<4> : tensor<128xi32> - // expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8}} + // expected-remark @below {{contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8}} %6 = arith.maxsi %4, %5 : tensor<128xi32> tt.return } @@ -974,3 +974,86 @@ tt.func public @trans_4d_tensor_kernel(%arg0: tensor<32x32x32x32xi32> {tt.contig %102 = tt.trans %arg0 {order = array} : tensor<32x32x32x32xi32> -> tensor<32x32x32x32xi32> tt.return } + +// ----- + +tt.func @unrealized_conversion_cast(%arg0: tensor<128x128xi32> {tt.contiguity = dense<[16, 32]> : tensor<2xi32>}) { + // Case 1: AxisInfo is propagated through a sequence of + // unrealized_conversion_cast ops. + // expected-remark @below {{contiguity = [16, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %0 = builtin.unrealized_conversion_cast %arg0 : tensor<128x128xi32> to !llvm.struct<(i32, i32, i32, i32)> + // expected-remark @below {{contiguity = [16, 32], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %1 = builtin.unrealized_conversion_cast %0 : !llvm.struct<(i32, i32, i32, i32)> to tensor<128x128xi32> + + // Case 2: AxisInfo is falling back to the pessimistic state if the + // propagated AxisInfo would be invalid. + // expected-remark @below {{contiguity = [1], divisibility = [1], constancy = [1], constant_value = }} + %2 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)> + // expected-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i32, i32, i32, i32)> to tensor<128x128xi32> + // expected-remark @below {{contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = }} + %4 = tt.trans %3 {order = array} : tensor<128x128xi32> -> tensor<128x128xi32> + tt.return +} + +// ----- + +// Axis analysis does not support multi-dimensional function arguments. Make +// sure that we don't crash. +tt.func @callee(%arg0: tensor<128x1xi32>) { + tt.return +} + +tt.func @caller() { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // expected-remark @below {{contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = }} + %1 = tt.expand_dims %0 {axis = 1: i32} : tensor<128xi32> -> tensor<128x1xi32> + tt.call @callee(%1) : (tensor<128x1xi32>) -> () + tt.return +} + +// ----- + +tt.func @mul_zero_constancy() { + %range = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %zeros = arith.constant dense<0> : tensor<128xi32> + // expected-remark @below {{constancy = [128]}} + %product = arith.muli %zeros, %range : tensor<128xi32> + tt.return +} + +// ----- + +tt.func @max_constancy() { + %c5 = arith.constant dense<5> : tensor<4xi32> + %c7 = arith.constant dense<7> : tensor<4xi32> + // expected-remark @below {{constancy = [4], constant_value = 7}} + %max = arith.maxsi %c5, %c7 : tensor<4xi32> + tt.return +} + +// ----- + +tt.func @select_same_value_constancy() { + %range = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32> + %two = arith.constant dense<2> : tensor<4xi32> + %mod = arith.remsi %range, %two : tensor<4xi32> + %zero = arith.constant dense<0> : tensor<4xi32> + %cond = arith.cmpi ne, %mod, %zero : tensor<4xi32> + %lhs = arith.constant dense<42> : tensor<4xi32> + %rhs = arith.constant dense<42> : tensor<4xi32> + // expected-remark @below {{constancy = [4], constant_value = 42}} + %sel = arith.select %cond, %lhs, %rhs : tensor<4xi1>, tensor<4xi32> + tt.return +} + +// ----- + +tt.func @cmp_after_max_constancy() { + %c5 = arith.constant dense<5> : tensor<4xi32> + %c7 = arith.constant dense<7> : tensor<4xi32> + %max = arith.maxsi %c5, %c7 : tensor<4xi32> + // expected-remark @below {{constancy = [4], constant_value = 1}} + %cmp = arith.cmpi sgt, %max, %c5 : tensor<4xi32> + tt.return +} diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index f0c6d652ed..0140d9cb7b 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -155,6 +155,18 @@ tt.func @preallocate(%A : !tt.ptr) { tt.return } +// expected-remark @below {{memdesc_ptr}} +// expected-remark @below {{size = 6144}} +tt.func @memdesc_ptr() { + // expected-remark @below {{offset = 0, size = 4096}} + %a0 = ttg.local_alloc : () -> !ttg.memdesc<32x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + // expected-remark @below {{offset = 4096, size = 2048}} + %a1 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %a0 : !ttg.memdesc<32x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + ttg.local_dealloc %a1 : !ttg.memdesc<1x16x16x!tt.ptr, #A_SHARED, #ttg.shared_memory, mutable> + tt.return +} + // Unused tensors are immediately released // expected-remark @below {{unused}} // expected-remark @below {{size = 1024}} @@ -279,9 +291,9 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { } -// expected-remark @below {{alloc}} +// expected-remark @below {{alloc_ptr}} // expected-remark @below {{size = 512}} -tt.func @alloc(%A : !tt.ptr) { +tt.func @alloc_ptr(%A : !tt.ptr) { // expected-remark @below {{offset = 0, size = 512}} %cst0 = ttg.local_alloc : () -> !ttg.memdesc<16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> diff --git a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir index ad4c215b2c..3696bf6801 100644 --- a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir +++ b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir @@ -150,6 +150,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar // COMMON: llvm.cond_br // COMMON: llvm.store + // Make sure branch condition is set properly when there is other value. + // COMMON: [[AND:%.*]] = llvm.and + // COMMON: llvm.cond_br [[AND]] + // COMMON: rocdl.raw.ptr.buffer.load.lds // COMMON: llvm.cond_br // COMMON: llvm.store diff --git a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir index bceb77fde2..6337fec57e 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_scaled_to_llvm.mlir @@ -7,8 +7,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: wmma_scaled_dot_fp4 - tt.func @wmma_scaled_dot_fp4(%arg0: tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<16x4xi8, #linear>, %arg2: tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg3: tensor<16x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma> + tt.func @wmma_scaled_dot_fp4(%arg0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<32x4xi8, #linear>, %arg2: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, %arg3: tensor<32x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // Matrix A // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<32xi8> @@ -18,21 +18,185 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<32xi8> // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32> // Scale A - // CHECK-COUNT-2: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> - // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 // Scale B - // CHECK-COUNT-2: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> - // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 // Matrix C // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<8xi32>, i32, vector<8xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> - %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<16x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, tensor<16x4xi8, #linear> * tensor<64x16xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, tensor<16x4xi8, #linear1> -> tensor<16x16xf32, #mma> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e2m1 rhs = e2m1 {fastMath = false} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, tensor<32x4xi8, #linear> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 16}>>, tensor<32x4xi8, #linear1> -> tensor<32x32xf32, #mma> // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> - %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<16x16x!tt.ptr, #mma> - tt.store %ptr0, %c : tensor<16x16x!tt.ptr, #mma> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> +#mma1 = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 64]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp4_fp8 + tt.func @wmma_scaled_dot_fp4_fp8(%arg0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, %arg1: tensor<32x4xi8, #linear>, %arg2: tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Matrix A + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<32xi8> + // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32> + // Matrix B + // CHECK-COUNT-64: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<8xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e2m1 rhs = e4m3 {fastMath = false} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>>, tensor<32x4xi8, #linear> * tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8 + tt.func @wmma_scaled_dot_fp8(%arg0: tensor<32x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<32x4xi8, #linear>, %arg2: tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x4xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Matrix A + // CHECK-COUNT-64: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Matrix B + // CHECK-COUNT-64: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x128xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear> * tensor<128x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x4xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8_k64 + tt.func @wmma_scaled_dot_fp8_k64(%arg0: tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<32x2xi8, #linear>, %arg2: tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x2xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Adjust for acc + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i8) : i8 + // Matrix A + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK-COUNT-32: llvm.insertelement %[[ZERO]], {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Matrix B + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-32: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK-COUNT-32: llvm.insertelement %[[ZERO]], {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-4: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x2xi8, #linear> * tensor<64x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x2xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> + tt.return + } +} + +// ----- + +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +#linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +#mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape=[16, 16, 128]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: wmma_scaled_dot_fp8_repeat_k + tt.func @wmma_scaled_dot_fp8_repeat_k(%arg0: tensor<32x256xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, %arg1: tensor<32x8xi8, #linear>, %arg2: tensor<256x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, %arg3: tensor<32x8xi8, #linear1>, %out0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // Matrix A + // CHECK-COUNT-128: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Matrix B + // CHECK-COUNT-128: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // CHECK-COUNT-64: llvm.insertelement {{.*}} : vector<64xi8> + // CHECK: llvm.bitcast {{.*}} : vector<64xi8> to vector<16xi32> + // Scale A + // CHECK-COUNT-8: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Scale B + // CHECK-COUNT-8: llvm.extractvalue {{.*}} : !llvm.struct<(i8, i8, i8, i8, i8, i8, i8, i8)> + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // CHECK-COUNT-4: llvm.insertelement {{.*}} : vector<4xi8> + // CHECK: llvm.bitcast {{.*}} : vector<4xi8> to i32 + // Matrix C + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<8xf32> + // CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.wmma.scale.f32.16x16x128.f8f6f4"{{.*}} : (i32, vector<16xi32>, i32, vector<16xi32>, i16, vector<8xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<8xf32> + %c = tt.dot_scaled %arg0 scale %arg1, %arg2 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x256xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>, tensor<32x8xi8, #linear> * tensor<256x32xi8, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>, tensor<32x8xi8, #linear1> -> tensor<32x32xf32, #mma> + // CHECK-COUNT-8: llvm.extractelement {{.*}} : vector<8xf32> + // CHECK-COUNT-8: llvm.insertelement {{.*}} : vector<1xf32> + %ptr0 = tt.splat %out0 : !tt.ptr -> tensor<32x32x!tt.ptr, #mma> + tt.store %ptr0, %c : tensor<32x32x!tt.ptr, #mma> tt.return } } diff --git a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir index 2a521f7475..0816308c8d 100644 --- a/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir +++ b/test/TritonGPU/amd/accelerate-amd-matmul-wmma-gfx1250.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx1250 matrix-instruction-size=16" | FileCheck %s --check-prefixes CHECK +// RUN: triton-opt %s -split-input-file --tritonamdgpu-accelerate-matmul="arch-generation-name=gfx1250" | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> @@ -32,3 +32,165 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK{LITERAL}: #mma1 = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 64]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp4_mxfp8 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp4_mxfp8( + %arg0: tensor<32x64xi8, #blocked>, + %arg1: tensor<128x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x4xi8, #blocked2>, + %arg3: tensor<32x4xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x64xi8, #blocked> -> tensor<32x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<128x32xf8E4M3FN, #blocked1> -> tensor<128x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e2m1 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e2m1 rhs = e4m3 {fastMath = false} : tensor<32x64xi8, #blocked>, tensor<32x4xi8, #blocked2> * tensor<128x32xf8E4M3FN, #blocked1>, tensor<32x4xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8( + %arg0: tensor<32x128xf8E4M3FN, #blocked>, + %arg1: tensor<128x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x4xi8, #blocked2>, + %arg3: tensor<32x4xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x128xf8E4M3FN, #blocked> -> tensor<32x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<128x32xf8E4M3FN, #blocked1> -> tensor<128x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x4xi8, #blocked2> -> tensor<32x4xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x128xf8E4M3FN, #blocked>, tensor<32x4xi8, #blocked2> * tensor<128x32xf8E4M3FN, #blocked1>, tensor<32x4xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8_k64 +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8_k64( + %arg0: tensor<32x64xf8E4M3FN, #blocked>, + %arg1: tensor<64x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x2xi8, #blocked2>, + %arg3: tensor<32x2xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x64xf8E4M3FN, #blocked> -> tensor<32x64xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<64x32xf8E4M3FN, #blocked1> -> tensor<64x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x2xi8, #blocked2> -> tensor<32x2xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x2xi8, #blocked2> -> tensor<32x2xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x64xf8E4M3FN, #blocked>, tensor<32x2xi8, #blocked2> * tensor<64x32xf8E4M3FN, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8_repeat_k +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8_repeat_k( + %arg0: tensor<32x256xf8E4M3FN, #blocked>, + %arg1: tensor<256x32xf8E4M3FN, #blocked1>, + %arg2: tensor<32x8xi8, #blocked2>, + %arg3: tensor<32x8xi8, #blocked2>, + %arg4: tensor<32x32x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<32x32xf32, #blocked3> -> tensor<32x32xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<32x256xf8E4M3FN, #blocked> -> tensor<32x256xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<256x32xf8E4M3FN, #blocked1> -> tensor<256x32xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<32x8xi8, #blocked2> -> tensor<32x8xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<32x8xi8, #blocked2> -> tensor<32x8xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<32x256xf8E4M3FN, #blocked>, tensor<32x8xi8, #blocked2> * tensor<256x32xf8E4M3FN, #blocked1>, tensor<32x8xi8, #blocked2> -> tensor<32x32xf32, #blocked3> + tt.store %arg4, %1 : tensor<32x32x!tt.ptr, #blocked3> + tt.return + } +} + + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[0, 0], [16, 0]], block = []}> +// CHECK{LITERAL}: #linear1 = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [0, 0]], warp = [[16, 0], [0, 0]], block = []}> +// CHECK{LITERAL}: #mma = #ttg.amd_wmma<{version = 3, isTranspose = true, warpsPerCTA = [2, 2], instrShape = [16, 16, 128]}> +// CHECK-LABEL: wmma_dot_scaled_mxfp8_repeat_mn +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx1250", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @wmma_dot_scaled_mxfp8_repeat_mn( + %arg0: tensor<64x128xf8E4M3FN, #blocked>, + %arg1: tensor<128x64xf8E4M3FN, #blocked1>, + %arg2: tensor<64x4xi8, #blocked2>, + %arg3: tensor<64x4xi8, #blocked2>, + %arg4: tensor<64x64x!tt.ptr, #blocked3> + ) { + // CHECK-NOT: tt.fp_to_fp + // CHECK: %[[C:.+]] = ttg.convert_layout {{.*}} : tensor<64x64xf32, #blocked3> -> tensor<64x64xf32, #mma> + // CHECK: %[[A:.+]] = ttg.convert_layout {{.*}} : tensor<64x128xf8E4M3FN, #blocked> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> + // CHECK: %[[B:.+]] = ttg.convert_layout {{.*}} : tensor<128x64xf8E4M3FN, #blocked1> -> tensor<128x64xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> + // CHECK: %[[SCALE0:.+]] = ttg.convert_layout {{.*}} : tensor<64x4xi8, #blocked2> -> tensor<64x4xi8, #linear> + // CHECK: %[[SCALE1:.+]] = ttg.convert_layout {{.*}} : tensor<64x4xi8, #blocked2> -> tensor<64x4xi8, #linear1> + // CHECK: tt.dot_scaled %[[A]] scale %[[SCALE0]], %[[B]] scale %[[SCALE1]], %[[C]] lhs = e4m3 rhs = e4m3 + %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked3> + %1 = tt.dot_scaled %arg0 scale %arg2, %arg1 scale %arg3, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<64x128xf8E4M3FN, #blocked>, tensor<64x4xi8, #blocked2> * tensor<128x64xf8E4M3FN, #blocked1>, tensor<64x4xi8, #blocked2> -> tensor<64x64xf32, #blocked3> + tt.store %arg4, %1 : tensor<64x64x!tt.ptr, #blocked3> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp index 2ef0816ff4..3cfd5b9bea 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp @@ -39,8 +39,8 @@ using ValueTable = std::map, Value>; ValueTable getValuesFromDotOperandLayoutStruct( ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, - int wmmaVer, Value value, int batch, int n0, int n1, int kBase, Type type, - Location loc) { + int wmmaVer, Value value, int batch, int n0, int n1, int kBase, + int kPadding, Type type, Location loc) { auto tb = TritonLLVMOpBuilder(loc, rewriter); auto elems = unpackLLElements(loc, value, rewriter); ValueTable vals; @@ -50,11 +50,18 @@ ValueTable getValuesFromDotOperandLayoutStruct( Type elemTy = typeConverter->convertType(type); Type ty = vec_ty(elemTy, kBase); Value rawElems = tb.undef(ty); + for (int k = 0; k < kBase; ++k) { - rawElems = tb.insert_element( - ty, rawElems, - elems[n0 * n1 * kBase * b + kBase * (n1 * i + j) + k], - tb.i32_val(k)); + int idx = n0 * n1 * kBase * b + kBase * (n1 * i + j) + k; + if (k < kBase - kPadding) { + rawElems = + tb.insert_element(ty, rawElems, elems[idx], tb.i32_val(k)); + } else { + // pad with zeros + Value zero = rewriter.create( + loc, elemTy, rewriter.getZeroAttr(elemTy)); + rawElems = tb.insert_element(ty, rawElems, zero, tb.i32_val(k)); + } } Value convertedElems; @@ -65,12 +72,14 @@ ValueTable getValuesFromDotOperandLayoutStruct( // Before wmma v3, bf16 is converted to i16 if (wmmaVer < 3) convertedElems = tb.bitcast(rawElems, vec_ty(i16_ty, kBase)); - } else if (kBase == 4 && type.getIntOrFloatBitWidth() == 8) { - convertedElems = tb.bitcast(rawElems, i32_ty); } else { - convertedElems = tb.bitcast( - rawElems, vec_ty(i32_ty, kBase * type.getIntOrFloatBitWidth() / - i32_ty.getIntOrFloatBitWidth())); + auto elems = kBase * type.getIntOrFloatBitWidth() / + i32_ty.getIntOrFloatBitWidth(); + assert(elems >= 1 && "unexpected number of elements"); + if (elems == 1) + convertedElems = tb.bitcast(rawElems, i32_ty); + else + convertedElems = tb.bitcast(rawElems, vec_ty(i32_ty, elems)); } vals[{b, i, j}] = convertedElems; } @@ -254,13 +263,19 @@ LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor, auto numRepK = repA[2]; auto numRepB = repA[0]; - int kBase = maybeWmmaIntrinsic->kBase; + // If kDim > kDimTensor, we need add zeros to the kBase vector. The amount of + // zeros is determined by kBase * (1 - kDimTensor / kDim) + auto kBase = maybeWmmaIntrinsic->kBase; + auto kDimTensor = aTensorTy.getShape().back(); + auto paddingFactor = kDim > kDimTensor ? (kDim / kDimTensor) : 1; + auto kPadding = kBase - kBase / paddingFactor; + ValueTable ha = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedA, numRepB, numRepM, numRepK, - kBase, aTensorTy.getElementType(), loc); + kBase, kPadding, aTensorTy.getElementType(), loc); ValueTable hb = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedB, numRepB, numRepN, numRepK, - kBase, aTensorTy.getElementType(), loc); + kBase, kPadding, aTensorTy.getElementType(), loc); auto dstElemTy = dTensorTy.getElementType(); auto fc = unpackLLElements(loc, loadedC, rewriter); @@ -373,6 +388,13 @@ LogicalResult convertScaledDot(triton::DotScaledOp op, int kBaseB = isFp4B ? kBase / 2 : kBase; int kDimB = isFp4B ? kDim / 2 : kDim; + bool isFp6A = (op.getAElemType() == triton::ScaleDotElemType::E2M3) || + (op.getAElemType() == triton::ScaleDotElemType::E3M2); + bool isFp6B = (op.getBElemType() == triton::ScaleDotElemType::E2M3) || + (op.getBElemType() == triton::ScaleDotElemType::E3M2); + if (isFp6A || isFp6B) + return op.emitError("NYI: FP6 scaled dot"); + auto repA = wmmaLayout.getRepForOperand(aTensorTy.getShape(), kDimA, 0); auto repB = wmmaLayout.getRepForOperand(bTensorTy.getShape(), kDimB, 1); @@ -388,23 +410,26 @@ LogicalResult convertScaledDot(triton::DotScaledOp op, auto numRepK = repA[2]; auto numRepB = repA[0]; - auto scaleShapeA = aScaleTensorTy.getShape(); - constexpr int scaleKWidthA = 4; - auto scaleShapeB = bScaleTensorTy.getShape(); - constexpr int scaleKWidthB = 4; + // If kDim > kDimTensor, we need add zeros to the kBase vector. The amount of + // zeros is determined by kBase * (1 - kDimTensor / kDim) + auto kDimTensorA = aTensorTy.getShape().back(); + auto paddingFactor = kDimA > kDimTensorA ? (kDimA / kDimTensorA) : 1; + auto kPaddingA = kBaseA - kBaseA / paddingFactor; + auto kPaddingB = kBaseB - kBaseB / paddingFactor; + auto KBaseScale = 4; ValueTable ha = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedA, numRepB, numRepM, numRepK, - kBaseA, aTensorTy.getElementType(), loc); + kBaseA, kPaddingA, aTensorTy.getElementType(), loc); ValueTable hb = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedB, numRepB, numRepN, numRepK, - kBaseB, bTensorTy.getElementType(), loc); + kBaseB, kPaddingB, bTensorTy.getElementType(), loc); ValueTable sa = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedAScale, numRepB, numRepM, numRepK, - scaleKWidthA, aScaleTensorTy.getElementType(), loc); + KBaseScale, 0, aScaleTensorTy.getElementType(), loc); ValueTable sb = getValuesFromDotOperandLayoutStruct( rewriter, typeConverter, wmmaVer, loadedBScale, numRepB, numRepN, numRepK, - scaleKWidthB, bScaleTensorTy.getElementType(), loc); + KBaseScale, 0, bScaleTensorTy.getElementType(), loc); auto dstElemTy = dTensorTy.getElementType(); auto fc = unpackLLElements(loc, loadedC, rewriter); @@ -438,11 +463,11 @@ LogicalResult convertScaledDot(triton::DotScaledOp op, ? generateScaledWMMAIntrinsic( rewriter, loc, hb[{b, n, k}], sb[{b, n, k}], ha[{b, m, k}], sa[{b, m, k}], acc, scaledBElemType, - scaledAElemType, dstElemTy, scaleKWidthA) + scaledAElemType, dstElemTy, KBaseScale) : generateScaledWMMAIntrinsic( rewriter, loc, ha[{b, m, k}], sa[{b, m, k}], hb[{b, n, k}], sb[{b, n, k}], acc, scaledAElemType, - scaledBElemType, dstElemTy, scaleKWidthB); + scaledBElemType, dstElemTy, KBaseScale); } for (unsigned v = 0; v < dElemsToStorePerThread; ++v) { fc[fcThreadOffIdx + v] = diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 4e76cf1dbb..cfb8942628 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -839,7 +839,7 @@ struct BufferLoadToLocalOpConversion Value cond = hasOther ? b.and_(threadPred, maybeSwizzledMaskElem) : threadPred; - auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, threadPred); + auto [loadBlock, afterLoadBlock] = emitBranch(rewriter, loc, cond); auto bufferLoadToLds = bufferEmitter.emitLoadToLds( vecTy, vecBytesVal, rsrcDesc, offsetElem, shmemAddr, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 2bd2e7c267..4462b8cc9a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -81,8 +81,15 @@ int TargetInfo::getWarpSize() const { } int TargetInfo::getSharedMemorySize() const { - int kbytes = getISAFamily() == ISAFamily::CDNA4 ? 160 : 64; - return kbytes * 1024; + // Should return the maximum capacity in kbyte + switch (getISAFamily()) { + case ISAFamily::GFX1250: + return 320 * 1024; + case ISAFamily::CDNA4: + return 160 * 1024; + default: + return 64 * 1024; + } } bool TargetInfo::supportMaximumMinimum() const { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index f17a02dd53..2a611f2875 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -1249,9 +1249,10 @@ class ScaledBlockedToScaledWMMAF8F6F4 final ScaleDotElemType aElemType = dotOp.getAElemType(); ScaleDotElemType bElemType = dotOp.getBElemType(); - // TODO: Add more supported types auto supportsTypes = [](ScaleDotElemType elemType) { - return elemType == ScaleDotElemType::E2M1; + return elemType == ScaleDotElemType::E2M1 || + elemType == ScaleDotElemType::E4M3 || + elemType == ScaleDotElemType::E5M2; }; if (!supportsTypes(aElemType) || !supportsTypes(bElemType)) { @@ -1315,9 +1316,6 @@ class ScaledBlockedToScaledWMMAF8F6F4 final auto convertScaleLayout = [&](TensorValue scale, llvm::ArrayRef valShape, LinearLayout dotLL, int idx) -> Value { - LinearLayout::BasesT scaleBases = dotLL.getBases(); - auto &warpBases = scaleBases[kWarp]; - SmallVector shape; if (!scale) { int64_t nonKDim = idx == 0 ? valShape[0] : valShape[1]; @@ -1330,7 +1328,7 @@ class ScaledBlockedToScaledWMMAF8F6F4 final } LinearLayout newLL = - ttg::chooseScaledWmmaScaleLayout(ctx, idx, warpBases, shape); + ttg::chooseScaledWmmaScaleLayout(ctx, idx, warpsPerTile, shape); Attribute newScaleEncoding = ttg::LinearEncodingAttr::get(ctx, newLL); // Scale's data type is always i8 auto newScaleType = RankedTensorType::get(shape, i8_ty, newScaleEncoding); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp index 5fb49e6985..2bb12ec65f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/Utility.cpp @@ -1,6 +1,7 @@ #include "Utility.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Tools/LayoutUtils.h" #include @@ -159,7 +160,7 @@ ttg::PaddedSharedEncodingAttr composePaddedLayoutForAsyncCopyCDNA4( return {}; } - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + unsigned bitWidth = getIntOrFloatOrPtrBitWidth(srcTy.getElementType()); unsigned elemByteWidth = std::max(bitWidth / 8u, 1u); auto loadBytes = shape[0] * shape[1] * elemByteWidth; if (loadBytes < 16384) { diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index ceb896c10f..56e4ebf17d 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -333,6 +333,7 @@ def gluon_to_ttgir(self, src, metadata, options, capability): passes.gluon.add_inliner(pm) passes.gluon.add_resolve_auto_encodings(pm) + nvidia.passes.ttnvgpuir.add_tma_lowering(pm) passes.gluon.add_canonicalizer(pm) passes.common.add_sccp(pm) passes.ttir.add_loop_aware_cse(pm) diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp index 2889a78ca0..b5c4e518b5 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/InsertAref.cpp @@ -145,7 +145,7 @@ int getTxCount(Operation *descOp) { auto encoding = getEncodingFromDescriptor(descOp, tensorType, desc); auto shapePerCTA = getShapePerCTA(encoding, tensorType.getShape()); return product(shapePerCTA) * - tensorType.getElementType().getIntOrFloatBitWidth() / 8; + getIntOrFloatOrPtrBitWidth(tensorType.getElementType()) / 8; } void createNVWSDescriptorLoadOp(OpBuilder &builder, Operation *ttDescLoadOp, diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8c7dfc7771..22f82ba6a9 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -256,7 +256,8 @@ class LoadAcquireOpPattern : public OpRewritePattern { auto loc = op->getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); Type valueTy = op.getType(); - const unsigned valueNBits = std::max(8u, valueTy.getIntOrFloatBitWidth()); + const unsigned valueNBits = + std::max(8u, (unsigned)getIntOrFloatOrPtrBitWidth(valueTy)); const size_t maxWordWidth = std::max(32, valueNBits); const size_t width = std::min((size_t)valueNBits, maxWordWidth); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp index 270a2fc91a..f1ed5161a3 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -559,7 +559,7 @@ struct FDivOpConversion ConversionPatternRewriter &rewriter, Type elemTy, MultipleOperandsRange operands, Location loc) const { - unsigned bitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned bitwidth = getIntOrFloatOrPtrBitWidth(elemTy); StringRef name; Type resultTy; if (32 == bitwidth) { @@ -643,7 +643,7 @@ struct ExpOpConversionApprox Location loc) const { auto b = TritonLLVMOpBuilder(loc, rewriter); // For non-FP32 input, call __nv_expf for higher-precision calculation - if (elemTy.getIntOrFloatBitWidth() != 32) + if (getIntOrFloatOrPtrBitWidth(elemTy) != 32) return {}; const double log2e = 1.4426950408889634; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index bd661f26a7..3b3008bae7 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -195,7 +195,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, auto vecTy = cast(val.getType()); Type elemTy = vecTy.getElementType(); unsigned vec = vecTy.getNumElements(); - unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); assert(llvm::isPowerOf2_32(vec)); if (elemBitwidth < 8) { @@ -213,7 +213,11 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, if (!elemTy.isInteger()) { SmallVector vals = unpackLLVector(loc, val, rewriter); for (Value &v : vals) { - v = b.bitcast(v, int_ty(elemBitwidth)); + if (isa(v.getType())) { + v = b.ptrtoint(int_ty(elemBitwidth), v); + } else { + v = b.bitcast(v, int_ty(elemBitwidth)); + } } storeDShared(rewriter, loc, ptr, ctaId, packLLVector(loc, vals, rewriter), pred); @@ -316,7 +320,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, auto vecTy = cast(loadTy); Type elemTy = vecTy.getElementType(); unsigned vec = vecTy.getNumElements(); - unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); + unsigned elemBitwidth = getIntOrFloatOrPtrBitWidth(elemTy); assert(llvm::isPowerOf2_32(vec)); if (elemBitwidth < 8) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 593adbc750..dd253e27ec 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -163,7 +163,7 @@ LogicalResult lowerLdStMatrix( auto kOffset = S("offset"); auto kAddr = S("addr"); auto smemPtrTy = ptr_ty(ctx, 3); - auto bitwidth = llvmElemTy.getIntOrFloatBitWidth(); + auto bitwidth = getIntOrFloatOrPtrBitWidth(llvmElemTy); // In the contiguous case we can pack elements <= 32 bits // In the transpose case we just have the b8 and b16 cases if ((!transpose && bitwidth > 32) ||