Skip to content

Commit a7c8505

Browse files
authored
[MLIR] Supported sparse MMA intrinsics in the MLIR->NVVM IR->NVPTX flow (#168686)
This change adds sparse MMA intrinsics to the MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.
1 parent 8bfca26 commit a7c8505

File tree

5 files changed

+1756
-1
lines changed

5 files changed

+1756
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 258 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
21312131
/// Generate the signature part of the mma intrinsic name.
21322132
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
21332133
list<WMMA_REGS> id_frags = !cond(
2134+
// FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
2135+
!or(!eq(A.ptx_elt_type, "e4m3"),
2136+
!eq(A.ptx_elt_type, "e5m2"),
2137+
!eq(A.ptx_elt_type, "e3m2"),
2138+
!eq(A.ptx_elt_type, "e2m3"),
2139+
!eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
21342140
// FP16 ops are identified by accumulator & result type.
21352141
!eq(A.ptx_elt_type, "f16") : [D, C],
21362142
// other ops are identified by input types.
@@ -2257,6 +2263,31 @@ class NVVM_MMA_OPS {
22572263
list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
22582264
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
22592265
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
2266+
2267+
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
2268+
[GEOM<16,8,16>, GEOM<16,8,32>],
2269+
["bf16"], [], ["f32"], []>.ret;
2270+
list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
2271+
[GEOM<16,8,8>, GEOM<16,8,16>],
2272+
["tf32"], [], ["f32"], []>.ret;
2273+
list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
2274+
[GEOM<16,8,16>, GEOM<16,8,32>],
2275+
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
2276+
list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
2277+
[GEOM<16,8,64>],
2278+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
2279+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
2280+
["f16", "f32"], ["f16", "f32"]>.ret;
2281+
list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
2282+
[GEOM<16,8,64>, GEOM<16,8,128>],
2283+
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
2284+
list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
2285+
[GEOM<16,8,32>, GEOM<16,8,64>],
2286+
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
2287+
list<list<WMMA_REGS>> all_mma_sp_sync_ops = !listconcat(
2288+
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
2289+
subint_mma_sp_ops, int_mma_sp_ops);
2290+
22602291
}
22612292

22622293
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -2362,6 +2393,16 @@ def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
23622393
def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
23632394
let assemblyFormat = "`<` $value `>`";
23642395
}
2396+
/// MMA kind types (for mixed-precision FP8 operations)
2397+
def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>;
2398+
def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind",
2399+
[MMAKindF8F6F4]> {
2400+
let genSpecializedAttr = 0;
2401+
let cppNamespace = "::mlir::NVVM";
2402+
}
2403+
def MMAKindAttr : EnumAttr<NVVM_Dialect, MMAKind, "mma_kind"> {
2404+
let assemblyFormat = "`<` $value `>`";
2405+
}
23652406

23662407
/// Attribute to hold the MMA shape
23672408
def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
@@ -2506,12 +2547,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
25062547
def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
25072548
def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
25082549
def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
2550+
def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>;
2551+
def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>;
2552+
def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>;
2553+
def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>;
2554+
def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>;
25092555

25102556
def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
25112557
[MMATypeF16, MMATypeF32, MMATypeTF32,
25122558
MMATypeBF16, MMATypeS8, MMATypeU8,
25132559
MMATypeS32, MMATypeS4, MMATypeU4,
2514-
MMATypeB1, MMATypeF64]> {
2560+
MMATypeB1, MMATypeF64,
2561+
MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> {
25152562
let genSpecializedAttr = 0;
25162563
let cppNamespace = "::mlir::NVVM";
25172564
}
@@ -2948,6 +2995,216 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
29482995
let hasVerifier = 1;
29492996
}
29502997

2998+
/// Generate enum value of the mma.sync intrinsic.
2999+
class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
3000+
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
3001+
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
3002+
string id = "llvm::Intrinsic::nvvm_mma"
3003+
# "_" # !subst("::", "_", Metadata)
3004+
# "_" # A.geom
3005+
# "_row_col"
3006+
# !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
3007+
# !if(Satfinite, "_satfinite", "")
3008+
# signature;
3009+
}
3010+
3011+
// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
3012+
// false otherwise.
3013+
// E.g.
3014+
// if NVVM_MMA_SP_SUPPORTED<...>.ret then
3015+
// def : FOO<>; // The record will only be defined for supported ops.
3016+
//
3017+
class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
3018+
string kind, int satf> {
3019+
// MMA.SP ops check both layouts.
3020+
string a_type = frags[0].ptx_elt_type;
3021+
string b_type = frags[1].ptx_elt_type;
3022+
string c_type = frags[2].ptx_elt_type;
3023+
string d_type = frags[3].ptx_elt_type;
3024+
string geom = frags[0].geom;
3025+
3026+
bit is_int = !or(!eq(a_type, "s8"),
3027+
!eq(a_type, "u8"),
3028+
!eq(a_type, "s4"),
3029+
!eq(a_type, "u4"));
3030+
3031+
bit ret = !cond(
3032+
3033+
// Limit satf to valid types
3034+
!and(!eq(satf, 1),
3035+
!eq(is_int, 0)): false,
3036+
3037+
// f16/bf16/tf32 requires A and B to be the same type.
3038+
!and(!or(!eq(a_type, "f16"),
3039+
!eq(a_type, "bf16"),
3040+
!eq(a_type, "tf32")),
3041+
!ne(a_type, b_type)): false,
3042+
3043+
// m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
3044+
!and(!or(!eq(geom, "m16n8k16"),
3045+
!eq(geom, "m16n8k32"),
3046+
!eq(geom, "m16n8k64")),
3047+
!ne(c_type, d_type)): false,
3048+
3049+
!and(!eq(kind, ""),
3050+
!or(!eq(a_type, "e3m2"),
3051+
!eq(a_type, "e2m3"),
3052+
!eq(a_type, "e2m1"),
3053+
!eq(b_type, "e3m2"),
3054+
!eq(b_type, "e2m3"),
3055+
!eq(b_type, "e2m1"))): false,
3056+
3057+
!and(!eq(kind, ""),
3058+
!eq(geom, "m16n8k64"),
3059+
!or(!eq(c_type, "f16"),
3060+
!eq(d_type, "f16"))): false,
3061+
3062+
!and(!ne(kind, ""),
3063+
!or(!eq(metadata, "sp"),
3064+
!ne(geom, "m16n8k64"),
3065+
!eq(is_int, 1))): false,
3066+
3067+
// All other are OK.
3068+
true: true
3069+
);
3070+
}
3071+
3072+
/// Helper to create the mapping between the configuration and the mma.sp.sync
3073+
/// intrinsic enum value.
3074+
class MMA_SP_SYNC_INTR {
3075+
list<list<list<list<string>>>> cond0 =
3076+
!foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
3077+
!foreach(metadata, ["sp", "sp::ordered_metadata"],
3078+
!foreach(kind, ["", "kind::f8f6f4"],
3079+
!foreach (satf, [0, 1],
3080+
!if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
3081+
"if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
3082+
# " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
3083+
# " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
3084+
# " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
3085+
# " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
3086+
# " && (satf.has_value() ? " # satf # " == static_cast<int>(*satf) : true)"
3087+
# " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata") # ")\n"
3088+
# " return " #
3089+
MMA_SP_SYNC_NAME<metadata, kind, satf, op[0], op[1], op[2], op[3]>.id # ";",
3090+
"") // if supported
3091+
) // satf
3092+
) // kind
3093+
) // metadata
3094+
); // all_mma_sp_sync_ops
3095+
list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
3096+
!listconcat(acc, el));
3097+
list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
3098+
list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
3099+
string id = !foldl("", f3, acc, el, acc # "\n" # el);
3100+
}
3101+
3102+
def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
3103+
3104+
let summary = "cooperative sparse matrix-multiply and accumulate";
3105+
3106+
let description = [{
3107+
The `nvvm.mma.sp.sync` operation collectively performs the sparse operation
3108+
`D = matmul(A_sparse, B) + C` using all threads in a warp.
3109+
3110+
This operation is similar to `nvvm.mma.sync` but with structured sparsity
3111+
in the A operand. The sparsity follows the 2:4 structured sparse pattern
3112+
where 2 out of every 4 elements are non-zero.
3113+
3114+
All the threads in the warp must execute the same `mma.sp.sync` operation.
3115+
3116+
The `sparseMetadata` operand provides the sparsity indices that indicate
3117+
which elements in the A operand are non-zero. The `sparsitySelector`
3118+
controls how the indices are distributed among threads in the warp and
3119+
should typically be 0 or 1.
3120+
3121+
The optional `orderedMetadata` attribute specifies the metadata ordering:
3122+
- Absence (default): Uses standard sparse metadata ordering
3123+
- Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
3124+
3125+
The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
3126+
- `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
3127+
- Only valid with ordered metadata and m16n8k64 shape
3128+
3129+
The shapes, layouts, and data types follow the same constraints as the
3130+
regular `nvvm.mma.sync` operation, but the A operand contains only the
3131+
non-zero elements in compressed format.
3132+
3133+
Example:
3134+
```mlir
3135+
%d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
3136+
sparseMetadata[%meta] selector[%sel]
3137+
{shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
3138+
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
3139+
3140+
// With ordered metadata:
3141+
%d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
3142+
sparseMetadata[%meta] selector[%sel]
3143+
{orderedMetadata, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
3144+
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
3145+
```
3146+
}];
3147+
3148+
let results = (outs LLVM_AnyStruct:$res);
3149+
let arguments = (ins NVVM_MMAShapeAttr:$shape,
3150+
OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
3151+
OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
3152+
OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
3153+
UnitAttr:$orderedMetadata,
3154+
OptionalAttr<MMAKindAttr>:$kind,
3155+
Variadic<LLVM_Type>:$operandA,
3156+
Variadic<LLVM_Type>:$operandB,
3157+
Variadic<LLVM_Type>:$operandC,
3158+
I32:$sparseMetadata,
3159+
I32:$sparsitySelector);
3160+
3161+
let extraClassDeclaration = !strconcat([{
3162+
static llvm::Intrinsic::ID getIntrinsicID(
3163+
int64_t m, int64_t n, uint64_t k,
3164+
std::optional<MMAIntOverflow> satf,
3165+
bool orderedMetadata,
3166+
std::optional<MMAKind> kind,
3167+
mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
3168+
mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
3169+
llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
3170+
llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
3171+
llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
3172+
llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
3173+
}],
3174+
MMA_SP_SYNC_INTR<>.id, [{
3175+
return 0;
3176+
}
3177+
3178+
static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
3179+
bool isAccumulator);
3180+
3181+
MMATypes accumPtxType();
3182+
MMATypes resultPtxType();
3183+
3184+
static mlir::NVVM::IDArgPair
3185+
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
3186+
llvm::IRBuilderBase& builder);
3187+
}]);
3188+
3189+
let builders = [
3190+
OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
3191+
"ValueRange":$operandB, "ValueRange":$operandC,
3192+
"Value":$sparseMetadata, "Value":$sparsitySelector,
3193+
"ArrayRef<int64_t>":$shape,
3194+
"std::optional<MMAIntOverflow>":$intOverflow,
3195+
"std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes)>
3196+
];
3197+
3198+
string llvmBuilder = [{
3199+
auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs(
3200+
*op, moduleTranslation, builder);
3201+
$res = createIntrinsicCall(builder, id, args);
3202+
}];
3203+
3204+
let hasCustomAssemblyFormat = 1;
3205+
let hasVerifier = 1;
3206+
}
3207+
29513208
//===----------------------------------------------------------------------===//
29523209
// NVVM TMA Ops
29533210
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)