@@ -84,6 +84,7 @@ def named_form(lhs, rhs):
8484
8585 print (module )
8686
87+
8788# CHECK-LABEL: TEST: testIdentityRegionOps
8889@run
8990def testIdentityRegionOps ():
@@ -161,3 +162,63 @@ def broadcast_op(op1, op2, op3):
161162 op5 = linalg .broadcast (op3 , outs = [op2 ], dimensions = [0 ])
162163
163164 print (module )
165+
166+
167+ # CHECK-LABEL: TEST: testGenericOp
168+ @run
169+ def testGenericOp ():
170+ with Context (), Location .unknown ():
171+ module = Module .create ()
172+ f32 = F32Type .get ()
173+ with InsertionPoint (module .body ):
174+ id_map = AffineMap .get_identity (2 )
175+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<16x16xf32>
176+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<16x16xf32>
177+ x = tensor .empty ((16 , 16 ), f32 )
178+ y = tensor .empty ((16 , 16 ), f32 )
179+
180+ # CHECK: %[[VAL_2:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_1]] : tensor<16x16xf32>) {
181+ # CHECK: ^bb0(%in: f32, %out: f32):
182+ # CHECK: linalg.yield %in : f32
183+ # CHECK: } -> tensor<16x16xf32>
184+ @linalg .generic (
185+ [x ],
186+ [y ],
187+ [id_map , id_map ],
188+ [linalg .IteratorType .parallel , linalg .IteratorType .parallel ],
189+ )
190+ def f (x , y ):
191+ return x
192+
193+ assert isinstance (f , Value )
194+ assert isinstance (f .type , RankedTensorType )
195+
196+ # CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<16x16x16xf32>
197+ z = tensor .empty ((16 , 16 , 16 ), f32 )
198+
199+ minor_id = AffineMap .get_minor_identity (3 , 2 )
200+ id_map = AffineMap .get_identity (3 )
201+
202+ # CHECK: %[[VAL_4:.+]]:2 = linalg.generic {indexing_maps = [#map1, #map2, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%[[VAL_0]] : tensor<16x16xf32>) outs(%[[VAL_3]], %[[VAL_3]] : tensor<16x16x16xf32>, tensor<16x16x16xf32>) {
203+ # CHECK: ^bb0(%in: f32, %out: f32, %out_0: f32):
204+ # CHECK: linalg.yield %in, %out : f32, f32
205+ # CHECK: } -> (tensor<16x16x16xf32>, tensor<16x16x16xf32>)
206+ @linalg .generic (
207+ [x ],
208+ [z , z ],
209+ [minor_id , id_map , id_map ],
210+ [
211+ linalg .IteratorType .parallel ,
212+ linalg .IteratorType .parallel ,
213+ linalg .IteratorType .parallel ,
214+ ],
215+ )
216+ def g (x , z1 , z2 ):
217+ return x , z1
218+
219+ assert isinstance (g , OpResultList )
220+ assert len (g ) == 2
221+ assert isinstance (g [0 ].type , RankedTensorType )
222+ assert isinstance (g [1 ].type , RankedTensorType )
223+
224+ print (module )
0 commit comments