@@ -296,6 +296,9 @@ enum class TensorCoreType : uint8_t {
296296 FP32_FP8E4M3FN_FP8E5M2_FP32_SCALE_VEC_1X,
297297 FP32_FP8E4M3FN_FP8E4M3FN_FP32_SCALE_VEC_1X,
298298 //
299+ FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X,
300+ FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X,
301+ //
299302 NOT_APPLICABLE,
300303};
301304
@@ -339,6 +342,8 @@ static Type getMmaRetType(TensorCoreType mmaType, MLIRContext *ctx) {
339342 case TensorCoreType::FP32_FP8E5M2_FP8E4M3FN_FP32_SCALE_VEC_1X:
340343 case TensorCoreType::FP32_FP8E4M3FN_FP8E5M2_FP32_SCALE_VEC_1X:
341344 case TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32_SCALE_VEC_1X:
345+ case TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X:
346+ case TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X:
342347 return fp32x4Ty;
343348 default :
344349 llvm::report_fatal_error (" Unsupported mma type found" );
@@ -367,6 +372,15 @@ static TensorCoreType getMmaTypeDotScaled(DotScaledOp op, RankedTensorType aTy,
367372 llvm::isa<Float8E4M3FNType>(bTy.getElementType ())) {
368373 return TensorCoreType::FP32_FP8E4M3FN_FP8E4M3FN_FP32_SCALE_VEC_1X;
369374 }
375+ if (op.getBElemType () == ScaleDotElemType::E2M1 &&
376+ op.getAElemType () == ScaleDotElemType::E2M1) {
377+ if (isa<mlir::Float8E4M3FNType>(
378+ op.getBScale ().getType ().getElementType ())) {
379+ return TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X;
380+ } else {
381+ return TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X;
382+ }
383+ }
370384 }
371385 return TensorCoreType::NOT_APPLICABLE;
372386}
@@ -493,6 +507,14 @@ inline static const std::map<TensorCoreType, std::string> mmaInstrPtxScaled = {
493507 " mma.sync.aligned.m16n8k32.row.col."
494508 " kind::mxf8f6f4.block_scale.scale_vec::"
495509 " 1X.f32.e4m3.e4m3.f32.ue8m0" },
510+ {TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X,
511+ " mma.sync.aligned.m16n8k64.row.col."
512+ " kind::mxf4nvf4.block_scale.scale_vec::"
513+ " 2X.f32.e2m1.e2m1.f32.ue8m0" },
514+ {TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X,
515+ " mma.sync.aligned.m16n8k64.row.col."
516+ " kind::mxf4nvf4.block_scale.scale_vec::"
517+ " 4X.f32.e2m1.e2m1.f32.ue4m3" },
496518};
497519
498520static void callMmaTuringInt8 (PTXBuilder &builder, int b,
@@ -890,13 +912,12 @@ LogicalResult convertMMADotScaled(triton::DotScaledOp op,
890912 TensorCoreType mmaType =
891913 getMmaTypeDotScaled (op, aTensorTy, bTensorTy, dTensorTy);
892914
893- NumRegisters numRegisters = {2 , 1 , 2 };
894-
895915 SmallVector<Value> unpackedAScale =
896916 unpackLLElements (op.getLoc (), adaptor.getAScale (), rewriter);
897917 SmallVector<Value> unpackedBScale =
898918 unpackLLElements (op.getLoc (), adaptor.getBScale (), rewriter);
899919
920+ NumRegisters numRegisters = {2 , 1 , 2 };
900921 EmitMmaCallback emit = [&](PTXBuilder &builder, int b, int m, int n, int k,
901922 mlir::triton::PTXInstr &mma, unsigned numMmaRets,
902923 unsigned colsPerThread, unsigned batchOffset,
@@ -906,8 +927,34 @@ LogicalResult convertMMADotScaled(triton::DotScaledOp op,
906927 auto tb = TritonLLVMOpBuilder (op.getLoc (), rewriter);
907928 auto i32 = IntegerType::get (op->getContext (), 32 );
908929
909- Value aScaleValue = tb.zext (i32 , unpackedAScale[m * repK + k]);
910- Value bScaleValue = tb.zext (i32 , unpackedBScale[n * repK + k]);
930+ auto packElements = [&](ArrayRef<Value> bytes, int loc,
931+ int numBytes) -> Value {
932+ Value packed = tb.zext (i32 , bytes[loc]);
933+ for (int i = 1 ; i < numBytes; ++i) {
934+ Value byte = tb.zext (i32 , bytes[loc + i]);
935+ Value shifted = tb.shl (byte, tb.i32_val (i * 8 ));
936+ packed = tb.or_ (packed, shifted);
937+ }
938+ return packed;
939+ };
940+
941+ int scaleVecMode;
942+ if (mmaInstrPtxScaled.at (mmaType).find (" 1X" ) != std::string::npos) {
943+ scaleVecMode = 1 ;
944+ } else if (mmaType ==
945+ TensorCoreType::FP32_FP4E2M1_FP4E2M1_FP32_SCALE_VEC_2X) {
946+ scaleVecMode = 2 ;
947+ } else if (mmaType == TensorCoreType::FP32_NVFP4_NVFP4_FP32_SCALE_VEC_4X) {
948+ scaleVecMode = 4 ;
949+ } else {
950+ llvm_unreachable (" Unsupported scale vector mode!" );
951+ }
952+ Value aScaleValue =
953+ packElements (unpackedAScale, m * repK * scaleVecMode + k * scaleVecMode,
954+ scaleVecMode);
955+ Value bScaleValue =
956+ packElements (unpackedBScale, n * repK * scaleVecMode + k * scaleVecMode,
957+ scaleVecMode);
911958
912959 BaseOffset base{numRegisters.m * m, numRegisters.n * n, numRegisters.k * k};
913960 callMmaScaled (builder, b, base, mma, numMmaRets, colsPerThread, aTable,
0 commit comments