@@ -606,3 +606,189 @@ def tensor_pack(src, dst):
606606 # CHECK: return %[[VAL_4]] : tensor<128x128xf32>
607607 # CHECK: }
608608 print (module )
609+
610+
611+ # CHECK-LABEL: TEST: testElementwiseOp
612+ @run
613+ def testElementwiseOp ():
614+ with Context (), Location .unknown ():
615+ module = Module .create ()
616+ f32 = F32Type .get ()
617+ with InsertionPoint (module .body ):
618+ rect_shape = (8 , 16 )
619+ vert_line_shape = (8 ,)
620+ hor_line_shape = (16 ,)
621+ transposed_rect_shape = (16 , 8 )
622+
623+ # CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)>
624+ # CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)>
625+ # CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)>
626+ # CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)>
627+
628+ ident_map_2d = AffineMap .get_identity (2 )
629+ transposed_map_2d = AffineMap .get_permutation ((1 , 0 ))
630+ vert_line_bcast_map = AffineMap .get (2 , 0 , [AffineDimExpr .get (0 )])
631+ hor_line_bcast_map = AffineMap .get (2 , 0 , [AffineDimExpr .get (1 )])
632+
633+ # CHECK: func.func @elementwise_op(
634+ @func .FuncOp .from_py_func (
635+ # CHECK-SAME: %[[Rect:.*]]: tensor<8x16xf32>,
636+ RankedTensorType .get (rect_shape , f32 ),
637+ # CHECK-SAME: %[[RectMem:.*]]: memref<8x16xf32>,
638+ MemRefType .get (rect_shape , f32 ),
639+ # CHECK-SAME: %[[VertLine:.*]]: tensor<8xf32>,
640+ RankedTensorType .get (vert_line_shape , f32 ),
641+ # CHECK-SAME: %[[VertLineMem:.*]]: memref<8xf32>,
642+ MemRefType .get (vert_line_shape , f32 ),
643+ # CHECK-SAME: %[[HorLine:.*]]: tensor<16xf32>,
644+ RankedTensorType .get (hor_line_shape , f32 ),
645+ # CHECK-SAME: %[[HorLineMem:.*]]: memref<16xf32>,
646+ MemRefType .get (hor_line_shape , f32 ),
647+ # CHECK-SAME: %[[TransRect:.*]]: tensor<16x8xf32>,
648+ RankedTensorType .get (transposed_rect_shape , f32 ),
649+ # CHECK-SAME: %[[TransRectMem:.*]]: memref<16x8xf32>)
650+ MemRefType .get (transposed_rect_shape , f32 ),
651+ )
652+ def elementwise_op (
653+ rect ,
654+ rect_mem ,
655+ vert_line ,
656+ vert_line_mem ,
657+ hor_line ,
658+ hor_line_mem ,
659+ trans_rect ,
660+ trans_rect_mem ,
661+ ):
662+ # CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32>
663+ out_rect = tensor .EmptyOp (rect_shape , f32 )
664+ # CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32>
665+ out_rect_mem = memref .alloca (MemRefType .get (rect_shape , f32 ), [], [])
666+
667+ if _inferred_affine_maps := True :
668+ # CHECK: linalg.elementwise
669+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
670+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
671+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
672+ op1 = linalg .ElementwiseOp (
673+ result_tensors = (out_rect .result .type ,),
674+ inputs = (rect ,),
675+ outputs = (out_rect ,),
676+ kind = linalg .ElementwiseKind .exp ,
677+ )
678+ linalg .fill_builtin_region (op1 .operation )
679+
680+ # CHECK: linalg.elementwise
681+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
682+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
683+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
684+ linalg .elementwise (
685+ rect ,
686+ outs = (out_rect ,),
687+ kind = linalg .ElementwiseKind .exp ,
688+ )
689+
690+ # CHECK: linalg.elementwise
691+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
692+ # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
693+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
694+ linalg .elementwise (
695+ rect_mem ,
696+ outs = (out_rect_mem ,),
697+ kind = linalg .ElementwiseKind .exp ,
698+ )
699+
700+ if _explicit_ident_affine_maps := True :
701+ # Same as above but with default identity indexing_maps explicitly provided.
702+ # CHECK: linalg.elementwise
703+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
704+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
705+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
706+ op3 = linalg .ElementwiseOp (
707+ result_tensors = (out_rect .result .type ,),
708+ inputs = (rect ,),
709+ outputs = (out_rect ,),
710+ kind = linalg .ElementwiseKind .exp ,
711+ indexing_maps = [ident_map_2d , ident_map_2d ],
712+ )
713+ linalg .fill_builtin_region (op3 .operation )
714+
715+ # CHECK: linalg.elementwise
716+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
717+ # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
718+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
719+ linalg .elementwise (
720+ rect_mem ,
721+ outs = (out_rect_mem ,),
722+ kind = linalg .ElementwiseKind .exp ,
723+ indexing_maps = [ident_map_2d , ident_map_2d ],
724+ )
725+
726+ if _ops_with_non_ident_input_maps := True :
727+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
728+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]]
729+ # CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>)
730+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
731+ op4 = linalg .ElementwiseOp (
732+ result_tensors = (out_rect .result .type ,),
733+ inputs = (vert_line ,),
734+ outputs = (out_rect ,),
735+ kind = linalg .ElementwiseKind .exp ,
736+ indexing_maps = [vert_line_bcast_map , ident_map_2d ],
737+ )
738+ linalg .fill_builtin_region (op4 .operation )
739+
740+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
741+ # CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]]
742+ # CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>)
743+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
744+ op4 = linalg .ElementwiseOp (
745+ result_tensors = (out_rect .result .type ,),
746+ inputs = (rect , vert_line ),
747+ outputs = (out_rect ,),
748+ kind = linalg .ElementwiseKind .add ,
749+ indexing_maps = [ident_map_2d , vert_line_bcast_map , ident_map_2d ],
750+ )
751+ linalg .fill_builtin_region (op4 .operation )
752+
753+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
754+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
755+ # CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
756+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
757+ linalg .elementwise (
758+ vert_line ,
759+ hor_line ,
760+ outs = (out_rect ,),
761+ kind = linalg .ElementwiseKind .div ,
762+ indexing_maps = [
763+ vert_line_bcast_map ,
764+ hor_line_bcast_map ,
765+ ident_map_2d ,
766+ ],
767+ )
768+
769+ if _ops_with_non_ident_and_transposed_input_maps := True :
770+ # CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
771+ vert_line_bools_mem = memref .alloca (
772+ MemRefType .get (vert_line_shape , IntegerType .get_signless (1 )),
773+ [],
774+ [],
775+ )
776+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<select>
777+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$TransMap2D]], #[[$IdentMap2D]]]
778+ # CHECK-SAME: ins(%[[VertLineBoolsMem]], %[[HorLineMem]], %[[TransRectMem]] : memref<8xi1>, memref<16xf32>, memref<16x8xf32>)
779+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
780+ linalg .elementwise (
781+ vert_line_bools_mem ,
782+ hor_line_mem ,
783+ trans_rect_mem ,
784+ outs = (out_rect_mem ,),
785+ kind = linalg .ElementwiseKind .select ,
786+ indexing_maps = [
787+ vert_line_bcast_map ,
788+ hor_line_bcast_map ,
789+ transposed_map_2d ,
790+ ident_map_2d ,
791+ ],
792+ )
793+
794+ print (module )
0 commit comments