|
| 1 | +#pragma once |
| 2 | +#include "mlir/IR/BuiltinTypes.h" |
| 3 | +#include "mlir/IR/PatternMatch.h" |
| 4 | +#include "triton/Dialect/Triton/IR/Dialect.h" |
| 5 | + |
| 6 | +namespace mlir::triton::nvidia_gpu { |
| 7 | + |
| 8 | +constexpr inline int TMA_SIZE_BYTES = 128; |
| 9 | +constexpr inline int TMA_ALIGN = 128; |
| 10 | + |
| 11 | +template <typename BuilderT> |
| 12 | +mlir::LogicalResult createTMADesc(mlir::Value tmaPtr, |
| 13 | + mlir::triton::MakeTensorDescOp op, |
| 14 | + BuilderT &builder) { |
| 15 | + using namespace mlir; |
| 16 | + MLIRContext *ctx = op.getContext(); |
| 17 | + auto loc = op.getLoc(); |
| 18 | + auto mkI32Constant = [&](int32_t val) { |
| 19 | + return builder.template create<arith::ConstantOp>( |
| 20 | + loc, builder.getI32Type(), builder.getI32IntegerAttr(val)); |
| 21 | + }; |
| 22 | + |
| 23 | + auto elemType = op.getBase().getType().getPointeeType(); |
| 24 | + auto elemSize = elemType.getIntOrFloatBitWidth() / 8; |
| 25 | + |
| 26 | + int32_t contig_dim_size = op.getTensorShape().back(); |
| 27 | + int32_t contig_dim_size_in_bytes = contig_dim_size * elemSize; |
| 28 | + if (contig_dim_size_in_bytes > 128) { |
| 29 | + contig_dim_size = 128 / elemSize; |
| 30 | + } |
| 31 | + llvm::SmallVector<Value> boxDim; |
| 32 | + boxDim.push_back(mkI32Constant(contig_dim_size)); |
| 33 | + for (int k = op.getTensorShape().size() - 2; k >= 0; --k) { |
| 34 | + boxDim.push_back(mkI32Constant(op.getTensorShape()[k])); |
| 35 | + } |
| 36 | + |
| 37 | + int32_t swizzle_mode; |
| 38 | + if (contig_dim_size_in_bytes >= 128) { |
| 39 | + swizzle_mode = 3; |
| 40 | + } else if (contig_dim_size_in_bytes == 64) { |
| 41 | + swizzle_mode = 2; |
| 42 | + } else if (contig_dim_size_in_bytes == 32) { |
| 43 | + swizzle_mode = 1; |
| 44 | + } else { |
| 45 | + op->emitError() |
| 46 | + << "contiguous box dimension must be at least 32 bytes but got " |
| 47 | + << contig_dim_size_in_bytes; |
| 48 | + return failure(); |
| 49 | + } |
| 50 | + |
| 51 | + Value elemSizeVal = builder.template create<arith::ConstantOp>( |
| 52 | + loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize)); |
| 53 | + Value globalStride = builder.template create<arith::MulIOp>( |
| 54 | + loc, op.getStrides()[0], elemSizeVal); |
| 55 | + // TODO: Workaround for ptxas bug, remove when we update ptxas |
| 56 | + Value four = builder.template create<arith::ConstantOp>( |
| 57 | + loc, builder.getI64Type(), builder.getI64IntegerAttr(4)); |
| 58 | + globalStride = |
| 59 | + builder.template create<arith::ShRSIOp>(loc, globalStride, four); |
| 60 | + |
| 61 | + int elemTypeEnum; |
| 62 | + switch (elemSize) { |
| 63 | + case 1: { |
| 64 | + elemTypeEnum = 0; |
| 65 | + break; |
| 66 | + } |
| 67 | + case 2: { |
| 68 | + elemTypeEnum = 1; |
| 69 | + break; |
| 70 | + } |
| 71 | + case 4: { |
| 72 | + elemTypeEnum = 2; |
| 73 | + break; |
| 74 | + } |
| 75 | + default: { |
| 76 | + op->emitError() |
| 77 | + << "Tensor descriptor element type must have size 1, 2, or 4 but got " |
| 78 | + << elemSize; |
| 79 | + return failure(); |
| 80 | + } |
| 81 | + } |
| 82 | + |
| 83 | + auto one = mkI32Constant(1); |
| 84 | + builder.template create<triton::ExperimentalTensormapCreateOp>( |
| 85 | + loc, |
| 86 | + /*desc_ptr=*/tmaPtr, |
| 87 | + /*global_address=*/op.getBase(), |
| 88 | + /*box_dim=*/boxDim, |
| 89 | + /*global_dim=*/ValueRange{op.getShape()[1], op.getShape()[0]}, |
| 90 | + /*global_stride=*/ValueRange{globalStride}, |
| 91 | + /*element_strides=*/ValueRange{one, one}, |
| 92 | + /*elem_type*/ builder.getI32IntegerAttr(elemTypeEnum), |
| 93 | + /*interleave_layout*/ builder.getI32IntegerAttr(0), |
| 94 | + /*swizzle_mode=*/builder.getI32IntegerAttr(swizzle_mode), |
| 95 | + /*fill_mode=*/builder.getI32IntegerAttr(0)); |
| 96 | + return success(); |
| 97 | +} |
| 98 | + |
| 99 | +} // namespace mlir::triton::nvidia_gpu |
0 commit comments