Skip to content

Commit 7acedcf

Browse files
committed
-Added python test for linalg.batch_reduce_matmul
1 parent 7a7f7c3 commit 7acedcf

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ def batch_matmul(
203203
)
204204

205205

206+
def batch_reduce_matmul(
207+
*ins: Union[Operation, OpView, Value],
208+
outs: Sequence[Union[Operation, OpView, Value]],
209+
indexing_maps: Optional[Sequence[AffineMapAttr]] = None,
210+
cast: Optional[Union[TypeFn, Attribute]] = None,
211+
):
212+
return _get_op_result_or_op_results(
213+
_create_matmul_like_op(
214+
BatchReduceMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast
215+
)
216+
)
217+
218+
206219
def contract(
207220
*ins: Union[Operation, OpView, Value],
208221
outs: Sequence[Union[Operation, OpView, Value]],

mlir/test/python/dialects/linalg/ops.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,107 @@ def batch_matmul_op(A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem):
568568
print(module)
569569

570570

571+
# CHECK-LABEL: TEST: testBatchReduceMatmulOp
572+
@run
573+
def testBatchReduceMatmulOp():
574+
with Context(), Location.unknown():
575+
module = Module.create()
576+
f32 = F32Type.get()
577+
with InsertionPoint(module.body):
578+
a_shape = (5, 4, 8)
579+
b_shape = (5, 8, 12)
580+
b_transposed_shape = (5, 12, 8)
581+
c_shape = (4, 12)
582+
583+
dimBatch = ir.AffineDimExpr.get(0)
584+
dimM = ir.AffineDimExpr.get(1)
585+
dimN = ir.AffineDimExpr.get(2)
586+
dimK = ir.AffineDimExpr.get(3)
587+
588+
# CHECK: #[[$A_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
589+
# CHECK: #[[$BTrans_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>
590+
# CHECK: #[[$C_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
591+
a_map = ir.AffineMap.get(4, 0, [dimBatch, dimM, dimK])
592+
b_transposed_map = ir.AffineMap.get(4, 0, [dimBatch, dimN, dimK])
593+
c_map = ir.AffineMap.get(4, 0, [dimM, dimN])
594+
595+
# CHECK: func.func @batch_reduce_matmul_op(
596+
@func.FuncOp.from_py_func(
597+
# CHECK-SAME: %[[A:.*]]: tensor<5x4x8xf32>,
598+
RankedTensorType.get(a_shape, f32),
599+
# CHECK-SAME: %[[Amem:.*]]: memref<5x4x8xf32>,
600+
MemRefType.get(a_shape, f32),
601+
# CHECK-SAME: %[[B:.*]]: tensor<5x8x12xf32>,
602+
RankedTensorType.get(b_shape, f32),
603+
# CHECK-SAME: %[[Bmem:.*]]: memref<5x8x12xf32>,
604+
MemRefType.get(b_shape, f32),
605+
# CHECK-SAME: %[[BTrans:.*]]: tensor<5x12x8xf32>,
606+
RankedTensorType.get(b_transposed_shape, f32),
607+
# CHECK-SAME: %[[BTransmem:.*]]: memref<5x12x8xf32>,
608+
MemRefType.get(b_transposed_shape, f32),
609+
# CHECK-SAME: %[[C:.*]]: tensor<4x12xf32>,
610+
RankedTensorType.get(c_shape, f32),
611+
# CHECK-SAME: %[[Cmem:.*]]: memref<4x12xf32>)
612+
MemRefType.get(c_shape, f32),
613+
)
614+
def batch_reduce_matmul_op(
615+
A, Amem, B, Bmem, Btransposed, Btransposedmem, C, Cmem
616+
):
617+
# CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
618+
res = linalg.BatchReduceMatmulOp(
619+
result_tensors=(C.type,),
620+
inputs=(A, B),
621+
outputs=(C,),
622+
)
623+
linalg.fill_builtin_region(res.operation)
624+
# CHECK: linalg.batch_reduce_matmul ins(%[[A]], %[[B]] : tensor<5x4x8xf32>, tensor<5x8x12xf32>) outs(%[[C]] : tensor<4x12xf32>)
625+
res = linalg.batch_reduce_matmul(A, B, outs=(C,))
626+
627+
# CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
628+
res = linalg.BatchReduceMatmulOp(
629+
result_tensors=(C.type,),
630+
inputs=(A, Btransposed),
631+
outputs=(C,),
632+
indexing_maps=[a_map, b_transposed_map, c_map],
633+
)
634+
linalg.fill_builtin_region(res.operation)
635+
# CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[A]], %[[BTrans]] : tensor<5x4x8xf32>, tensor<5x12x8xf32>) outs(%[[C]] : tensor<4x12xf32>)
636+
res = linalg.batch_reduce_matmul(
637+
A,
638+
Btransposed,
639+
outs=(C,),
640+
indexing_maps=[a_map, b_transposed_map, c_map],
641+
)
642+
643+
# CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
644+
res = linalg.BatchReduceMatmulOp(
645+
result_tensors=[],
646+
inputs=(Amem, Bmem),
647+
outputs=(Cmem,),
648+
)
649+
linalg.fill_builtin_region(res.operation)
650+
# CHECK: linalg.batch_reduce_matmul ins(%[[Amem]], %[[Bmem]] : memref<5x4x8xf32>, memref<5x8x12xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
651+
linalg.batch_reduce_matmul(Amem, Bmem, outs=(Cmem,))
652+
653+
# CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
654+
res = linalg.BatchReduceMatmulOp(
655+
result_tensors=[],
656+
inputs=(Amem, Btransposedmem),
657+
outputs=(Cmem,),
658+
indexing_maps=[a_map, b_transposed_map, c_map],
659+
)
660+
linalg.fill_builtin_region(res.operation)
661+
# CHECK: linalg.batch_reduce_matmul indexing_maps = [#[[$A_MAP]], #[[$BTrans_MAP]], #[[$C_MAP]]] ins(%[[Amem]], %[[BTransmem]] : memref<5x4x8xf32>, memref<5x12x8xf32>) outs(%[[Cmem]] : memref<4x12xf32>)
662+
linalg.batch_reduce_matmul(
663+
Amem,
664+
Btransposedmem,
665+
outs=(Cmem,),
666+
indexing_maps=[a_map, b_transposed_map, c_map],
667+
)
668+
669+
print(module)
670+
671+
571672
# CHECK-LABEL: TEST: testPackUnPackOp
572673
@run
573674
def testPackUnPackOp():

0 commit comments

Comments
 (0)