Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions include/triton/Conversion/MLIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }

inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
type.isBF16() || type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
type.isFloat8E5M2FNUZ();
inline bool isFloat8(Type type) {
return isa<Float8E4M3B11FNUZType, Float8E4M3FNType, Float8E4M3FNUZType,
Float8E5M2Type, Float8E5M2FNUZType>(type);
}

inline bool isFloat8(Type type) {
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
type.isFloat8E5M2FNUZ();
inline bool isFloat(Type type) {
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
type.isBF16() || llvm::isa<Float8E4M3B11FNUZType>(type) ||
isFloat8(type);
}

inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
Expand Down
8 changes: 4 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,14 +756,14 @@ bool supportMMA(triton::DotOp op, int version) {
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
if (op.getMaxNumImpreciseAcc() < 32 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
Expand All @@ -784,8 +784,8 @@ bool supportMMA(Value value, int version) {
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
// FP8 is not natively supported on all mma versions but it can always be
// promoted to fp16 therefore we can always support it.
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>(elemTy);
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
NvidiaMmaEncodingAttr mmaLayout =
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
if (mmaLayout) {
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
// promote operands for sm < 89 since fp8 mma is not natively supported
// promote operands for sm >= 90 when mma is not v3
if (!isNativeFP8 ||
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
SmallVector<unsigned> validN;

// MMAv3 with larger instruction shape is preferred.
if (eltType.isFloat8E5M2() || eltType.isFloat8E4M3FN() ||
eltType.isFloat8E4M3FNUZ() || eltType.isF16() || eltType.isBF16() ||
eltType.isF32()) {
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3FNUZType>(
eltType) ||
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() {
const auto &d = getD();
auto aTensorTy = cast<triton::gpu::TensorOrMemDesc>(a.getType());
auto aElTy = cast<triton::gpu::TensorOrMemDesc>(a.getType()).getElementType();
bool isFP8 = aElTy.isFloat8E5M2() || aElTy.isFloat8E4M3FN() ||
aElTy.isFloat8E5M2FNUZ() || aElTy.isFloat8E4M3FNUZ();
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>(aElTy);
bool accFP32 =
cast<triton::gpu::TensorOrMemDesc>(d.getType()).getElementType().isF32();
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();
Expand Down
15 changes: 7 additions & 8 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,17 +1019,16 @@ struct FpToFpOpConversion
return outVals;
}
size_t numElements = 4;
if (srcElementType.isFloat8E4M3FN() || dstElementType.isFloat8E4M3FN() ||
srcElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E4M3FNUZ() ||
srcElementType.isFloat8E5M2FNUZ() ||
dstElementType.isFloat8E5M2FNUZ()) {
if (llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
srcElementType) ||
llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
dstElementType)) {
numElements = 2;
}
bool useFP16IntermediateSrc =
srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 &&
(dstElementType.isFloat8E4M3FNUZ() ||
dstElementType.isFloat8E5M2FNUZ()));
srcElementType.isF32() &&
!(isaFamily == AMD::ISAFamily::CDNA3 &&
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
bool isDstFP32 = dstElementType.isF32();
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
Type dstType = isDstFP32 ? f16_ty : dstElementType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
// store instructions, except for fp8 matmul kernels due to regression
// TODO (lixun): investigate the regression and enable this feature again
auto aElemTy = mfmaInstr.getElementTypeA();
bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ();
bool isFP8 = llvm::isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(aElemTy);
bool isTransposed = isChainDot(dotOp) || !isFP8;
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
oldRetType.getContext(),
Expand Down
15 changes: 10 additions & 5 deletions third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,24 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA,
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
return MfmaTypeId::I8TyId;
}
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
if (llvm::isa<Float8E4M3FNUZType>(dataTypeA) &&
llvm::isa<Float8E4M3FNUZType>(dataTypeB)) {
return MfmaTypeId::Fp8Fp8TyId;
}
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
if (llvm::isa<Float8E4M3FNUZType>(dataTypeA) &&
llvm::isa<Float8E5M2FNUZType>(dataTypeB)) {
return MfmaTypeId::Fp8Bf8TyId;
}
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
if (llvm::isa<Float8E5M2FNUZType>(dataTypeA) &&
llvm::isa<Float8E4M3FNUZType>(dataTypeB)) {
return MfmaTypeId::Bf8Fp8TyId;
}
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
if (llvm::isa<Float8E5M2FNUZType>(dataTypeA) &&
llvm::isa<Float8E5M2FNUZType>(dataTypeB)) {
return MfmaTypeId::Bf8Bf8TyId;
}
if (dataTypeA.isFloat8E5M2() && dataTypeB.isFloat8E5M2()) {
if (llvm::isa<Float8E5M2Type>(dataTypeA) &&
llvm::isa<Float8E5M2Type>(dataTypeB)) {
return MfmaTypeId::Fp16TyId;
}
llvm_unreachable("Unsupported input argument type.");
Expand Down
24 changes: 10 additions & 14 deletions third_party/intel/lib/Analysis/DPAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ DPASAnalysis::getDPASType(OpTy op) {
if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32)
return DPASEngineType::FP32_FP32_TF32_TF32;
// For FP8XFP8->FP32, upcast to FP16
if (aElemTy.isFloat8E5M2())
if (isa<Float8E5M2Type>(aElemTy))
return DPASEngineType::FP32_FP32_FP16_FP16;
if (aElemTy.isFloat8E4M3FN())
if (isa<Float8E4M3FNType>(aElemTy))
return DPASEngineType::FP32_FP32_FP16_FP16;
} else if (dElemTy.isF16()) {
if (aElemTy.isF16())
Expand All @@ -147,36 +147,32 @@ DPASAnalysis::getDPASType(OpTy op) {

if (isa<FloatType>(dElemTy)) {
if (dElemTy.isF32()) {
if (aElemTy.isBF16() &&
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
if (aElemTy.isBF16() && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
return DPASEngineType::FP32_FP32_BF16_FP8;
// 2 E2M1 are packed into 1 int8
if (aElemTy.isBF16() && bElemTy.isInteger(8))
return DPASEngineType::FP32_FP32_BF16_FP4;
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
bElemTy.isBF16())
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isBF16())
return DPASEngineType::FP32_FP32_FP8_BF16;
if (aElemTy.isF16() &&
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
if (aElemTy.isF16() && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
return DPASEngineType::FP32_FP32_FP16_FP8;
// 2 E2M1 are packed into 1 int8
if (aElemTy.isF16() && bElemTy.isInteger(8))
return DPASEngineType::FP32_FP32_FP16_FP4;
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
bElemTy.isF16())
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isF16())
return DPASEngineType::FP32_FP32_FP8_FP16;
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) &&
isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
return DPASEngineType::FP32_FP32_FP8_FP8;
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) &&
bElemTy.isInteger(8))
return DPASEngineType::FP32_FP32_FP8_FP4;
if (aElemTy.isInteger(8) && bElemTy.isBF16())
return DPASEngineType::FP32_FP32_FP4_BF16;
if (aElemTy.isInteger(8) && bElemTy.isF16())
return DPASEngineType::FP32_FP32_FP4_FP16;
if (aElemTy.isInteger(8) &&
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
return DPASEngineType::FP32_FP32_FP4_FP8;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr");

unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth();
if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN())
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(elemType))
dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16.

return DPASCapability::opsChanBitWidths / dpasElemBitWidths;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ struct FpToFpOpConversion
auto dstElementType = getElementType(op.getResult());
auto roundingMode = op.getRounding();

if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
if (isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)) {
assert(roundingMode.has_value() &&
"Rounding mode must be specified for conversions to fp8");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
oldAType.getElementType().getIntOrFloatBitWidth();

// We are upcasting FP8 to FP16
if (oldAType.getElementType().isFloat8E5M2() ||
oldAType.getElementType().isFloat8E4M3FN())
if (isa<Float8E5M2Type, Float8E4M3FNType>(oldAType.getElementType()))
dpasElemBitWidths = 2 * dpasElemBitWidths;

// Enlarge the repCluster size to use the large 2D load for A and B
Expand Down Expand Up @@ -488,7 +487,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {

Type promoteType;
if (dpasLayout) {
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
bool isNativeFP8 = isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
// fp8 is not natively supported by the the DPAS instruction, promote it
// to fp16.
if (!isNativeFP8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,17 @@ TensorCoreType getMmaType(triton::DotOp op) {
return TensorCoreType::FP32_FP16_FP16_FP32;
if (aTy.getElementType().isBF16() && bTy.getElementType().isBF16())
return TensorCoreType::FP32_BF16_BF16_FP32;
if (aTy.getElementType().isFloat8E5M2() &&
bTy.getElementType().isFloat8E5M2())
if (llvm::isa<Float8E5M2Type>(aTy.getElementType()) &&
llvm::isa<Float8E5M2Type>(bTy.getElementType()))
return TensorCoreType::FP32_FP8E5M2_FP8E5M2_FP32;
if (aTy.getElementType().isFloat8E5M2() &&
bTy.getElementType().isFloat8E4M3FN())
if (llvm::isa<Float8E5M2Type>(aTy.getElementType()) &&
llvm::isa<Float8E4M3FNType>(bTy.getElementType()))
return TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32;
if (aTy.getElementType().isFloat8E4M3FN() &&
bTy.getElementType().isFloat8E5M2())
if (llvm::isa<Float8E4M3FNType>(aTy.getElementType()) &&
llvm::isa<Float8E5M2Type>(bTy.getElementType()))
return TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32;
if (aTy.getElementType().isFloat8E4M3FN() &&
bTy.getElementType().isFloat8E4M3FN())
if (llvm::isa<Float8E4M3FNType>(aTy.getElementType()) &&
llvm::isa<Float8E4M3FNType>(bTy.getElementType()))
return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32;
if (aTy.getElementType().isF32() && bTy.getElementType().isF32() &&
op.getInputPrecision() == InputPrecision::TF32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 };
inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB,
Type scaleAType, Type scaleBType) {
if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) {
if (scaleAType.isFloat8E4M3FN() && scaleBType.isFloat8E4M3FN()) {
if (llvm::isa<Float8E4M3FNType>(scaleAType) &&
llvm::isa<Float8E4M3FNType>(scaleBType)) {
return mxfpKind::mxf4nvf4;
}
return mxfpKind::mxf4;
Expand Down Expand Up @@ -100,9 +101,9 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter,
return 1;
if (type.isF32())
return 2;
if (type.isFloat8E4M3FN())
if (llvm::isa<Float8E4M3FNType>(type))
return 0;
if (type.isFloat8E5M2())
if (llvm::isa<Float8E5M2Type>(type))
return 1;
llvm_unreachable("Unsupported type.");
};
Expand Down Expand Up @@ -224,7 +225,7 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc,
opcode += "f16";
else if (srcElementTy.isF32())
opcode += "tf32";
else if (srcElementTy.isFloat8E4M3FN() || srcElementTy.isFloat8E5M2())
else if (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementTy))
opcode += "f8f6f4";
else
assert(0 && "Unsupported type.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ triton::nvgpu::WGMMAEltType getMmaOperandType(Value a, bool allowTF32) {
return triton::nvgpu::WGMMAEltType::tf32;
} else if (aTy.isInteger(8)) {
return triton::nvgpu::WGMMAEltType::s8;
} else if (aTy.isFloat8E5M2()) {
} else if (llvm::isa<Float8E5M2Type>(aTy)) {
return triton::nvgpu::WGMMAEltType::e5m2;
} else if (aTy.isFloat8E4M3FN()) {
} else if (llvm::isa<Float8E4M3FNType>(aTy)) {
return triton::nvgpu::WGMMAEltType::e4m3;
} else {
llvm::report_fatal_error("Unsupported mma operand type found");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ struct FpToFpOpConversion
llvm::errs() << "\n";
llvm::report_fatal_error("Unsupported rounding mode for conversion.");
}
if (computeCapability < 89 &&
(srcTy.isFloat8E4M3FN() || dstTy.isFloat8E4M3FN())) {
if (computeCapability < 89 && (llvm::isa<Float8E4M3FNType>(srcTy) ||
llvm::isa<Float8E4M3FNType>(dstTy))) {
llvm::errs() << "Conversion from/to f8e4m3nv is only supported on "
"compute capability >= 89"
<< "\n";
Expand All @@ -489,7 +489,7 @@ struct FpToFpOpConversion
auto dstElementType = getElementType(op.getResult());
auto roundingMode = op.getRounding();

if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)) {
assert(roundingMode.has_value() &&
"Rounding mode must be specified for convertsions to fp8");

Expand Down Expand Up @@ -526,8 +526,8 @@ struct FpToFpOpConversion

bool useFP16IntermediateSrc =
srcElementType.isF32() &&
(!(computeCapability >= 90 && (dstElementType.isFloat8E4M3FN() ||
dstElementType.isFloat8E5M2())) ||
(!(computeCapability >= 90 &&
(llvm::isa<Float8E4M3FNType, Float8E5M2Type>(dstElementType))) ||
roundingMode.value() == RoundingMode::RTZ);
bool isDstFP32 = dstElementType.isF32();
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
Expand Down