@@ -2131,6 +2131,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
21312131/// Generate the signature part of the mma intrinsic name.
21322132class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
21332133 list<WMMA_REGS> id_frags = !cond(
2134+ // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
2135+ !or(!eq(A.ptx_elt_type, "e4m3"),
2136+ !eq(A.ptx_elt_type, "e5m2"),
2137+ !eq(A.ptx_elt_type, "e3m2"),
2138+ !eq(A.ptx_elt_type, "e2m3"),
2139+ !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
21342140 // FP16 ops are identified by accumulator & result type.
21352141 !eq(A.ptx_elt_type, "f16") : [D, C],
21362142 // other ops are identified by input types.
@@ -2257,6 +2263,31 @@ class NVVM_MMA_OPS {
22572263 list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
22582264 tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
22592265 fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
2266+
2267+ list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
2268+ [GEOM<16,8,16>, GEOM<16,8,32>],
2269+ ["bf16"], [], ["f32"], []>.ret;
2270+ list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
2271+ [GEOM<16,8,8>, GEOM<16,8,16>],
2272+ ["tf32"], [], ["f32"], []>.ret;
2273+ list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
2274+ [GEOM<16,8,16>, GEOM<16,8,32>],
2275+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
2276+ list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
2277+ [GEOM<16,8,64>],
2278+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
2279+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
2280+ ["f16", "f32"], ["f16", "f32"]>.ret;
2281+ list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
2282+ [GEOM<16,8,64>, GEOM<16,8,128>],
2283+ ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
2284+ list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
2285+ [GEOM<16,8,32>, GEOM<16,8,64>],
2286+ ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
2287+ list<list<WMMA_REGS>> all_mma_sp_sync_ops = !listconcat(
2288+ bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
2289+ subint_mma_sp_ops, int_mma_sp_ops);
2290+
22602291}
22612292
22622293def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -2362,6 +2393,16 @@ def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
23622393def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
23632394 let assemblyFormat = "`<` $value `>`";
23642395}
2396+ /// MMA kind types (for mixed-precision FP8 operations)
2397+ def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>;
2398+ def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind",
2399+ [MMAKindF8F6F4]> {
2400+ let genSpecializedAttr = 0;
2401+ let cppNamespace = "::mlir::NVVM";
2402+ }
2403+ def MMAKindAttr : EnumAttr<NVVM_Dialect, MMAKind, "mma_kind"> {
2404+ let assemblyFormat = "`<` $value `>`";
2405+ }
23652406
23662407/// Attribute to hold the MMA shape
23672408def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
@@ -2506,12 +2547,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
25062547def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
25072548def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
25082549def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
2550+ def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>;
2551+ def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>;
2552+ def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>;
2553+ def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>;
2554+ def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>;
25092555
25102556def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
25112557 [MMATypeF16, MMATypeF32, MMATypeTF32,
25122558 MMATypeBF16, MMATypeS8, MMATypeU8,
25132559 MMATypeS32, MMATypeS4, MMATypeU4,
2514- MMATypeB1, MMATypeF64]> {
2560+ MMATypeB1, MMATypeF64,
2561+ MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> {
25152562 let genSpecializedAttr = 0;
25162563 let cppNamespace = "::mlir::NVVM";
25172564}
@@ -2948,6 +2995,216 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
29482995 let hasVerifier = 1;
29492996}
29502997
2998+ /// Generate enum value of the mma.sync intrinsic.
2999+ class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
3000+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
3001+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
3002+ string id = "llvm::Intrinsic::nvvm_mma"
3003+ # "_" # !subst("::", "_", Metadata)
3004+ # "_" # A.geom
3005+ # "_row_col"
3006+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
3007+ # !if(Satfinite, "_satfinite", "")
3008+ # signature;
3009+ }
3010+
3011+ // Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
3012+ // false otherwise.
3013+ // E.g.
3014+ // if NVVM_MMA_SP_SUPPORTED<...>.ret then
3015+ // def : FOO<>; // The record will only be defined for supported ops.
3016+ //
3017+ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
3018+ string kind, int satf> {
3019+ // MMA.SP ops check both layouts.
3020+ string a_type = frags[0].ptx_elt_type;
3021+ string b_type = frags[1].ptx_elt_type;
3022+ string c_type = frags[2].ptx_elt_type;
3023+ string d_type = frags[3].ptx_elt_type;
3024+ string geom = frags[0].geom;
3025+
3026+ bit is_int = !or(!eq(a_type, "s8"),
3027+ !eq(a_type, "u8"),
3028+ !eq(a_type, "s4"),
3029+ !eq(a_type, "u4"));
3030+
3031+ bit ret = !cond(
3032+
3033+ // Limit satf to valid types
3034+ !and(!eq(satf, 1),
3035+ !eq(is_int, 0)): false,
3036+
3037+ // f16/bf16/tf32 requires A and B to be the same type.
3038+ !and(!or(!eq(a_type, "f16"),
3039+ !eq(a_type, "bf16"),
3040+ !eq(a_type, "tf32")),
3041+ !ne(a_type, b_type)): false,
3042+
3043+ // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
3044+ !and(!or(!eq(geom, "m16n8k16"),
3045+ !eq(geom, "m16n8k32"),
3046+ !eq(geom, "m16n8k64")),
3047+ !ne(c_type, d_type)): false,
3048+
3049+ !and(!eq(kind, ""),
3050+ !or(!eq(a_type, "e3m2"),
3051+ !eq(a_type, "e2m3"),
3052+ !eq(a_type, "e2m1"),
3053+ !eq(b_type, "e3m2"),
3054+ !eq(b_type, "e2m3"),
3055+ !eq(b_type, "e2m1"))): false,
3056+
3057+ !and(!eq(kind, ""),
3058+ !eq(geom, "m16n8k64"),
3059+ !or(!eq(c_type, "f16"),
3060+ !eq(d_type, "f16"))): false,
3061+
3062+ !and(!ne(kind, ""),
3063+ !or(!eq(metadata, "sp"),
3064+ !ne(geom, "m16n8k64"),
3065+ !eq(is_int, 1))): false,
3066+
3067+ // All other are OK.
3068+ true: true
3069+ );
3070+ }
3071+
3072+ /// Helper to create the mapping between the configuration and the mma.sp.sync
3073+ /// intrinsic enum value.
3074+ class MMA_SP_SYNC_INTR {
3075+ list<list<list<list<string>>>> cond0 =
3076+ !foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
3077+ !foreach(metadata, ["sp", "sp::ordered_metadata"],
3078+ !foreach(kind, ["", "kind::f8f6f4"],
3079+ !foreach (satf, [0, 1],
3080+ !if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
3081+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
3082+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
3083+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
3084+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
3085+ # " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
3086+ # " && (satf.has_value() ? " # satf # " == static_cast<int>(*satf) : true)"
3087+ # " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata") # ")\n"
3088+ # " return " #
3089+ MMA_SP_SYNC_NAME<metadata, kind, satf, op[0], op[1], op[2], op[3]>.id # ";",
3090+ "") // if supported
3091+ ) // satf
3092+ ) // kind
3093+ ) // metadata
3094+ ); // all_mma_sp_sync_ops
3095+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
3096+ !listconcat(acc, el));
3097+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
3098+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
3099+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
3100+ }
3101+
3102+ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
3103+
3104+ let summary = "cooperative sparse matrix-multiply and accumulate";
3105+
3106+ let description = [{
3107+ The `nvvm.mma.sp.sync` operation collectively performs the sparse operation
3108+ `D = matmul(A_sparse, B) + C` using all threads in a warp.
3109+
3110+ This operation is similar to `nvvm.mma.sync` but with structured sparsity
3111+ in the A operand. The sparsity follows the 2:4 structured sparse pattern
3112+ where 2 out of every 4 elements are non-zero.
3113+
3114+ All the threads in the warp must execute the same `mma.sp.sync` operation.
3115+
3116+ The `sparseMetadata` operand provides the sparsity indices that indicate
3117+ which elements in the A operand are non-zero. The `sparsitySelector`
3118+ controls how the indices are distributed among threads in the warp and
3119+ should typically be 0 or 1.
3120+
3121+ The optional `orderedMetadata` attribute specifies the metadata ordering:
3122+ - Absence (default): Uses standard sparse metadata ordering
3123+ - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
3124+
3125+ The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
3126+ - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
3127+ - Only valid with ordered metadata and m16n8k64 shape
3128+
3129+ The shapes, layouts, and data types follow the same constraints as the
3130+ regular `nvvm.mma.sync` operation, but the A operand contains only the
3131+ non-zero elements in compressed format.
3132+
3133+ Example:
3134+ ```mlir
3135+ %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
3136+ sparseMetadata[%meta] selector[%sel]
3137+ {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
3138+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
3139+
3140+ // With ordered metadata:
3141+ %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
3142+ sparseMetadata[%meta] selector[%sel]
3143+ {orderedMetadata, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
3144+ : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
3145+ ```
3146+ }];
3147+
3148+ let results = (outs LLVM_AnyStruct:$res);
3149+ let arguments = (ins NVVM_MMAShapeAttr:$shape,
3150+ OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
3151+ OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
3152+ OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
3153+ UnitAttr:$orderedMetadata,
3154+ OptionalAttr<MMAKindAttr>:$kind,
3155+ Variadic<LLVM_Type>:$operandA,
3156+ Variadic<LLVM_Type>:$operandB,
3157+ Variadic<LLVM_Type>:$operandC,
3158+ I32:$sparseMetadata,
3159+ I32:$sparsitySelector);
3160+
3161+ let extraClassDeclaration = !strconcat([{
3162+ static llvm::Intrinsic::ID getIntrinsicID(
3163+ int64_t m, int64_t n, uint64_t k,
3164+ std::optional<MMAIntOverflow> satf,
3165+ bool orderedMetadata,
3166+ std::optional<MMAKind> kind,
3167+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
3168+ mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
3169+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
3170+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
3171+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
3172+ llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
3173+ }],
3174+ MMA_SP_SYNC_INTR<>.id, [{
3175+ return 0;
3176+ }
3177+
3178+ static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
3179+ bool isAccumulator);
3180+
3181+ MMATypes accumPtxType();
3182+ MMATypes resultPtxType();
3183+
3184+ static mlir::NVVM::IDArgPair
3185+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3186+ llvm::IRBuilderBase& builder);
3187+ }]);
3188+
3189+ let builders = [
3190+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
3191+ "ValueRange":$operandB, "ValueRange":$operandC,
3192+ "Value":$sparseMetadata, "Value":$sparsitySelector,
3193+ "ArrayRef<int64_t>":$shape,
3194+ "std::optional<MMAIntOverflow>":$intOverflow,
3195+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes)>
3196+ ];
3197+
3198+ string llvmBuilder = [{
3199+ auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs(
3200+ *op, moduleTranslation, builder);
3201+ $res = createIntrinsicCall(builder, id, args);
3202+ }];
3203+
3204+ let hasCustomAssemblyFormat = 1;
3205+ let hasVerifier = 1;
3206+ }
3207+
29513208//===----------------------------------------------------------------------===//
29523209// NVVM TMA Ops
29533210//===----------------------------------------------------------------------===//
0 commit comments