@@ -3791,3 +3791,71 @@ structured_op: !LinalgStructuredOpConfig
37913791 scalar_const : ' 2.3283063999999999E-10 : f64'
37923792 - !ScalarExpression
37933793 scalar_arg : min
3794+ --- !LinalgOpConfig
3795+ metadata : !LinalgOpMetadata
3796+ name : reduce_batch_matmul
3797+ cpp_class_name : ReduceBatchMatmulOp
3798+ doc : |-
3799+ Performs a batched matrix multiplication of two 3D inputs.
3800+ Numeric casting is performed on the operands to the inner multiply, promoting
3801+ them to the same data type as the accumulator/output.
3802+ implements :
3803+ - LinalgContractionOpInterface
3804+ structured_op : !LinalgStructuredOpConfig
3805+ args :
3806+ - !LinalgOperandDefConfig
3807+ name : A
3808+ kind : input_tensor
3809+ type_var : T1
3810+ shape_map : affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
3811+ - !LinalgOperandDefConfig
3812+ name : B
3813+ kind : input_tensor
3814+ type_var : T2
3815+ shape_map : affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
3816+ - !LinalgOperandDefConfig
3817+ name : C
3818+ kind : output_tensor
3819+ type_var : U
3820+ shape_map : affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
3821+ indexing_maps : !LinalgIndexingMapsConfig
3822+ static_indexing_maps :
3823+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
3824+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
3825+ - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d1, d2)>
3826+ iterator_types :
3827+ - reduction
3828+ - parallel
3829+ - parallel
3830+ - reduction
3831+ assignments :
3832+ - !ScalarAssign
3833+ arg : C
3834+ value : !ScalarExpression
3835+ scalar_fn :
3836+ kind : binary
3837+ fn_name : add
3838+ operands :
3839+ - !ScalarExpression
3840+ scalar_arg : C
3841+ - !ScalarExpression
3842+ scalar_fn :
3843+ kind : binary
3844+ fn_name : mul
3845+ operands :
3846+ - !ScalarExpression
3847+ scalar_fn :
3848+ kind : type
3849+ fn_name : cast_signed
3850+ type_var : U
3851+ operands :
3852+ - !ScalarExpression
3853+ scalar_arg : A
3854+ - !ScalarExpression
3855+ scalar_fn :
3856+ kind : type
3857+ fn_name : cast_signed
3858+ type_var : U
3859+ operands :
3860+ - !ScalarExpression
3861+ scalar_arg : B
0 commit comments