Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 258 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
/// Generate the signature part of the mma intrinsic name.
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
// FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
!or(!eq(A.ptx_elt_type, "e4m3"),
!eq(A.ptx_elt_type, "e5m2"),
!eq(A.ptx_elt_type, "e3m2"),
!eq(A.ptx_elt_type, "e2m3"),
!eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
// FP16 ops are identified by accumulator & result type.
!eq(A.ptx_elt_type, "f16") : [D, C],
// other ops are identified by input types.
Expand Down Expand Up @@ -2257,6 +2263,31 @@ class NVVM_MMA_OPS {
list<list<WMMA_REGS>> all_mma_sync_ops = !listconcat(
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);

list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
[GEOM<16,8,16>, GEOM<16,8,32>],
["bf16"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
[GEOM<16,8,8>, GEOM<16,8,16>],
["tf32"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
[GEOM<16,8,16>, GEOM<16,8,32>],
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
[GEOM<16,8,64>],
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
["f16", "f32"], ["f16", "f32"]>.ret;
list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
[GEOM<16,8,64>, GEOM<16,8,128>],
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
[GEOM<16,8,32>, GEOM<16,8,64>],
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_mma_sp_sync_ops = !listconcat(
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);

}

def NVVM_MMA_OPS : NVVM_MMA_OPS;
Expand Down Expand Up @@ -2362,6 +2393,16 @@ def MMAIntOverflow : I32EnumAttr<"MMAIntOverflow", "MMA overflow options",
def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflow"> {
let assemblyFormat = "`<` $value `>`";
}
/// MMA kind types (for mixed-precision FP8 operations)
def MMAKindF8F6F4 : I32EnumAttrCase<"f8f6f4", 0>;
def MMAKind : I32EnumAttr<"MMAKind", "MMA operation kind",
[MMAKindF8F6F4]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def MMAKindAttr : EnumAttr<NVVM_Dialect, MMAKind, "mma_kind"> {
let assemblyFormat = "`<` $value `>`";
}

/// Attribute to hold the MMA shape
def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
Expand Down Expand Up @@ -2506,12 +2547,18 @@ def MMATypeU4 : I32EnumAttrCase<"u4", 7>;
def MMATypeS4 : I32EnumAttrCase<"s4", 8>;
def MMATypeBF16 : I32EnumAttrCase<"bf16", 9>;
def MMATypeF64 : I32EnumAttrCase<"f64", 10>;
def MMATypeE4M3 : I32EnumAttrCase<"e4m3", 11>;
def MMATypeE5M2 : I32EnumAttrCase<"e5m2", 12>;
def MMATypeE3M2 : I32EnumAttrCase<"e3m2", 13>;
def MMATypeE2M3 : I32EnumAttrCase<"e2m3", 14>;
def MMATypeE2M1 : I32EnumAttrCase<"e2m1", 15>;

def MMATypes : I32EnumAttr<"MMATypes", "NVVM MMA types",
[MMATypeF16, MMATypeF32, MMATypeTF32,
MMATypeBF16, MMATypeS8, MMATypeU8,
MMATypeS32, MMATypeS4, MMATypeU4,
MMATypeB1, MMATypeF64]> {
MMATypeB1, MMATypeF64,
MMATypeE4M3, MMATypeE5M2, MMATypeE3M2, MMATypeE2M3, MMATypeE2M1]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
Expand Down Expand Up @@ -2948,6 +2995,216 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
}

/// Generate enum value of the mma.sync intrinsic.
class MMA_SP_SYNC_NAME<string Metadata, string Kind, int Satfinite,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string id = "llvm::Intrinsic::nvvm_mma"
# "_" # !subst("::", "_", Metadata)
# "_" # A.geom
# "_row_col"
# !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
# !if(Satfinite, "_satfinite", "")
# signature;
}

// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
// false otherwise.
// E.g.
// if NVVM_MMA_SP_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
string kind, int satf> {
// MMA.SP ops check both layouts.
string a_type = frags[0].ptx_elt_type;
string b_type = frags[1].ptx_elt_type;
string c_type = frags[2].ptx_elt_type;
string d_type = frags[3].ptx_elt_type;
string geom = frags[0].geom;

bit is_int = !or(!eq(a_type, "s8"),
!eq(a_type, "u8"),
!eq(a_type, "s4"),
!eq(a_type, "u4"));

bit ret = !cond(

// Limit satf to valid types
!and(!eq(satf, 1),
!eq(is_int, 0)): false,

// f16/bf16/tf32 requires A and B to be the same type.
!and(!or(!eq(a_type, "f16"),
!eq(a_type, "bf16"),
!eq(a_type, "tf32")),
!ne(a_type, b_type)): false,

// m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
!and(!or(!eq(geom, "m16n8k16"),
!eq(geom, "m16n8k32"),
!eq(geom, "m16n8k64")),
!ne(c_type, d_type)): false,

!and(!eq(kind, ""),
!or(!eq(a_type, "e3m2"),
!eq(a_type, "e2m3"),
!eq(a_type, "e2m1"),
!eq(b_type, "e3m2"),
!eq(b_type, "e2m3"),
!eq(b_type, "e2m1"))): false,

!and(!eq(kind, ""),
!eq(geom, "m16n8k64"),
!or(!eq(c_type, "f16"),
!eq(d_type, "f16"))): false,

!and(!ne(kind, ""),
!or(!eq(metadata, "sp"),
!ne(geom, "m16n8k64"),
!eq(is_int, 1))): false,

// All other are OK.
true: true
);
}

/// Helper to create the mapping between the configuration and the mma.sp.sync
/// intrinsic enum value.
class MMA_SP_SYNC_INTR {
list<list<list<list<string>>>> cond0 =
!foreach(op, NVVM_MMA_OPS.all_mma_sp_sync_ops,
!foreach(metadata, ["sp", "sp::ordered_metadata"],
!foreach(kind, ["", "kind::f8f6f4"],
!foreach (satf, [0, 1],
!if(NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret,
"if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
# " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
# " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
# " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
# " && \"" # op[3].ptx_elt_type # "\" == eltypeD"
# " && (satf.has_value() ? " # satf # " == static_cast<int>(*satf) : true)"
# " && " # !if(!eq(metadata, "sp"), "!orderedMetadata", "orderedMetadata") # ")\n"
# " return " #
MMA_SP_SYNC_NAME<metadata, kind, satf, op[0], op[1], op[2], op[3]>.id # ";",
"") // if supported
) // satf
) // kind
) // metadata
); // all_mma_sp_sync_ops
list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
!listconcat(acc, el));
list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
string id = !foldl("", f3, acc, el, acc # "\n" # el);
}

def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {

let summary = "cooperative sparse matrix-multiply and accumulate";

let description = [{
The `nvvm.mma.sp.sync` operation collectively performs the sparse operation
`D = matmul(A_sparse, B) + C` using all threads in a warp.

This operation is similar to `nvvm.mma.sync` but with structured sparsity
in the A operand. The sparsity follows the 2:4 structured sparse pattern
where 2 out of every 4 elements are non-zero.

All the threads in the warp must execute the same `mma.sp.sync` operation.

The `sparseMetadata` operand provides the sparsity indices that indicate
which elements in the A operand are non-zero. The `sparsitySelector`
controls how the indices are distributed among threads in the warp and
should typically be 0 or 1.

The optional `orderedMetadata` attribute specifies the metadata ordering:
- Absence (default): Uses standard sparse metadata ordering
- Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)

The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
- `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
- Only valid with ordered metadata and m16n8k64 shape

The shapes, layouts, and data types follow the same constraints as the
regular `nvvm.mma.sync` operation, but the A operand contains only the
non-zero elements in compressed format.

Example:
```mlir
%d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>

// With ordered metadata:
%d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
{orderedMetadata, shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
```
}];

let results = (outs LLVM_AnyStruct:$res);
let arguments = (ins NVVM_MMAShapeAttr:$shape,
OptionalAttr<MMAIntOverflowAttr>:$intOverflowBehavior,
OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
UnitAttr:$orderedMetadata,
OptionalAttr<MMAKindAttr>:$kind,
Variadic<LLVM_Type>:$operandA,
Variadic<LLVM_Type>:$operandB,
Variadic<LLVM_Type>:$operandC,
I32:$sparseMetadata,
I32:$sparsitySelector);

let extraClassDeclaration = !strconcat([{
static llvm::Intrinsic::ID getIntrinsicID(
int64_t m, int64_t n, uint64_t k,
std::optional<MMAIntOverflow> satf,
bool orderedMetadata,
std::optional<MMAKind> kind,
mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
mlir::NVVM::MMATypes eltypeCEnum, mlir::NVVM::MMATypes eltypeDEnum) {
llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
llvm::StringRef eltypeD = stringifyEnum(eltypeDEnum);
}],
MMA_SP_SYNC_INTR<>.id, [{
return 0;
}

static std::optional<mlir::NVVM::MMATypes> inferOperandMMAType(Type operandElType,
bool isAccumulator);

MMATypes accumPtxType();
MMATypes resultPtxType();

static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}]);

let builders = [
OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
"ValueRange":$operandB, "ValueRange":$operandC,
"Value":$sparseMetadata, "Value":$sparsitySelector,
"ArrayRef<int64_t>":$shape,
"std::optional<MMAIntOverflow>":$intOverflow,
"std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes)>
];

string llvmBuilder = [{
auto [id, args] = NVVM::MmaSpOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
$res = createIntrinsicCall(builder, id, args);
}];

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
Expand Down
Loading