@@ -596,51 +596,148 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
596596
597597//===---------------------------------------------------------------------===//
598598// WMMA intrinsics
599- class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
600- list<Trait> traits = []> :
601- ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
602- Arguments<(ins Variadic<LLVM_Type>:$args)> {
603- let assemblyFormat =
604- "$args attr-dict `:` functional-type($args, $res)";
599+ class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
600+ [0], [0], [], 1, 0, 0, 0, [], []>,
601+ Arguments<(ins
602+ LLVM_ScalarOrVectorOf<AB>:$A,
603+ LLVM_ScalarOrVectorOf<AB>:$B,
604+ LLVM_ScalarOrVectorOf<CD>:$C)> {
605+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
606+ let assemblyFormat = [{
607+ $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
608+ }];
609+ }
610+
611+ class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
612+ [0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
613+ Arguments<(ins
614+ LLVM_ScalarOrVectorOf<AB>:$A,
615+ LLVM_ScalarOrVectorOf<AB>:$B,
616+ LLVM_ScalarOrVectorOf<CD>:$C,
617+ DefaultValuedAttr<I1Attr, "0">:$opsel)> {
618+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
619+ let assemblyFormat = [{
620+ $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
621+ }];
622+ }
623+
624+ class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
625+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
626+ Arguments<(ins
627+ DefaultValuedAttr<I1Attr, "0">:$signA,
628+ LLVM_ScalarOrVectorOf<AB>:$A,
629+ DefaultValuedAttr<I1Attr, "0">:$signB,
630+ LLVM_ScalarOrVectorOf<AB>:$B,
631+ LLVM_ScalarOrVectorOf<CD>:$C,
632+ DefaultValuedAttr<I1Attr, "0">:$clamp)> {
633+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
634+ let assemblyFormat = [{
635+ $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
636+ }];
637+ }
638+
639+ class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
640+ [0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
641+ Arguments<(ins
642+ DefaultValuedAttr<I1Attr, "0">:$signA,
643+ LLVM_ScalarOrVectorOf<AB>:$A,
644+ DefaultValuedAttr<I1Attr, "0">:$signB,
645+ LLVM_ScalarOrVectorOf<AB>:$B,
646+ DefaultValuedAttr<I16Attr, "0">:$modC,
647+ LLVM_ScalarOrVectorOf<CD>:$C,
648+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
649+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
650+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
651+ let assemblyFormat = [{
652+ $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
653+ }];
654+ }
655+
656+ class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
657+ [0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
658+ Arguments<(ins
659+ LLVM_ScalarOrVectorOf<AB>:$A,
660+ LLVM_ScalarOrVectorOf<AB>:$B,
661+ DefaultValuedAttr<I16Attr, "0">:$modC,
662+ LLVM_ScalarOrVectorOf<CD>:$C,
663+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
664+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
665+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
666+ let assemblyFormat = [{
667+ $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
668+ }];
669+ }
670+
671+ class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
672+ [0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
673+ Arguments<(ins
674+ DefaultValuedAttr<I1Attr, "0">:$signA,
675+ LLVM_ScalarOrVectorOf<AB>:$A,
676+ DefaultValuedAttr<I1Attr, "0">:$signB,
677+ LLVM_ScalarOrVectorOf<AB>:$B,
678+ DefaultValuedAttr<I16Attr, "0">:$modC,
679+ LLVM_ScalarOrVectorOf<C>:$C,
680+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
681+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
682+ let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
683+ let assemblyFormat = [{
684+ $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
685+ }];
686+ }
687+
688+ class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
689+ [0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
690+ Arguments<(ins
691+ DefaultValuedAttr<I1Attr, "0">:$signA,
692+ LLVM_ScalarOrVectorOf<AB>:$A,
693+ DefaultValuedAttr<I1Attr, "0">:$signB,
694+ LLVM_ScalarOrVectorOf<AB>:$B,
695+ LLVM_ScalarOrVectorOf<CD>:$C,
696+ DefaultValuedAttr<I1Attr, "0">:$reuseA,
697+ DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
698+ let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
699+ let assemblyFormat = [{
700+ $A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
701+ }];
605702}
606703
607704// Available from gfx11
608- def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.f16", [0] >;
609- def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.bf16", [0] >;
610- def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x16.f16", [0] >;
611- def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp <"wmma.bf16.16x16x16.bf16", [0] >;
612- def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x16.iu8", [1] >;
613- def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x16.iu4", [1] >;
705+ def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.f16", /*Type AB=*/F16, /*Type CD=*/F32 >;
706+ def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.bf16", AnyInteger, F32 >;
707+ def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_WMMA_Opsel_IntrOp <"wmma.f16.16x16x16.f16", F16, F16 >;
708+ def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_WMMA_Opsel_IntrOp <"wmma.bf16.16x16x16.bf16", AnyInteger, AnyInteger >;
709+ def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_WMMA_IU_IntrOp <"wmma.i32.16x16x16.iu8", AnyInteger, AnyInteger >;
710+ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_WMMA_IU_IntrOp <"wmma.i32.16x16x16.iu4", AnyInteger, AnyInteger >;
614711// Available from gfx12
615- def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.fp8_fp8", [1] >;
616- def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.fp8_bf8", [1] >;
617- def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.bf8_bf8", [1] >;
618- def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x16.bf8_fp8", [1] >;
619- def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x32.iu4", [1] >;
712+ def ROCDL_wmma_f32_16x16x16_fp8_fp8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.fp8_fp8", AnyInteger, F32 >;
713+ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.fp8_bf8", AnyInteger, F32 >;
714+ def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.bf8_bf8", AnyInteger, F32 >;
715+ def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_WMMA_IntrOp <"wmma.f32.16x16x16.bf8_fp8", AnyInteger, F32 >;
716+ def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_WMMA_IU_IntrOp <"wmma.i32.16x16x32.iu4", AnyInteger, AnyInteger >;
620717// Available from gfx1250
621- def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x4.f32", [1] >;
622- def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x32.bf16", [1] >;
623- def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x32.f16", [1] >;
624- def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x32.f16", [1] >;
625- def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp <"wmma.bf16.16x16x32.bf16", [1] >;
626- def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp <"wmma.bf16f32.16x16x32.bf16", [1,5] >;
627- def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.fp8_fp8", [0] >;
628- def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.fp8_bf8", [0] >;
629- def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.bf8_fp8", [0] >;
630- def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x64.bf8_bf8", [0] >;
631- def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.fp8_fp8", [0] >;
632- def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.fp8_bf8", [0] >;
633- def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.bf8_fp8", [0] >;
634- def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x64.bf8_bf8", [0] >;
635- def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.fp8_fp8", [0] >;
636- def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.fp8_bf8", [0] >;
637- def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.bf8_fp8", [0] >;
638- def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f32.16x16x128.bf8_bf8", [0] >;
639- def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.fp8_fp8", [0] >;
640- def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.fp8_bf8", [0] >;
641- def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.bf8_fp8", [0] >;
642- def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp <"wmma.f16.16x16x128.bf8_bf8", [0] >;
643- def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp <"wmma.i32.16x16x64.iu8", [1] >;
718+ def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f32.16x16x4.f32", F32, F32 >;
719+ def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f32.16x16x32.bf16", BF16, F32 >;
720+ def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f32.16x16x32.f16", F16, F32 >;
721+ def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.f16.16x16x32.f16", F16, F16 >;
722+ def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Reuse_IntrOp <"wmma.bf16.16x16x32.bf16", BF16, BF16 >;
723+ def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_WMMA_ModsAll_Diff_IntrOp <"wmma.bf16f32.16x16x32.bf16", BF16, /*Type C=*/F32, /*Type D=*/BF16 >;
724+ def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.fp8_fp8", AnyInteger, F32 >;
725+ def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.fp8_bf8", AnyInteger, F32 >;
726+ def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.bf8_fp8", AnyInteger, F32 >;
727+ def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x64.bf8_bf8", AnyInteger, F32 >;
728+ def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.fp8_fp8", AnyInteger, F16 >;
729+ def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.fp8_bf8", AnyInteger, F16 >;
730+ def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.bf8_fp8", AnyInteger, F16 >;
731+ def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x64.bf8_bf8", AnyInteger, F16 >;
732+ def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.fp8_fp8", AnyInteger, F32 >;
733+ def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.fp8_bf8", AnyInteger, F32 >;
734+ def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.bf8_fp8", AnyInteger, F32 >;
735+ def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f32.16x16x128.bf8_bf8", AnyInteger, F32 >;
736+ def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.fp8_fp8", AnyInteger, F16 >;
737+ def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.fp8_bf8", AnyInteger, F16 >;
738+ def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.bf8_fp8", AnyInteger, F16 >;
739+ def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp <"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16 >;
740+ def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp <"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger >;
644741
645742//===---------------------------------------------------------------------===//
646743// LDS transpose intrinsics (available in GFX950)
0 commit comments