|
| 1 | +#include "mlir/IR/BuiltinTypes.h" |
| 2 | +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" |
| 3 | +#include "triton/Dialect/Triton/IR/Dialect.h" |
| 4 | +#include "triton/Dialect/TritonGPU/IR/Attributes.h" |
| 5 | +#include "triton/Dialect/TritonGPU/IR/Dialect.h" |
| 6 | +#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h" |
| 7 | +#include "llvm/Support/raw_ostream.h" |
| 8 | + |
| 9 | +#define GET_OP_CLASSES |
| 10 | +#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc" |
| 11 | + |
| 12 | +namespace mlir::triton::gpu { |
| 13 | + |
| 14 | +LogicalResult UpcastMXFPOp::verify() { |
| 15 | + auto fpType = getFpType(); |
| 16 | + |
| 17 | + auto xTy = getSrc().getType(); |
| 18 | + auto scaleTy = getScale().getType(); |
| 19 | + |
| 20 | + if (xTy.getElementType() != FloatType::getBF16(getContext())) { |
| 21 | + return emitOpError("element type of the first operand must be bf16"); |
| 22 | + } |
| 23 | + |
| 24 | + if (scaleTy.getElementType() != IntegerType::get(getContext(), 8)) { |
| 25 | + return emitOpError("element type of the second operand must be uint8"); |
| 26 | + } |
| 27 | + |
| 28 | + auto xShape = xTy.getShape(); |
| 29 | + auto scaleShape = scaleTy.getShape(); |
| 30 | + |
| 31 | + if (xShape.size() != scaleShape.size() || xShape.size() < 2) { |
| 32 | + return emitOpError( |
| 33 | + "operands must have the same number of dimensions, at least 2"); |
| 34 | + } |
| 35 | + |
| 36 | + if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 || |
| 37 | + fpType == F8F6F4Type::E5M2)) { |
| 38 | + return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2"); |
| 39 | + } |
| 40 | + |
| 41 | + // Change to support fp8 types |
| 42 | + const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1; |
| 43 | + |
| 44 | + if (xShape.back() != (32 / elems_packed) * scaleShape.back()) { |
| 45 | + return emitOpError("last dimension of first operand must be 16 times " |
| 46 | + "larger than that of the second operand"); |
| 47 | + } |
| 48 | + |
| 49 | + if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) { |
| 50 | + return emitOpError( |
| 51 | + "all dimensions except the last must match between operands"); |
| 52 | + } |
| 53 | + |
| 54 | + auto layoutX = xTy.getEncoding(); |
| 55 | + if (!layoutX || !isa<DotOperandEncodingAttr>(layoutX)) { |
| 56 | + return emitOpError("Expected a DotOperandEncodingAttr for values"); |
| 57 | + } |
| 58 | + auto layoutScale = scaleTy.getEncoding(); |
| 59 | + if (!layoutScale || !isa<BlockedEncodingAttr>(layoutScale)) { |
| 60 | + return emitOpError("Expected a BlockOperandEncoding for scales"); |
| 61 | + } |
| 62 | + auto blockedScale = cast<BlockedEncodingAttr>(layoutScale); |
| 63 | + |
| 64 | + // Necessary to keep all of the scales of a given block of values in the same |
| 65 | + // warp |
| 66 | + auto threadsPerWarp = blockedScale.getThreadsPerWarp(); |
| 67 | + if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) { |
| 68 | + return emitOpError("Expected threads per warp to be {16, 2}"); |
| 69 | + } |
| 70 | + |
| 71 | + return success(); |
| 72 | +} |
| 73 | + |
| 74 | +LogicalResult UpcastMXFPOp::inferReturnTypes( |
| 75 | + MLIRContext *context, std::optional<Location> location, ValueRange operands, |
| 76 | + DictionaryAttr attributes, OpaqueProperties opaqueProperties, |
| 77 | + RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) { |
| 78 | + auto xTy = cast<RankedTensorType>(operands[0].getType()); |
| 79 | + auto properties = opaqueProperties.as<const Properties *>(); |
| 80 | + auto typeEncoded = properties->fp_type.getValue(); |
| 81 | + auto xShape = xTy.getShape(); |
| 82 | + |
| 83 | + auto encoding = xTy.getEncoding(); |
| 84 | + if (!encoding) { |
| 85 | + return emitOptionalError(location, "expected an encoding"); |
| 86 | + } |
| 87 | + if (!mlir::isa<DotOperandEncodingAttr>(encoding)) { |
| 88 | + return emitOptionalError(location, "expected an mma layout encoding"); |
| 89 | + } |
| 90 | + if (xShape.size() < 2) { |
| 91 | + return emitOptionalError(location, "tensor rank must be at least 2"); |
| 92 | + } |
| 93 | + |
| 94 | + // For now we just return the input encoding. For fp4 we'll need to cast from |
| 95 | + // tf32 to fp16 encoding and multiply the shape by two |
| 96 | + assert((typeEncoded == F8F6F4Type::E4M3 || typeEncoded == F8F6F4Type::E5M2) && |
| 97 | + "NYI: only fp8e4m3 and fp8e5m2 are supported"); |
| 98 | + |
| 99 | + inferredReturnTypes.push_back(xTy); |
| 100 | + return success(); |
| 101 | +} |
| 102 | + |
| 103 | +} // namespace mlir::triton::gpu |
0 commit comments