@@ -582,51 +582,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
582582
583583//===---------------------------------------------------------------------===//
584584// WMMA intrinsics
585- class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
586- list<Trait> traits = []> :
587- ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
588- Arguments<(ins Variadic<LLVM_Type>:$args)> {
589- let assemblyFormat =
590- "$args attr-dict `:` functional-type($args, $res)";
585+ class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
586+ [0], [0], [], 1, 0, 0, 0, [], []>,
587+ Arguments<(ins
588+ LLVM_ScalarOrVectorOf<AB>:$A,
589+ LLVM_ScalarOrVectorOf<AB>:$B,
590+ LLVM_ScalarOrVectorOf<CD>:$C)> {
591+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
592+ let assemblyFormat = [{
593+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
594+ }];
595+ }
596+
597+ class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
598+ [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
599+ Arguments<(ins
600+ LLVM_ScalarOrVectorOf<AB>:$A,
601+ LLVM_ScalarOrVectorOf<AB>:$B,
602+ LLVM_ScalarOrVectorOf<CD>:$C,
603+ DefaultValuedAttr<I1Attr, "0">:$opsel)> {
604+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
605+ let assemblyFormat = [{
606+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
607+ }];
608+ }
609+
610+ class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
611+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
612+ Arguments<(ins
613+ DefaultValuedAttr<I1Attr, "0">:$signA,
614+ LLVM_ScalarOrVectorOf<AB>:$A,
615+ DefaultValuedAttr<I1Attr, "0">:$signB,
616+ LLVM_ScalarOrVectorOf<AB>:$B,
617+ LLVM_ScalarOrVectorOf<CD>:$C,
618+ DefaultValuedAttr<I1Attr, "0">:$clamp)> {
619+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
620+ let assemblyFormat = [{
621+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
622+ }];
623+ }
624+
625+ class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
626+ [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
627+ Arguments<(ins
628+ DefaultValuedAttr<I1Attr, "0">:$signA,
629+ LLVM_ScalarOrVectorOf<AB>:$A,
630+ DefaultValuedAttr<I1Attr, "0">:$signB,
631+ LLVM_ScalarOrVectorOf<AB>:$B,
632+ DefaultValuedAttr<I16Attr, "0">:$modC,
633+ LLVM_ScalarOrVectorOf<CD>:$C,
634+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
635+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
636+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
637+ let assemblyFormat = [{
638+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
639+ }];
640+ }
641+
642+ class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
643+ [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
644+ Arguments<(ins
645+ LLVM_ScalarOrVectorOf<AB>:$A,
646+ LLVM_ScalarOrVectorOf<AB>:$B,
647+ DefaultValuedAttr<I16Attr, "0">:$modC,
648+ LLVM_ScalarOrVectorOf<CD>:$C,
649+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
650+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
651+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
652+ let assemblyFormat = [{
653+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
654+ }];
655+ }
656+
657+ class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
658+ [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
659+ Arguments<(ins
660+ DefaultValuedAttr<I1Attr, "0">:$signA,
661+ LLVM_ScalarOrVectorOf<AB>:$A,
662+ DefaultValuedAttr<I1Attr, "0">:$signB,
663+ LLVM_ScalarOrVectorOf<AB>:$B,
664+ DefaultValuedAttr<I16Attr, "0">:$modC,
665+ LLVM_ScalarOrVectorOf<C>:$C,
666+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
667+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
668+ let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
669+ let assemblyFormat = [{
670+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
671+ }];
672+ }
673+
674+ class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
675+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
676+ Arguments<(ins
677+ DefaultValuedAttr<I1Attr, "0">:$signA,
678+ LLVM_ScalarOrVectorOf<AB>:$A,
679+ DefaultValuedAttr<I1Attr, "0">:$signB,
680+ LLVM_ScalarOrVectorOf<AB>:$B,
681+ LLVM_ScalarOrVectorOf<CD>:$C,
682+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
683+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
684+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
685+ let assemblyFormat = [{
686+ $A `,` $B `,` $C attr-dict `:` `(`type($A) `,` type($B) `,` type($C)`)` `->` type($res)
687+ }];
591688}
592689
593690// Available from gfx11
594- def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.f16", [0] >;
595- def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.bf16", [0] >;
596- def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x16.f16", [0] >;
597- def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp <"wmma.bf16.16x16x16.bf16", [0] >;
598- def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x16.iu8", [1] >;
599- def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x16.iu4", [1] >;
691+ def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32 >;
692+ def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.bf16", AnyInteger, F32 >;
693+ def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp <"wmma.f16.16x16x16.f16", F16, F16 >;
694+ def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp <"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger >;
695+ def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp <"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger >;
696+ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp <"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger >;
600697// Available from gfx12
601- def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.fp8_fp8", [1] >;
602- def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.fp8_bf8", [1] >;
603- def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.bf8_bf8", [1] >;
604- def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.bf8_fp8", [1] >;
605- def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x32.iu4", [1] >;
698+ def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32 >;
699+ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32 >;
700+ def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32 >;
701+ def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32 >;
702+ def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp <"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger >;
606703// Available from gfx1250
607- def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x4.f32", [1] >;
608- def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x32.bf16", [1] >;
609- def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x32.f16", [1] >;
610- def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x32.f16", [1] >;
611- def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp <"wmma.bf16.16x16x32.bf16", [1] >;
612- def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp <"wmma.bf16f32.16x16x32.bf16", [1,5] >;
613- def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.fp8_fp8", [0] >;
614- def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.fp8_bf8", [0] >;
615- def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.bf8_fp8", [0] >;
616- def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.bf8_bf8", [0] >;
617- def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.fp8_fp8", [0] >;
618- def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.fp8_bf8", [0] >;
619- def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.bf8_fp8", [0] >;
620- def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.bf8_bf8", [0] >;
621- def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.fp8_fp8", [0] >;
622- def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.fp8_bf8", [0] >;
623- def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.bf8_fp8", [0] >;
624- def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.bf8_bf8", [0] >;
625- def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.fp8_fp8", [0] >;
626- def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.fp8_bf8", [0] >;
627- def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.bf8_fp8", [0] >;
628- def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.bf8_bf8", [0] >;
629- def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x64.iu8", [1] >;
704+ def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f32.16x16x4.f32", F32, F32 >;
705+ def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f32.16x16x32.bf16", BF16, F32 >;
706+ def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f32.16x16x32.f16", F16, F32 >;
707+ def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f16.16x16x32.f16", F16, F16 >;
708+ def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.bf16.16x16x32.bf16", BF16, BF16 >;
709+ def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp <"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16 >;
710+ def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32 >;
711+ def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32 >;
712+ def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32 >;
713+ def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32 >;
714+ def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16 >;
715+ def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16 >;
716+ def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16 >;
717+ def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16 >;
718+ def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32 >;
719+ def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32 >;
720+ def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32 >;
721+ def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32 >;
722+ def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16 >;
723+ def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16 >;
724+ def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16 >;
725+ def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16 >;
726+ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp <"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger >;
630727
631728//===---------------------------------------------------------------------===//
632729// LDS transpose intrinsics (available in GFX950)
0 commit comments