@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
568568 print (module )
569569
570570
571+ # CHECK-LABEL: TEST: testBatchReduceMatmulOp
572+ @run
573+ def testBatchReduceMatmulOp ():
574+ with Context (), Location .unknown ():
575+ module = Module .create ()
576+ f32 = F32Type .get ()
577+ with InsertionPoint (module .body ):
578+ a_shape = (5 , 4 , 8 )
579+ b_shape = (5 , 8 , 12 )
580+ b_transposed_shape = (5 , 12 , 8 )
581+ c_shape = (4 , 12 )
582+
583+ dimBatch = ir .AffineDimExpr .get (0 )
584+ dimM = ir .AffineDimExpr .get (1 )
585+ dimN = ir .AffineDimExpr .get (2 )
586+ dimK = ir .AffineDimExpr .get (3 )
587+
588+ # CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
589+ # CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
590+ # CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
591+ a_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimM , dimK ])
592+ b_transposed_map = ir .AffineMap .get (4 , 0 , [dimBatch , dimN , dimK ])
593+ c_map = ir .AffineMap .get (4 , 0 , [dimM , dimN ])
594+
595+ # CHECK: func.func @batch_reduce_matmul_op(
596+ @func .FuncOp .from_py_func (
597+ # CHECK-SAME: %[[A:.*]]: tensor<5x4x8xf32>,
598+ RankedTensorType .get (a_shape , f32 ),
599+ # CHECK-SAME: %[[Amem:.*]]: memref<5x4x8xf32>,
600+ MemRefType .get (a_shape , f32 ),
601+ # CHECK-SAME: %[[B:.*]]: tensor<5x8x12xf32>,
602+ RankedTensorType .get (b_shape , f32 ),
603+ # CHECK-SAME: %[[Bmem:.*]]: memref<5x8x12xf32>,
604+ MemRefType .get (b_shape , f32 ),
605+ # CHECK-SAME: %[[BTrans:.*]]: tensor<5x12x8xf32>,
606+ RankedTensorType .get (b_transposed_shape , f32 ),
607+ # CHECK-SAME: %[[BTransmem:.*]]: memref<5x12x8xf32>,
608+ MemRefType .get (b_transposed_shape , f32 ),
609+ # CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
610+ RankedTensorType .get (c_shape , f32 ),
611+ # CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
612+ MemRefType .get (c_shape , f32 ),
613+ )
614+ def batch_reduce_matmul_op (
615+ A , Amem , B , Bmem , Btransposed , Btransposedmem , C , Cmem
616+ ):
617+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
618+ res = linalg .BatchReduceMatmulOp (
619+ result_tensors = (C .type ,),
620+ inputs = (A , B ),
621+ outputs = (C ,),
622+ )
623+ linalg .fill_builtin_region (res .operation )
624+ # CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
625+ res = linalg .batch_reduce_matmul (A , B , outs = (C ,))
626+
627+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
628+ res = linalg .BatchReduceMatmulOp (
629+ result_tensors = (C .type ,),
630+ inputs = (A , Btransposed ),
631+ outputs = (C ,),
632+ indexing_maps = [a_map , b_transposed_map , c_map ],
633+ )
634+ linalg .fill_builtin_region (res .operation )
635+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
636+ res = linalg .batch_reduce_matmul (
637+ A ,
638+ Btransposed ,
639+ outs = (C ,),
640+ indexing_maps = [a_map , b_transposed_map , c_map ],
641+ )
642+
643+ # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
644+ res = linalg .BatchReduceMatmulOp (
645+ result_tensors = [],
646+ inputs = (Amem , Bmem ),
647+ outputs = (Cmem ,),
648+ )
649+ linalg .fill_builtin_region (res .operation )
650+ # CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
651+ linalg .batch_reduce_matmul (Amem , Bmem , outs = (Cmem ,))
652+
653+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
654+ res = linalg .BatchReduceMatmulOp (
655+ result_tensors = [],
656+ inputs = (Amem , Btransposedmem ),
657+ outputs = (Cmem ,),
658+ indexing_maps = [a_map , b_transposed_map , c_map ],
659+ )
660+ linalg .fill_builtin_region (res .operation )
661+ # CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
662+ linalg .batch_reduce_matmul (
663+ Amem ,
664+ Btransposedmem ,
665+ outs = (Cmem ,),
666+ indexing_maps = [a_map , b_transposed_map , c_map ],
667+ )
668+
669+ print (module )
670+
671+
571672# CHECK-LABEL: TEST: testPackUnPackOp
572673@run
573674def testPackUnPackOp ():
0 commit comments