Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,35 @@ def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}

def BfmmlaOp : ArmSVE_IntrOp<"bfmmla", [Pure,
AllTypesMatch<["src1", "src2"]>,
AllTypesMatch<["acc", "res"]>,
]> {
let summary = "BFloat16 matrix multiply-accumulate";
let description = [{
BFMMLA: BFloat16 matrix multiply-accumulate into 2×2 matrices";

This operation multiplies the 2x4 BFloat16 matrix held in each 128-bit
segment of the first source vector by the 4x2 BFloat16 matrix in the
corresponding segment of the second source vector, then accumulates
this intermediate result with the 2x2 Float32 matrix in the corresponding
segment of the accumulator vector, yielding the final 2x2 Float32
segment of the result.

Source:
https://developer.arm.com/documentation/100987/0000
}];
// Supports (vector<[8]xbf16>, vector<[8]xbf16>) -> (vector<[4]xf32>)
let arguments = (ins
ScalableVectorOfLengthAndType<[4], [F32]>:$acc,
ScalableVectorOfLengthAndType<[8], [BF16]>:$src1,
ScalableVectorOfLengthAndType<[8], [BF16]>:$src2
);
let results = (outs ScalableVectorOfLengthAndType<[4], [F32]>:$res);
let assemblyFormat =
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($res)";
}

class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
target.addLegalOp<ConvertFromSvboolIntrOp,
target.addLegalOp<BfmmlaOp,
ConvertFromSvboolIntrOp,
ConvertToSvboolIntrOp,
DupQLaneIntrOp,
PselIntrOp,
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/ArmSVE/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,63 @@ func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
return
}

// -----

func.func @bfmmla_invalid_element_type_lhs_rhs(%acc: vector<[4]xf32>,
%lhs: vector<[8]xf16>,
%rhs: vector<[8]xf16>) -> vector<[4]xf32> {
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[8]xf16>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}

// -----

func.func @bfmmla_invalid_dimension_lhs_rhs(%acc: vector<[4]xf32>,
%lhs: vector<[4]xbf16>,
%rhs: vector<[4]xbf16>) -> vector<[4]xf32> {
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<[4]xbf16>}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[4]xbf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}

// -----

func.func @bfmmla_fixed_dimension_lhs_rhs(%acc: vector<[4]xf32>,
%lhs: vector<8xbf16>,
%rhs: vector<8xbf16>) -> vector<[4]xf32> {
// expected-error@+1 {{operand #1 must be scalable vector of bfloat16 type values of length 8, but got 'vector<8xbf16>}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<8xbf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}

// -----

func.func @bfmmla_invalid_element_type_acc(%acc: vector<[4]xi32>,
%lhs: vector<[8]xbf16>,
%rhs: vector<[8]xbf16>) -> vector<[4]xi32> {
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[4]xi32>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[4]xi32>
return %0 : vector<[4]xi32>
}

// -----

func.func @bfmmla_invalid_dimension_acc(%acc: vector<[8]xf32>,
%lhs: vector<[8]xbf16>,
%rhs: vector<[8]xbf16>) -> vector<[8]xf32> {
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<[8]xf32>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<[8]xf32>
return %0 : vector<[8]xf32>
}

// -----

func.func @bfmmla_fixed_dimension_acc(%acc: vector<4xf32>,
%lhs: vector<[8]xbf16>,
%rhs: vector<[8]xbf16>) -> vector<4xf32> {
// expected-error@+1 {{operand #0 must be scalable vector of 32-bit float values of length 4, but got 'vector<4xf32>'}}
%0 = arm_sve.intr.bfmmla %acc, %lhs, %rhs : vector<[8]xbf16> to vector<4xf32>
return %0 : vector<4xf32>
}
10 changes: 10 additions & 0 deletions mlir/test/Dialect/ArmSVE/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ func.func @arm_sve_usmmla(%a: vector<[16]xi8>,

// -----

func.func @arm_sve_bfmmla(%a: vector<[8]xbf16>,
%b: vector<[8]xbf16>,
%c: vector<[4]xf32>) -> vector<[4]xf32> {
// CHECK: arm_sve.intr.bfmmla {{.*}}: vector<[8]xbf16> to vector<[4]xf32>
%0 = arm_sve.intr.bfmmla %c, %a, %b : vector<[8]xbf16> to vector<[4]xf32>
return %0 : vector<[4]xf32>
}

// -----

func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
%b: vector<[4]xi32>,
%c: vector<[4]xi32>,
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Target/LLVMIR/arm-sve.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,18 @@ llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
llvm.return %0 : vector<[4]xi32>
}

// CHECK-LABEL: define <vscale x 4 x float> @arm_sve_bfmmla
llvm.func @arm_sve_bfmmla(%arg0: vector<[8]xbf16>,
%arg1: vector<[8]xbf16>,
%arg2: vector<[4]xf32>)
-> vector<[4]xf32> {
// CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.bfmmla(<vscale x 4 x float>
%0 = "arm_sve.intr.bfmmla"(%arg2, %arg0, %arg1) :
(vector<[4]xf32>, vector<[8]xbf16>, vector<[8]xbf16>)
-> vector<[4]xf32>
llvm.return %0 : vector<[4]xf32>
}

// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
%arg1: vector<[4]xi32>,
Expand Down