@@ -41,3 +41,47 @@ func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tens
4141 outs (%C: tensor <?x?xf32 >) -> tensor <?x?xf32 >
4242 return %result : tensor <?x?xf32 >
4343}
44+
45+ // -----
46+
47+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
48+ // CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
49+ //
50+ // CHECK: func.func @unary_broadcasted(%[[A:.+]]: tensor<8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
51+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
52+ // CHECK-SAME: indexing_maps = [#[[BROADCASTED]], #[[IDENTITY]]]
53+ // CHECK-SAME: ins(%[[A]] : tensor<8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
54+ // CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
55+ //
56+ func.func @unary_broadcasted (%A : tensor <8 x32 xf32 >, %B: tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 > {
57+ %empty = tensor.empty () : tensor <8 x16 x32 xf32 >
58+ %broadcasted_A = linalg.broadcast ins (%A : tensor <8 x32 xf32 >) outs (%empty : tensor <8 x16 x32 xf32 >) dimensions = [1 ]
59+ %result = linalg.elementwise kind =#linalg.elementwise_kind <exp >
60+ ins (%broadcasted_A : tensor <8 x16 x32 xf32 >) outs (%B: tensor <8 x16 x32 xf32 >) -> tensor <8 x16 x32 xf32 >
61+ return %result : tensor <8 x16 x32 xf32 >
62+ }
63+
64+ // -----
65+
66+ // CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
67+ // CHECK-DAG: #[[BROADCASTED:.+]] = affine_map<(d0, d1) -> (d0)>
68+ //
69+ // CHECK: func.func @binary_broadcasted(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
70+ // CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
71+ // CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[BROADCASTED]], #[[IDENTITY]]]
72+ // CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
73+ // CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
74+ //
75+ func.func @binary_broadcasted (%A : tensor <?x?xf32 >, %B: tensor <?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
76+ %c0 = arith.constant 0 : index
77+ %c1 = arith.constant 1 : index
78+ %dim0 = tensor.dim %A , %c0 : tensor <?x?xf32 >
79+ %dim1 = tensor.dim %A , %c1 : tensor <?x?xf32 >
80+
81+ %empty = tensor.empty (%dim1 , %dim0 ) : tensor <?x?xf32 >
82+ %broadcasted_B = linalg.broadcast ins (%B : tensor <?xf32 >) outs (%empty : tensor <?x?xf32 >) dimensions = [1 ]
83+ %result = linalg.elementwise kind =#linalg.elementwise_kind <add >
84+ ins (%A , %broadcasted_B : tensor <?x?xf32 >, tensor <?x?xf32 >)
85+ outs (%C: tensor <?x?xf32 >) -> tensor <?x?xf32 >
86+ return %result : tensor <?x?xf32 >
87+ }
0 commit comments