|
| 1 | +#include "Dialect/TritonIntelGPU/Transforms/DecomposeScaledBlocked.h" |
| 2 | + |
| 3 | +#include "mlir/IR/Types.h" |
| 4 | +#include "mlir/IR/Value.h" |
| 5 | +#include "mlir/Support/LogicalResult.h" |
| 6 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 7 | + |
| 8 | +#include "triton/Dialect/Triton/IR/Dialect.h" |
| 9 | +#include "triton/Dialect/TritonGPU/IR/Attributes.h" |
| 10 | +#include "triton/Dialect/TritonGPU/IR/Dialect.h" |
| 11 | + |
| 12 | +using namespace mlir; |
| 13 | +using namespace mlir::triton; |
| 14 | +using namespace mlir::triton::gpu; |
| 15 | + |
| 16 | +namespace { |
| 17 | + |
| 18 | +SmallVector<int, 2> getTransposeOrder(int rank) { |
| 19 | + assert(rank >= 2); |
| 20 | + auto transOrder = llvm::to_vector<2>(llvm::seq<int>(rank - 2)); |
| 21 | + transOrder.push_back(rank - 1); |
| 22 | + transOrder.push_back(rank - 2); |
| 23 | + return transOrder; |
| 24 | +} |
| 25 | + |
| 26 | +class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> { |
| 27 | + |
| 28 | +public: |
| 29 | + DecomposeScaledBlocked(MLIRContext *context, int benefit) |
| 30 | + : OpRewritePattern<DotScaledOp>(context, benefit) {} |
| 31 | + |
| 32 | + LogicalResult matchAndRewrite(DotScaledOp scaledDotOp, |
| 33 | + PatternRewriter &rewriter) const override { |
| 34 | + // Types |
| 35 | + auto computeType = getComputeType(scaledDotOp.getAElemType(), |
| 36 | + scaledDotOp.getBElemType(), rewriter); |
| 37 | + auto loc = scaledDotOp.getLoc(); |
| 38 | + |
| 39 | + auto cvtDotOperand = [&](TypedValue<RankedTensorType> v, |
| 40 | + int opIdx) -> TypedValue<RankedTensorType> { |
| 41 | + auto *ctx = rewriter.getContext(); |
| 42 | + auto retEnc = scaledDotOp.getType().getEncoding(); |
| 43 | + auto vType = v.getType(); |
| 44 | + auto encoding = DotOperandEncodingAttr::get(ctx, opIdx, retEnc, |
| 45 | + vType.getElementType()); |
| 46 | + auto retTy = RankedTensorType::get(vType.getShape(), |
| 47 | + vType.getElementType(), encoding); |
| 48 | + return rewriter.create<ConvertLayoutOp>(loc, retTy, v); |
| 49 | + }; |
| 50 | + |
| 51 | + auto scaledA = scaleArg(rewriter, scaledDotOp, 0, computeType); |
| 52 | + scaledA = cvtDotOperand(scaledA, 0); |
| 53 | + auto scaledB = scaleArg(rewriter, scaledDotOp, 1, computeType); |
| 54 | + scaledB = cvtDotOperand(scaledB, 1); |
| 55 | + auto newDot = rewriter.create<DotOp>(scaledDotOp.getLoc(), scaledA, scaledB, |
| 56 | + scaledDotOp.getC()); |
| 57 | + |
| 58 | + rewriter.replaceOpWithNewOp<ConvertLayoutOp>(scaledDotOp, |
| 59 | + scaledDotOp.getType(), newDot); |
| 60 | + return success(); |
| 61 | + } |
| 62 | + |
| 63 | +private: |
| 64 | + FloatType getComputeType(ScaleDotElemType aType, ScaleDotElemType bType, |
| 65 | + PatternRewriter &rewriter) const { |
| 66 | + if (aType == ScaleDotElemType::FP16 || bType == ScaleDotElemType::FP16) |
| 67 | + return rewriter.getF16Type(); |
| 68 | + return rewriter.getBF16Type(); |
| 69 | + } |
| 70 | + |
| 71 | + TypedValue<RankedTensorType> scaleTo16(PatternRewriter &rewriter, |
| 72 | + TypedValue<RankedTensorType> scale, |
| 73 | + FloatType computeType) const { |
| 74 | + auto loc = scale.getLoc(); |
| 75 | + auto scaleTy = scale.getType(); |
| 76 | + assert(computeType == rewriter.getBF16Type() || |
| 77 | + computeType == rewriter.getF16Type()); |
| 78 | + |
| 79 | + // Choose an fp type that can fit the scale value. |
| 80 | + FloatType largeFpType = computeType == rewriter.getF16Type() |
| 81 | + ? rewriter.getF32Type() |
| 82 | + : computeType; |
| 83 | + int intWidth = largeFpType.getIntOrFloatBitWidth(); |
| 84 | + auto intType = rewriter.getIntegerType(intWidth); |
| 85 | + |
| 86 | + auto zexted = |
| 87 | + rewriter.create<arith::ExtUIOp>(loc, scaleTy.clone(intType), scale); |
| 88 | + // getFpMantissaWidth() returns the number of bits in the mantissa plus the |
| 89 | + // sign bit! |
| 90 | + int shiftValue = largeFpType.getFPMantissaWidth() - 1; |
| 91 | + auto shiftConst = |
| 92 | + rewriter.create<arith::ConstantIntOp>(loc, shiftValue, intWidth); |
| 93 | + auto shift = |
| 94 | + rewriter.create<SplatOp>(loc, scaleTy.clone(intType), shiftConst); |
| 95 | + auto shlRes = rewriter.create<arith::ShLIOp>(loc, zexted, shift); |
| 96 | + Value scaleFP = |
| 97 | + rewriter.create<BitcastOp>(loc, scaleTy.clone(largeFpType), shlRes); |
| 98 | + if (largeFpType != computeType) { |
| 99 | + scaleFP = rewriter.create<arith::TruncFOp>( |
| 100 | + loc, scaleTy.clone(computeType), scaleFP); |
| 101 | + } |
| 102 | + return cast<TypedValue<RankedTensorType>>(scaleFP); |
| 103 | + } |
| 104 | + |
| 105 | + TypedValue<RankedTensorType> |
| 106 | + broadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp, |
| 107 | + ModuleOp mod, TypedValue<RankedTensorType> scale, |
| 108 | + int dim) const { |
| 109 | + auto *ctx = rewriter.getContext(); |
| 110 | + auto loc = scale.getLoc(); |
| 111 | + auto scaleTy = scale.getType(); |
| 112 | + auto rank = scaleTy.getRank(); |
| 113 | + // 2.1) Expand dims along the last dimension |
| 114 | + { |
| 115 | + // 2.1.1) Find default encoding for ExpandDims |
| 116 | + auto shape = to_vector(scaleTy.getShape()); |
| 117 | + shape.insert(shape.end(), 1); |
| 118 | + auto nWarps = lookupNumWarps(scaledDotOp); |
| 119 | + auto threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); |
| 120 | + auto numCTAs = TritonGPUDialect::getNumCTAs(mod); |
| 121 | + auto blockedEnc = getDefaultBlockedEncoding(ctx, shape, nWarps, |
| 122 | + threadsPerWarp, numCTAs); |
| 123 | + // 2.1.2) Cast scale16 to SliceEncoding |
| 124 | + auto sliceEnc = SliceEncodingAttr::get(ctx, rank, blockedEnc); |
| 125 | + auto sliceType = RankedTensorType::get( |
| 126 | + scaleTy.getShape(), scaleTy.getElementType(), sliceEnc); |
| 127 | + scale = rewriter.create<ConvertLayoutOp>(loc, sliceType, scale); |
| 128 | + } |
| 129 | + auto expandScale = rewriter.create<ExpandDimsOp>(loc, scale, rank); |
| 130 | + // 2.2) Broadcast the dimension to size 32 |
| 131 | + auto scaleShape = to_vector(scaleTy.getShape()); |
| 132 | + scaleShape.push_back(32); |
| 133 | + auto broadcastScale = rewriter.create<BroadcastOp>( |
| 134 | + loc, expandScale.getType().clone(scaleShape), expandScale); |
| 135 | + // 2.3) Transpose the dimension to the scaled dimension |
| 136 | + auto transposeOrder = llvm::to_vector(llvm::seq<int32_t>(rank)); |
| 137 | + transposeOrder.insert(transposeOrder.begin() + dim + 1, rank); |
| 138 | + auto transposedScale = |
| 139 | + rewriter.create<TransOp>(loc, broadcastScale, transposeOrder); |
| 140 | + // 2.4) Reshape to the shape of v |
| 141 | + scaleShape.pop_back(); |
| 142 | + scaleShape[dim] *= 32; |
| 143 | + auto reshapeScale = |
| 144 | + rewriter.create<ReshapeOp>(loc, scaleShape, transposedScale); |
| 145 | + return reshapeScale; |
| 146 | + } |
| 147 | + |
| 148 | + TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter, |
| 149 | + DotScaledOp scaledDotOp, ModuleOp mod, |
| 150 | + TypedValue<RankedTensorType> mxfp, |
| 151 | + TypedValue<RankedTensorType> scale, |
| 152 | + FloatType computeType, int dim) const { |
| 153 | + // Implement tl.where(scale == 0xFF, float("nan"), mxfp) |
| 154 | + auto loc = scale.getLoc(); |
| 155 | + |
| 156 | + // FIXME: use large int type (int32) for comparing with 0xFF to avoid |
| 157 | + // accidently masking non-NaN values to NaN. |
| 158 | + // This piece of code will be removed after |
| 159 | + // https://github.com/intel/intel-xpu-backend-for-triton/issues/3605 |
| 160 | + FloatType largeFpType = computeType == rewriter.getF16Type() |
| 161 | + ? rewriter.getF32Type() |
| 162 | + : computeType; |
| 163 | + int intWidth = largeFpType.getIntOrFloatBitWidth(); |
| 164 | + auto intType = rewriter.getIntegerType(intWidth); |
| 165 | + // Use large int scale type, incase it get nonNaN to NaN |
| 166 | + auto scaleTy = scale.getType().clone(intType); |
| 167 | + auto zexted = rewriter.create<arith::ExtUIOp>(loc, scaleTy, scale); |
| 168 | + |
| 169 | + // Scale is NaN |
| 170 | + auto constFF = rewriter.create<arith::ConstantOp>( |
| 171 | + loc, scaleTy, |
| 172 | + DenseElementsAttr::get(scaleTy, |
| 173 | + APInt(scaleTy.getElementTypeBitWidth(), 0xff))); |
| 174 | + auto scaleIsNan = cast<TypedValue<RankedTensorType>>( |
| 175 | + rewriter |
| 176 | + .create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, zexted, |
| 177 | + constFF) |
| 178 | + .getResult()); |
| 179 | + auto cond = broadcastScale(rewriter, scaledDotOp, mod, scaleIsNan, dim); |
| 180 | + // Make scale is NaN compatible with mxfp |
| 181 | + auto condTy = cond.getType(); |
| 182 | + condTy = RankedTensorType::get(condTy.getShape(), condTy.getElementType(), |
| 183 | + mxfp.getType().getEncoding()); |
| 184 | + cond = rewriter.create<ConvertLayoutOp>(loc, condTy, cond); |
| 185 | + |
| 186 | + // Create NaN |
| 187 | + auto mxfpTy = mxfp.getType(); |
| 188 | + auto nan = APFloat::getNaN( |
| 189 | + cast<FloatType>(mxfpTy.getElementType()).getFloatSemantics()); |
| 190 | + auto constNan = rewriter.create<arith::ConstantOp>( |
| 191 | + loc, mxfpTy, DenseElementsAttr::get(mxfpTy, nan)); |
| 192 | + |
| 193 | + auto result = rewriter.create<arith::SelectOp>(loc, cond, constNan, mxfp); |
| 194 | + return cast<TypedValue<RankedTensorType>>(result.getResult()); |
| 195 | + } |
| 196 | + |
| 197 | + TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter, |
| 198 | + DotScaledOp scaledDotOp, int opIdx, |
| 199 | + FloatType computeType) const { |
| 200 | + auto v = opIdx == 0 ? scaledDotOp.getA() : scaledDotOp.getB(); |
| 201 | + auto scale = opIdx == 0 ? scaledDotOp.getAScale() : scaledDotOp.getBScale(); |
| 202 | + auto isFp4 = |
| 203 | + ScaleDotElemType::E2M1 == |
| 204 | + (opIdx == 0 ? scaledDotOp.getAElemType() : scaledDotOp.getBElemType()); |
| 205 | + auto fastMath = scaledDotOp.getFastMath(); |
| 206 | + |
| 207 | + auto *ctx = rewriter.getContext(); |
| 208 | + auto loc = v.getLoc(); |
| 209 | + auto mod = scaledDotOp->getParentOfType<ModuleOp>(); |
| 210 | + auto rank = v.getType().getRank(); |
| 211 | + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; |
| 212 | + |
| 213 | + // 0) Upcast value to computeType (fp16/bf16) |
| 214 | + if (isFp4) { |
| 215 | + // We always pack along the fastest moving dimension, kDim |
| 216 | + v = rewriter.create<Fp4ToFpOp>(loc, v, computeType, kDim); |
| 217 | + } else { |
| 218 | + auto vType16 = v.getType().clone(computeType); |
| 219 | + v = cast<TypedValue<RankedTensorType>>( |
| 220 | + rewriter.create<FpToFpOp>(loc, vType16, v).getResult()); |
| 221 | + } |
| 222 | + if (!scale) |
| 223 | + return v; |
| 224 | + |
| 225 | + // For some weird reason, we take the scale with shape as if it were coming |
| 226 | + // from the lhs even when it's the rhs. In a normal world, we should accept |
| 227 | + // this parametre transposed, as we do with the mxfp. |
| 228 | + if (opIdx == 1) { |
| 229 | + auto order = getTransposeOrder(rank); |
| 230 | + scale = rewriter.create<TransOp>(loc, scale, order); |
| 231 | + } |
| 232 | + |
| 233 | + // 1) Cast scale to compute type (fp16/bf16) |
| 234 | + auto scale16 = scaleTo16(rewriter, scale, computeType); |
| 235 | + |
| 236 | + // 2) Broadcast scale to the same shape and layout as v |
| 237 | + auto reshapeScale = |
| 238 | + broadcastScale(rewriter, scaledDotOp, mod, scale16, kDim); |
| 239 | + reshapeScale = |
| 240 | + rewriter.create<ConvertLayoutOp>(loc, v.getType(), reshapeScale); |
| 241 | + |
| 242 | + // 3) Multiply |
| 243 | + auto mxfp = cast<TypedValue<RankedTensorType>>( |
| 244 | + rewriter.create<arith::MulFOp>(loc, v, reshapeScale).getResult()); |
| 245 | + |
| 246 | + // Skip NaN checks if fastMath |
| 247 | + if (fastMath) |
| 248 | + return mxfp; |
| 249 | + |
| 250 | + // 4) If the scale is NaN, return NaN, else return the scaled value. |
| 251 | + return maskNan(rewriter, scaledDotOp, mod, mxfp, scale, computeType, kDim); |
| 252 | + } |
| 253 | +}; |
| 254 | + |
| 255 | +} // namespace |
| 256 | + |
| 257 | +namespace mlir::triton::gpu::intel { |
| 258 | + |
| 259 | +void populateDecomposeScaledBlockedPatterns(RewritePatternSet &patterns, |
| 260 | + int benefit) { |
| 261 | + patterns.add<DecomposeScaledBlocked>(patterns.getContext(), benefit); |
| 262 | +} |
| 263 | + |
| 264 | +} // namespace mlir::triton::gpu::intel |
0 commit comments