@@ -327,35 +327,32 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
327327#map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
328328#map4 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
329329#map5 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
330- func.func @rank_reduced_extract_slice (%arg0: tensor <1 x6 x5 xf32 >, %arg1: tensor <1 x5 x6 xf32 >, %arg2: tensor <4 x6 xf32 >) -> tensor <4 x6 xf32 > {
330+ func.func @rank_reduced_extract_slice (
331+ %arg0: tensor <1 x6 x5 xf32 >, %arg1: tensor <1 x5 x6 xf32 >, %arg2: tensor <4 x6 xf32 >,
332+ %arg3: tensor <1 x6 x6 xf32 >, %arg4: tensor <4 x6 xf32 >, %arg5: tensor <4 x2 xf32 >
333+ ) -> tensor <4 x6 xf32 > {
331334 %c0 = arith.constant 0 : index
332335 %c2 = arith.constant 2 : index
333336 %c6 = arith.constant 6 : index
334- %cst = arith.constant 0.0 : f32
335- %init1 = tensor.empty () : tensor <1 x6 x6 xf32 >
336- %fill1 = linalg.fill ins (%cst : f32 ) outs (%init1 : tensor <1 x6 x6 xf32 >) -> tensor <1 x6 x6 xf32 >
337337 %0 = linalg.generic
338338 {index ing_maps = [#map0 , #map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]}
339- ins (%arg0 , %arg1 : tensor <1 x6 x5 xf32 >, tensor <1 x5 x6 xf32 >) outs (%fill1 : tensor <1 x6 x6 xf32 >) {
339+ ins (%arg0 , %arg1 : tensor <1 x6 x5 xf32 >, tensor <1 x5 x6 xf32 >) outs (%arg3 : tensor <1 x6 x6 xf32 >) {
340340 ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
341341 %10 = arith.mulf %in , %in_1 : f32
342342 %11 = arith.addf %out , %10 : f32
343343 linalg.yield %11 : f32
344344 } -> tensor <1 x6 x6 xf32 >
345- %init2 = tensor.empty () : tensor <4 x6 xf32 >
346- %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args (%arg3 = %init2 ) -> (tensor <4 x6 xf32 >) {
347- %2 = tensor.extract_slice %0 [0 , 0 , %arg4 ] [1 , 6 , 2 ] [1 , 1 , 1 ] : tensor <1 x6 x6 xf32 > to tensor <6 x2 xf32 >
348- %init3 = tensor.empty () : tensor <4 x2 xf32 >
349- %fill3 = linalg.fill ins (%cst : f32 ) outs (%init3 : tensor <4 x2 xf32 >) -> tensor <4 x2 xf32 >
345+ %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args (%arg6 = %arg4 ) -> (tensor <4 x6 xf32 >) {
346+ %2 = tensor.extract_slice %0 [0 , 0 , %arg7 ] [1 , 6 , 2 ] [1 , 1 , 1 ] : tensor <1 x6 x6 xf32 > to tensor <6 x2 xf32 >
350347 %3 = linalg.generic
351348 {index ing_maps = [#map3 , #map4 , #map5 ], iterator_types = [" parallel" , " parallel" , " reduction" ]}
352- ins (%arg2 , %2 : tensor <4 x6 xf32 >, tensor <6 x2 xf32 >) outs (%fill3 : tensor <4 x2 xf32 >) {
349+ ins (%arg2 , %2 : tensor <4 x6 xf32 >, tensor <6 x2 xf32 >) outs (%arg5 : tensor <4 x2 xf32 >) {
353350 ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
354351 %20 = arith.mulf %in , %in_1 : f32
355352 %21 = arith.addf %out , %20 : f32
356353 linalg.yield %21 : f32
357354 } -> tensor <4 x2 xf32 >
358- %4 = tensor.insert_slice %3 into %arg3 [0 , %arg4 ] [4 , 2 ] [1 , 1 ] : tensor <4 x2 xf32 > into tensor <4 x6 xf32 >
355+ %4 = tensor.insert_slice %3 into %arg6 [0 , %arg7 ] [4 , 2 ] [1 , 1 ] : tensor <4 x2 xf32 > into tensor <4 x6 xf32 >
359356 scf.yield %4 : tensor <4 x6 xf32 >
360357 }
361358 return %1 : tensor <4 x6 xf32 >
@@ -365,24 +362,22 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x
365362// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
366363// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
367364// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
365+ // CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32>
366+ // CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32>
367+ // CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32>
368368
369369// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
370370// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
371371// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
372- // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
373- // CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32>
374- // CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
375- // CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
376- // CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
377- // CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32)
378372
379373// For loop right after tensor alloc & fill, no linalg.generic.
380- // CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
374+ // CHECK-NOT: linalg.generic
375+ // CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]])
381376
382377// Producer linalg.generic now inside the loop, with tiled args sliced before
383378// it.
384379// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385- // CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD ]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
380+ // CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3 ]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
386381// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
387382// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388383// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
@@ -392,7 +387,7 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x
392387// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
393388// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
394389// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395- // CHECK-SAME: outs(%[[FILL_CONS ]] : tensor<4x2xf32>)
390+ // CHECK-SAME: outs(%[[ARG5 ]] : tensor<4x2xf32>)
396391// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
397392// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
398393// CHECK: return %[[FOR]] : tensor<4x6xf32>
0 commit comments