@@ -606,3 +606,189 @@ def tensor_pack(src, dst):
606
606
# CHECK: return %[[VAL_4]] : tensor<128x128xf32>
607
607
# CHECK: }
608
608
print (module )
609
+
610
+
611
+ # CHECK-LABEL: TEST: testElementwiseOp
612
+ @run
613
+ def testElementwiseOp ():
614
+ with Context (), Location .unknown ():
615
+ module = Module .create ()
616
+ f32 = F32Type .get ()
617
+ with InsertionPoint (module .body ):
618
+ rect_shape = (8 , 16 )
619
+ vert_line_shape = (8 ,)
620
+ hor_line_shape = (16 ,)
621
+ transposed_rect_shape = (16 , 8 )
622
+
623
+ # CHECK-DAG: #[[$IdentMap2D:.*]] = affine_map<(d0, d1) -> (d0, d1)>
624
+ # CHECK-DAG: #[[$TransMap2D:.*]] = affine_map<(d0, d1) -> (d1, d0)>
625
+ # CHECK-DAG: #[[$VertLineBCastMap:.*]] = affine_map<(d0, d1) -> (d0)>
626
+ # CHECK-DAG: #[[$HorLineBCastMap:.*]] = affine_map<(d0, d1) -> (d1)>
627
+
628
+ ident_map_2d = AffineMap .get_identity (2 )
629
+ transposed_map_2d = AffineMap .get_permutation ((1 , 0 ))
630
+ vert_line_bcast_map = AffineMap .get (2 , 0 , [AffineDimExpr .get (0 )])
631
+ hor_line_bcast_map = AffineMap .get (2 , 0 , [AffineDimExpr .get (1 )])
632
+
633
+ # CHECK: func.func @elementwise_op(
634
+ @func .FuncOp .from_py_func (
635
+ # CHECK-SAME: %[[Rect:.*]]: tensor<8x16xf32>,
636
+ RankedTensorType .get (rect_shape , f32 ),
637
+ # CHECK-SAME: %[[RectMem:.*]]: memref<8x16xf32>,
638
+ MemRefType .get (rect_shape , f32 ),
639
+ # CHECK-SAME: %[[VertLine:.*]]: tensor<8xf32>,
640
+ RankedTensorType .get (vert_line_shape , f32 ),
641
+ # CHECK-SAME: %[[VertLineMem:.*]]: memref<8xf32>,
642
+ MemRefType .get (vert_line_shape , f32 ),
643
+ # CHECK-SAME: %[[HorLine:.*]]: tensor<16xf32>,
644
+ RankedTensorType .get (hor_line_shape , f32 ),
645
+ # CHECK-SAME: %[[HorLineMem:.*]]: memref<16xf32>,
646
+ MemRefType .get (hor_line_shape , f32 ),
647
+ # CHECK-SAME: %[[TransRect:.*]]: tensor<16x8xf32>,
648
+ RankedTensorType .get (transposed_rect_shape , f32 ),
649
+ # CHECK-SAME: %[[TransRectMem:.*]]: memref<16x8xf32>)
650
+ MemRefType .get (transposed_rect_shape , f32 ),
651
+ )
652
+ def elementwise_op (
653
+ rect ,
654
+ rect_mem ,
655
+ vert_line ,
656
+ vert_line_mem ,
657
+ hor_line ,
658
+ hor_line_mem ,
659
+ trans_rect ,
660
+ trans_rect_mem ,
661
+ ):
662
+ # CHECK: %[[OutRect:.*]] = tensor.empty() : tensor<8x16xf32>
663
+ out_rect = tensor .EmptyOp (rect_shape , f32 )
664
+ # CHECK: %[[OutRectMem:.*]] = memref.alloca() : memref<8x16xf32>
665
+ out_rect_mem = memref .alloca (MemRefType .get (rect_shape , f32 ), [], [])
666
+
667
+ if _inferred_affine_maps := True :
668
+ # CHECK: linalg.elementwise
669
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
670
+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
671
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
672
+ op1 = linalg .ElementwiseOp (
673
+ result_tensors = (out_rect .result .type ,),
674
+ inputs = (rect ,),
675
+ outputs = (out_rect ,),
676
+ kind = linalg .ElementwiseKind .exp ,
677
+ )
678
+ linalg .fill_builtin_region (op1 .operation )
679
+
680
+ # CHECK: linalg.elementwise
681
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
682
+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
683
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
684
+ linalg .elementwise (
685
+ rect ,
686
+ outs = (out_rect ,),
687
+ kind = linalg .ElementwiseKind .exp ,
688
+ )
689
+
690
+ # CHECK: linalg.elementwise
691
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
692
+ # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
693
+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
694
+ linalg .elementwise (
695
+ rect_mem ,
696
+ outs = (out_rect_mem ,),
697
+ kind = linalg .ElementwiseKind .exp ,
698
+ )
699
+
700
+ if _explicit_ident_affine_maps := True :
701
+ # Same as above but with default identity indexing_maps explicitly provided.
702
+ # CHECK: linalg.elementwise
703
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
704
+ # CHECK-SAME: ins(%[[Rect]] : tensor<8x16xf32>)
705
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
706
+ op3 = linalg .ElementwiseOp (
707
+ result_tensors = (out_rect .result .type ,),
708
+ inputs = (rect ,),
709
+ outputs = (out_rect ,),
710
+ kind = linalg .ElementwiseKind .exp ,
711
+ indexing_maps = [ident_map_2d , ident_map_2d ],
712
+ )
713
+ linalg .fill_builtin_region (op3 .operation )
714
+
715
+ # CHECK: linalg.elementwise
716
+ # CHECK-SAME: kind=#linalg.elementwise_kind<exp>
717
+ # CHECK-SAME: ins(%[[RectMem]] : memref<8x16xf32>)
718
+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
719
+ linalg .elementwise (
720
+ rect_mem ,
721
+ outs = (out_rect_mem ,),
722
+ kind = linalg .ElementwiseKind .exp ,
723
+ indexing_maps = [ident_map_2d , ident_map_2d ],
724
+ )
725
+
726
+ if _ops_with_non_ident_input_maps := True :
727
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<exp>
728
+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$IdentMap2D]]]
729
+ # CHECK-SAME: ins(%[[VertLine]] : tensor<8xf32>)
730
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
731
+ op4 = linalg .ElementwiseOp (
732
+ result_tensors = (out_rect .result .type ,),
733
+ inputs = (vert_line ,),
734
+ outputs = (out_rect ,),
735
+ kind = linalg .ElementwiseKind .exp ,
736
+ indexing_maps = [vert_line_bcast_map , ident_map_2d ],
737
+ )
738
+ linalg .fill_builtin_region (op4 .operation )
739
+
740
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<add>
741
+ # CHECK-SAME: indexing_maps = [#[[$IdentMap2D]], #[[$VertLineBCastMap]], #[[$IdentMap2D]]]
742
+ # CHECK-SAME: ins(%[[Rect]], %[[VertLine]] : tensor<8x16xf32>, tensor<8xf32>)
743
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
744
+ op4 = linalg .ElementwiseOp (
745
+ result_tensors = (out_rect .result .type ,),
746
+ inputs = (rect , vert_line ),
747
+ outputs = (out_rect ,),
748
+ kind = linalg .ElementwiseKind .add ,
749
+ indexing_maps = [ident_map_2d , vert_line_bcast_map , ident_map_2d ],
750
+ )
751
+ linalg .fill_builtin_region (op4 .operation )
752
+
753
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<div>
754
+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
755
+ # CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
756
+ # CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
757
+ linalg .elementwise (
758
+ vert_line ,
759
+ hor_line ,
760
+ outs = (out_rect ,),
761
+ kind = linalg .ElementwiseKind .div ,
762
+ indexing_maps = [
763
+ vert_line_bcast_map ,
764
+ hor_line_bcast_map ,
765
+ ident_map_2d ,
766
+ ],
767
+ )
768
+
769
+ if _ops_with_non_ident_and_transposed_input_maps := True :
770
+ # CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
771
+ vert_line_bools_mem = memref .alloca (
772
+ MemRefType .get (vert_line_shape , IntegerType .get_signless (1 )),
773
+ [],
774
+ [],
775
+ )
776
+ # CHECK: linalg.elementwise kind=#linalg.elementwise_kind<select>
777
+ # CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$TransMap2D]], #[[$IdentMap2D]]]
778
+ # CHECK-SAME: ins(%[[VertLineBoolsMem]], %[[HorLineMem]], %[[TransRectMem]] : memref<8xi1>, memref<16xf32>, memref<16x8xf32>)
779
+ # CHECK-SAME: outs(%[[OutRectMem]] : memref<8x16xf32>)
780
+ linalg .elementwise (
781
+ vert_line_bools_mem ,
782
+ hor_line_mem ,
783
+ trans_rect_mem ,
784
+ outs = (out_rect_mem ,),
785
+ kind = linalg .ElementwiseKind .select ,
786
+ indexing_maps = [
787
+ vert_line_bcast_map ,
788
+ hor_line_bcast_map ,
789
+ transposed_map_2d ,
790
+ ident_map_2d ,
791
+ ],
792
+ )
793
+
794
+ print (module )
0 commit comments