Skip to content

Commit 878061b

Browse files
committed
[MLIR] Linalg introduce BRGEMM
1 parent 1f734b0 commit 878061b

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,3 +3791,71 @@ structured_op: !LinalgStructuredOpConfig
37913791
scalar_const: '2.3283063999999999E-10 : f64'
37923792
- !ScalarExpression
37933793
scalar_arg: min
3794+
--- !LinalgOpConfig
3795+
metadata: !LinalgOpMetadata
3796+
name: reduce_batch_matmul
3797+
cpp_class_name: ReduceBatchMatmulOp
3798+
doc: |-
3799+
Performs a batched matrix multiplication of two 3D inputs.
3800+
Numeric casting is performed on the operands to the inner multiply, promoting
3801+
them to the same data type as the accumulator/output.
3802+
implements:
3803+
- LinalgContractionOpInterface
3804+
structured_op: !LinalgStructuredOpConfig
3805+
args:
3806+
- !LinalgOperandDefConfig
3807+
name: A
3808+
kind: input_tensor
3809+
type_var: T1
3810+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
3811+
- !LinalgOperandDefConfig
3812+
name: B
3813+
kind: input_tensor
3814+
type_var: T2
3815+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
3816+
- !LinalgOperandDefConfig
3817+
name: C
3818+
kind: output_tensor
3819+
type_var: U
3820+
shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
3821+
indexing_maps: !LinalgIndexingMapsConfig
3822+
static_indexing_maps:
3823+
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
3824+
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
3825+
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
3826+
iterator_types:
3827+
- reduction
3828+
- parallel
3829+
- parallel
3830+
- reduction
3831+
assignments:
3832+
- !ScalarAssign
3833+
arg: C
3834+
value: !ScalarExpression
3835+
scalar_fn:
3836+
kind: binary
3837+
fn_name: add
3838+
operands:
3839+
- !ScalarExpression
3840+
scalar_arg: C
3841+
- !ScalarExpression
3842+
scalar_fn:
3843+
kind: binary
3844+
fn_name: mul
3845+
operands:
3846+
- !ScalarExpression
3847+
scalar_fn:
3848+
kind: type
3849+
fn_name: cast_signed
3850+
type_var: U
3851+
operands:
3852+
- !ScalarExpression
3853+
scalar_arg: A
3854+
- !ScalarExpression
3855+
scalar_fn:
3856+
kind: type
3857+
fn_name: cast_signed
3858+
type_var: U
3859+
operands:
3860+
- !ScalarExpression
3861+
scalar_arg: B

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,3 +762,11 @@ func.func @conv_interface_wrong_num_operands(
762762
}) {dilations = dense<1> : tensor<2xi64>, linalg.memoized_indexing_maps = [#map0, #map1, #map2], operand_segment_sizes = array<i32: 2, 1>, strides = dense<1> : tensor<2xi64>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
763763
return %0 : tensor<?x?x?x?xf32>
764764
}
765+
766+
// -----
767+
768+
func.func @brgemm_test(%arg0: tensor<8x128x256xf32>, %arg1: tensor<8x256x512xf32>, %arg2: tensor<128x512xf32>) -> tensor<128x512xf32> {
769+
// CHECK: linalg.reduce_batch_matmul
770+
%0 = linalg.reduce_batch_matmul ins(%arg0, %arg1 : tensor<8x128x256xf32>, tensor<8x256x512xf32>) outs(%arg2: tensor<128x512xf32>) -> tensor<128x512xf32>
771+
return %0: tensor<128x512xf32>
772+
}

0 commit comments

Comments
 (0)