Skip to content

Commit 77c329f

Browse files
[mlir][ROCDL] Adds wmma scaled intrinsics for gfx1250 (#165915)
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 8c3f59f commit 77c329f

File tree

3 files changed

+256
-29
lines changed

3 files changed

+256
-29
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -599,105 +599,155 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
599599
class ROCDL_WMMA_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
600600
[0], [0], [], 1, 0, 0, 0, [], []>,
601601
Arguments<(ins
602-
LLVM_ScalarOrVectorOf<AB>:$A,
603-
LLVM_ScalarOrVectorOf<AB>:$B,
604-
LLVM_ScalarOrVectorOf<CD>:$C)> {
602+
LLVM_ScalarOrVectorOf<AB>:$a,
603+
LLVM_ScalarOrVectorOf<AB>:$b,
604+
LLVM_ScalarOrVectorOf<CD>:$c)> {
605605
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
606606
let assemblyFormat = [{
607-
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
607+
$a `,` $b `,` $c attr-dict `:` functional-type(operands, $res)
608608
}];
609609
}
610610

611611
class ROCDL_WMMA_Opsel_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
612612
[0], [1], [], 1, 0, 0, 0, [3], ["opsel"]>,
613613
Arguments<(ins
614-
LLVM_ScalarOrVectorOf<AB>:$A,
615-
LLVM_ScalarOrVectorOf<AB>:$B,
616-
LLVM_ScalarOrVectorOf<CD>:$C,
614+
LLVM_ScalarOrVectorOf<AB>:$a,
615+
LLVM_ScalarOrVectorOf<AB>:$b,
616+
LLVM_ScalarOrVectorOf<CD>:$c,
617617
DefaultValuedAttr<I1Attr, "0">:$opsel)> {
618618
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
619619
let assemblyFormat = [{
620-
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
620+
$a `,` $b `,` $c attr-dict `:` functional-type(operands, $res)
621621
}];
622622
}
623623

624624
class ROCDL_WMMA_IU_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
625625
[0], [1], [], 1, 0, 0, 0, [0, 2, 5], ["signA", "signB", "clamp"]>,
626626
Arguments<(ins
627627
DefaultValuedAttr<I1Attr, "0">:$signA,
628-
LLVM_ScalarOrVectorOf<AB>:$A,
628+
LLVM_ScalarOrVectorOf<AB>:$a,
629629
DefaultValuedAttr<I1Attr, "0">:$signB,
630-
LLVM_ScalarOrVectorOf<AB>:$B,
631-
LLVM_ScalarOrVectorOf<CD>:$C,
630+
LLVM_ScalarOrVectorOf<AB>:$b,
631+
LLVM_ScalarOrVectorOf<CD>:$c,
632632
DefaultValuedAttr<I1Attr, "0">:$clamp)> {
633633
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
634634
let assemblyFormat = [{
635-
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
635+
$a `,` $b `,` $c attr-dict `:` functional-type(operands, $res)
636636
}];
637637
}
638638

639639
class ROCDL_WMMA_ModsAll_Reuse_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
640640
[0], [1], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
641641
Arguments<(ins
642642
DefaultValuedAttr<I1Attr, "0">:$signA,
643-
LLVM_ScalarOrVectorOf<AB>:$A,
643+
LLVM_ScalarOrVectorOf<AB>:$a,
644644
DefaultValuedAttr<I1Attr, "0">:$signB,
645-
LLVM_ScalarOrVectorOf<AB>:$B,
645+
LLVM_ScalarOrVectorOf<AB>:$b,
646646
DefaultValuedAttr<I16Attr, "0">:$modC,
647-
LLVM_ScalarOrVectorOf<CD>:$C,
647+
LLVM_ScalarOrVectorOf<CD>:$c,
648648
DefaultValuedAttr<I1Attr, "0">:$reuseA,
649649
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
650650
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
651651
let assemblyFormat = [{
652-
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
652+
$a `,` $b `,` $c attr-dict `:` functional-type(operands, $res)
653653
}];
654654
}
655655

656656
class ROCDL_WMMA_ModsC_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
657657
[0], [0], [], 1, 0, 0, 0, [2, 4, 5], ["modC","reuseA","reuseB"]>,
658658
Arguments<(ins
659-
LLVM_ScalarOrVectorOf<AB>:$A,
660-
LLVM_ScalarOrVectorOf<AB>:$B,
659+
LLVM_ScalarOrVectorOf<AB>:$a,
660+
LLVM_ScalarOrVectorOf<AB>:$b,
661661
DefaultValuedAttr<I16Attr, "0">:$modC,
662-
LLVM_ScalarOrVectorOf<CD>:$C,
662+
LLVM_ScalarOrVectorOf<CD>:$c,
663663
DefaultValuedAttr<I1Attr, "0">:$reuseA,
664664
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
665665
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
666666
let assemblyFormat = [{
667-
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
667+
$a `,` $b `,` $c attr-dict `:` functional-type(operands, $res)
668668
}];
669669
}
670670

671671
class ROCDL_WMMA_ModsAll_Diff_IntrOp<string mnemonic, Type AB, Type C, Type D> : ROCDL_IntrOp<mnemonic,
672672
[0], [1, 5], [], 1, 0, 0, 0, [0, 2, 4, 6, 7], ["signA", "signB","modC","reuseA","reuseB"]>,
673673
Arguments<(ins
674674
DefaultValuedAttr<I1Attr, "0">:$signA,
675-
LLVM_ScalarOrVectorOf<AB>:$A,
675+
LLVM_ScalarOrVectorOf<AB>:$a,
676676
DefaultValuedAttr<I1Attr, "0">:$signB,
677-
LLVM_ScalarOrVectorOf<AB>:$B,
677+
LLVM_ScalarOrVectorOf<AB>:$b,
678678
DefaultValuedAttr<I16Attr, "0">:$modC,
679-
LLVM_ScalarOrVectorOf<C>:$C,
679+
LLVM_ScalarOrVectorOf<C>:$c,
680680
DefaultValuedAttr<I1Attr, "0">:$reuseA,
681681
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
682682
let results = (outs LLVM_ScalarOrVectorOf<D>:$res);
683683
let assemblyFormat = [{
684-
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
684+
$a `,` $b `,` $c attr-dict `:` functional-type(operands, $res)
685685
}];
686686
}
687687

688688
class ROCDL_WMMA_ModsAB_IntrOp<string mnemonic, Type AB, Type CD> : ROCDL_IntrOp<mnemonic,
689689
[0], [1], [], 1, 0, 0, 0, [0, 2, 5, 6], ["signA", "signB", "reuseA","reuseB"]>,
690690
Arguments<(ins
691691
DefaultValuedAttr<I1Attr, "0">:$signA,
692-
LLVM_ScalarOrVectorOf<AB>:$A,
692+
LLVM_ScalarOrVectorOf<AB>:$a,
693693
DefaultValuedAttr<I1Attr, "0">:$signB,
694-
LLVM_ScalarOrVectorOf<AB>:$B,
695-
LLVM_ScalarOrVectorOf<CD>:$C,
694+
LLVM_ScalarOrVectorOf<AB>:$b,
695+
LLVM_ScalarOrVectorOf<CD>:$c,
696696
DefaultValuedAttr<I1Attr, "0">:$reuseA,
697697
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
698698
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
699699
let assemblyFormat = [{
700-
$A `,` $B `,` $C attr-dict `:` functional-type(operands, $res)
700+
$a `,` $b `,` $c attr-dict `:` functional-type(operands, $res)
701+
}];
702+
}
703+
704+
// Overloaded operands: [1, 3] refers to LLVM intrinsic parameter positions where
705+
// A is at position 1 and B is at position 3 (after format parameters).
706+
class ROCDL_WMMA_Scale_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic,
707+
[0], [1, 3], [], 1, 0, 0, 0, [0, 2, 4, 6, 7, 9, 10, 12, 13],
708+
["fmtA", "fmtB", "modC", "scaleAType", "fmtScaleA",
709+
"scaleBType", "fmtScaleB", "reuseA", "reuseB"]>,
710+
Arguments<(ins
711+
DefaultValuedAttr<I32Attr, "0">:$fmtA,
712+
LLVM_ScalarOrVectorOf<AB>:$a,
713+
DefaultValuedAttr<I32Attr, "0">:$fmtB,
714+
LLVM_ScalarOrVectorOf<AB>:$b,
715+
DefaultValuedAttr<I16Attr, "0">:$modC,
716+
LLVM_ScalarOrVectorOf<CD>:$c,
717+
DefaultValuedAttr<I32Attr, "0">:$scaleAType,
718+
DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
719+
ScaleExpTy:$scaleA,
720+
DefaultValuedAttr<I32Attr, "0">:$scaleBType,
721+
DefaultValuedAttr<I32Attr, "0">:$fmtScaleB,
722+
ScaleExpTy:$scaleB,
723+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
724+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
725+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
726+
let assemblyFormat = [{
727+
$a `,` $b `,` $c `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res)
728+
}];
729+
}
730+
731+
class ROCDL_WMMA_Scale_F4_IntrOp<string mnemonic, Type AB, Type CD, Type ScaleExpTy> : ROCDL_IntrOp<mnemonic,
732+
[0], [0, 1], [], 1, 0, 0, 0, [2, 4, 5, 7, 8, 10, 11],
733+
["modC", "scaleAType", "fmtScaleA",
734+
"scaleBType", "fmtScaleB", "reuseA", "reuseB"]>,
735+
Arguments<(ins
736+
LLVM_ScalarOrVectorOf<AB>:$a,
737+
LLVM_ScalarOrVectorOf<AB>:$b,
738+
DefaultValuedAttr<I16Attr, "0">:$modC,
739+
LLVM_ScalarOrVectorOf<CD>:$c,
740+
DefaultValuedAttr<I32Attr, "0">:$scaleAType,
741+
DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
742+
ScaleExpTy:$scaleA,
743+
DefaultValuedAttr<I32Attr, "0">:$scaleBType,
744+
DefaultValuedAttr<I32Attr, "0">:$fmtScaleB,
745+
ScaleExpTy:$scaleB,
746+
DefaultValuedAttr<I1Attr, "0">:$reuseA,
747+
DefaultValuedAttr<I1Attr, "0">:$reuseB)> {
748+
let results = (outs LLVM_ScalarOrVectorOf<CD>:$res);
749+
let assemblyFormat = [{
750+
$a `,` $b `,` $c `,` $scaleA `,` $scaleB attr-dict `:` functional-type(operands, $res)
701751
}];
702752
}
703753

@@ -739,6 +789,12 @@ def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x1
739789
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_WMMA_ModsC_IntrOp<"wmma.f16.16x16x128.bf8_bf8", AnyInteger, F16>;
740790
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_WMMA_ModsAB_IntrOp<"wmma.i32.16x16x64.iu8", AnyInteger, AnyInteger>;
741791

792+
// Scaled wmma intrinsics (available from gfx1250)
793+
def ROCDL_wmma_scale_f32_16x16x128_f8f6f4 : ROCDL_WMMA_Scale_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4", AnyInteger, F32, I32>;
794+
def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_WMMA_Scale_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4", AnyInteger, F32, I64>;
795+
def ROCDL_wmma_scale_f32_32x16x128_f4 : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale.f32.32x16x128.f4", AnyInteger, F32, I32>;
796+
def ROCDL_wmma_scale16_f32_32x16x128_f4 : ROCDL_WMMA_Scale_F4_IntrOp<"wmma.scale16.f32.32x16x128.f4", AnyInteger, F32, I64>;
797+
742798
//===---------------------------------------------------------------------===//
743799
// LDS transpose intrinsics (available in GFX950)
744800

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1351,6 +1351,26 @@ llvm.func @rocdl.cvt.scalef32.sr.pk16(%v16xf32: vector<16xf32>,
13511351

13521352
// -----
13531353

1354+
// CHECK-LABEL: @rocdl_wmma_scale_ops
1355+
llvm.func @rocdl_wmma_scale_ops(%a_f8: vector<8xi32>, %a_f4: vector<4xi32>, %c_f32: vector<4xf32>, %c16_f32: vector<16xf32>,
1356+
%scale_i32: i32, %scale_i64: i64) {
1357+
// CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
1358+
%r0 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i32, %scale_i32 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
1359+
1360+
// CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
1361+
%r1 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4 %a_f8, %a_f8, %c_f32, %scale_i64, %scale_i64 : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
1362+
1363+
// CHECK: rocdl.wmma.scale.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
1364+
%r2 = rocdl.wmma.scale.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i32, %scale_i32 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32) -> vector<16xf32>
1365+
1366+
// CHECK: rocdl.wmma.scale16.f32.32x16x128.f4 %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
1367+
%r3 = rocdl.wmma.scale16.f32.32x16x128.f4 %a_f4, %a_f4, %c16_f32, %scale_i64, %scale_i64 : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i64, i64) -> vector<16xf32>
1368+
1369+
llvm.return
1370+
}
1371+
1372+
// -----
1373+
13541374
// expected-error@below {{attribute attached to unexpected op}}
13551375
func.func private @expected_llvm_func() attributes { rocdl.kernel }
13561376

0 commit comments

Comments
 (0)