Skip to content

Commit 71e2f13

Browse files
[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla
1 parent 02e68ee commit 71e2f13

File tree

5 files changed

+71
-0
lines changed

5 files changed

+71
-0
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
273273
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
274274
}
275275

276+
def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
277+
AllTypesMatch<["src1", "src2"]>,
278+
AllTypesMatch<["acc", "dst"]>]> {
279+
let summary = "Matrix-matrix multiply and accumulate op";
280+
let description = [{
281+
USMMLA: Unsigned by signed integer matrix multiply-accumulate.
282+
283+
The unsigned by signed integer matrix multiply-accumulate operation
284+
multiplies the 2×8 matrix of unsigned 8-bit integer values held
285+
the first source vector by the 8×2 matrix of signed 8-bit integer
286+
values in the second source vector. The resulting 2×2 widened 32-bit
287+
integer matrix product is then added to the 32-bit integer matrix
288+
accumulator.
289+
290+
Source:
291+
https://developer.arm.com/documentation/100987/0000
292+
}];
293+
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
294+
let arguments = (ins
295+
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
296+
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
297+
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
298+
);
299+
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
300+
let assemblyFormat =
301+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
302+
}
303+
276304
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
277305
"expected corresponding svbool type widened to [16]xi1",
278306
lhsArg, rhsArg,
@@ -568,6 +596,10 @@ def SmmlaIntrOp :
568596
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
569597
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
570598

599+
def UsmmlaIntrOp :
600+
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
601+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
602+
571603
def SdotIntrOp :
572604
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
573605
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
2424
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
2525
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
2626
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
27+
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
2728
using DupQLaneLowering =
2829
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
2930
using ScalableMaskedAddIOpLowering =
@@ -194,6 +195,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
194195
SmmlaOpLowering,
195196
UdotOpLowering,
196197
UmmlaOpLowering,
198+
UsmmlaOpLowering,
197199
DupQLaneLowering,
198200
ScalableMaskedAddIOpLowering,
199201
ScalableMaskedAddFOpLowering,
@@ -222,6 +224,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
222224
SmmlaIntrOp,
223225
UdotIntrOp,
224226
UmmlaIntrOp,
227+
UsmmlaIntrOp,
225228
DupQLaneIntrOp,
226229
ScalableMaskedAddIIntrOp,
227230
ScalableMaskedAddFIntrOp,
@@ -242,6 +245,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
242245
SmmlaOp,
243246
UdotOp,
244247
UmmlaOp,
248+
UsmmlaOp,
245249
DupQLaneOp,
246250
ScalableMaskedAddIOp,
247251
ScalableMaskedAddFOp,

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
4848

4949
// -----
5050

51+
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
52+
%b: vector<[16]xi8>,
53+
%c: vector<[4]xi32>)
54+
-> vector<[4]xi32> {
55+
// CHECK: arm_sve.intr.usmmla
56+
%0 = arm_sve.usmmla %c, %a, %b :
57+
vector<[16]xi8> to vector<[4]xi32>
58+
return %0 : vector<[4]xi32>
59+
}
60+
61+
// -----
62+
5163
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
5264
%b: vector<[4]xi32>,
5365
%c: vector<[4]xi32>,

mlir/test/Dialect/ArmSVE/roundtrip.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
4444

4545
// -----
4646

47+
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
48+
%b: vector<[16]xi8>,
49+
%c: vector<[4]xi32>) -> vector<[4]xi32> {
50+
// CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
51+
%0 = arm_sve.usmmla %c, %a, %b :
52+
vector<[16]xi8> to vector<[4]xi32>
53+
return %0 : vector<[4]xi32>
54+
}
55+
56+
// -----
57+
4758
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
4859
%b: vector<[4]xi32>,
4960
%c: vector<[4]xi32>,

mlir/test/Target/LLVMIR/arm-sve.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
4848
llvm.return %0 : vector<[4]xi32>
4949
}
5050

51+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
52+
llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
53+
%arg1: vector<[16]xi8>,
54+
%arg2: vector<[4]xi32>)
55+
-> vector<[4]xi32> {
56+
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
57+
%0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
58+
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
59+
-> vector<[4]xi32>
60+
llvm.return %0 : vector<[4]xi32>
61+
}
62+
5163
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
5264
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
5365
%arg1: vector<[4]xi32>,

0 commit comments

Comments
 (0)