@@ -321,47 +321,78 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
321321
322322// -----
323323
324- func.func @rank_reduced_extract_slice (%cond : i1 ) -> tensor <6 x2 xf32 > {
324+ #map0 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>
325+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d3 , d2 )>
326+ #map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 )>
327+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
328+ #map4 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
329+ #map5 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
330+ func.func @rank_reduced_extract_slice (%arg0: tensor <1 x6 x5 xf32 >, %arg1: tensor <1 x5 x6 xf32 >, %arg2: tensor <4 x6 xf32 >) -> tensor <4 x6 xf32 > {
331+ %c0 = arith.constant 0 : index
332+ %c2 = arith.constant 2 : index
333+ %c6 = arith.constant 6 : index
325334 %cst = arith.constant 0.0 : f32
326- %cst1 = arith.constant 1.0 : f32
327-
328- %empty1 = tensor.empty () : tensor <6 x6 x1 x1 x1 x1 xf32 >
329- %init1 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 , d3 , d4 , d5 ) -> (d0 , d1 , d2 , d3 , d4 , d5 )>], iterator_types = [" parallel" , " parallel" , " parallel" , " parallel" , " parallel" , " parallel" ]} outs (%empty1 : tensor <6 x6 x1 x1 x1 x1 xf32 >) {
330- ^bb0 (%out: f32 ):
331- linalg.yield %cst : f32
332- } -> tensor <6 x6 x1 x1 x1 x1 xf32 >
333-
334- %if = scf.if %cond -> tensor <6 x2 xf32 > {
335- %extract0 = tensor.extract_slice %init1 [0 , 0 , 0 , 0 , 0 , 0 ] [6 , 2 , 1 , 1 , 1 , 1 ] [1 , 1 , 1 , 1 , 1 , 1 ] : tensor <6 x6 x1 x1 x1 x1 xf32 > to tensor <6 x2 xf32 >
336-
337- %init2 = tensor.empty () : tensor <6 x2 xf32 >
338- %add1 = linalg.generic {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>], iterator_types = [" parallel" , " parallel" ]} ins (%extract0 : tensor <6 x2 xf32 >) outs (%init2 : tensor <6 x2 xf32 >) {
339- ^bb0 (%in: f32 , %out: f32 ):
340- %add = arith.addf %in , %cst1 : f32
341- linalg.yield %add : f32
342- } -> tensor <6 x2 xf32 >
343- scf.yield %add1 : tensor <6 x2 xf32 >
344- } else {
345- %extract2 = tensor.extract_slice %init1 [0 , 2 , 0 , 0 , 0 , 0 ] [6 , 2 , 1 , 1 , 1 , 1 ] [1 , 1 , 1 , 1 , 1 , 1 ] : tensor <6 x6 x1 x1 x1 x1 xf32 > to tensor <6 x2 xf32 >
346- scf.yield %extract2 : tensor <6 x2 xf32 >
335+ %init1 = tensor.empty () : tensor <1 x6 x6 xf32 >
336+ %fill1 = linalg.fill ins (%cst : f32 ) outs (%init1 : tensor <1 x6 x6 xf32 >) -> tensor <1 x6 x6 xf32 >
337+ %0 = linalg.generic
338+ {index ing_maps = [#map0 , #map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]}
339+ ins (%arg0 , %arg1 : tensor <1 x6 x5 xf32 >, tensor <1 x5 x6 xf32 >) outs (%fill1 : tensor <1 x6 x6 xf32 >) {
340+ ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
341+ %10 = arith.mulf %in , %in_1 : f32
342+ %11 = arith.addf %out , %10 : f32
343+ linalg.yield %11 : f32
344+ } -> tensor <1 x6 x6 xf32 >
345+ %init2 = tensor.empty () : tensor <4 x6 xf32 >
346+ %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args (%arg3 = %init2 ) -> (tensor <4 x6 xf32 >) {
347+ %2 = tensor.extract_slice %0 [0 , 0 , %arg4 ] [1 , 6 , 2 ] [1 , 1 , 1 ] : tensor <1 x6 x6 xf32 > to tensor <6 x2 xf32 >
348+ %init3 = tensor.empty () : tensor <4 x2 xf32 >
349+ %fill3 = linalg.fill ins (%cst : f32 ) outs (%init3 : tensor <4 x2 xf32 >) -> tensor <4 x2 xf32 >
350+ %3 = linalg.generic
351+ {index ing_maps = [#map3 , #map4 , #map5 ], iterator_types = [" parallel" , " parallel" , " reduction" ]}
352+ ins (%arg2 , %2 : tensor <4 x6 xf32 >, tensor <6 x2 xf32 >) outs (%fill3 : tensor <4 x2 xf32 >) {
353+ ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
354+ %20 = arith.mulf %in , %in_1 : f32
355+ %21 = arith.addf %out , %20 : f32
356+ linalg.yield %21 : f32
357+ } -> tensor <4 x2 xf32 >
358+ %4 = tensor.insert_slice %3 into %arg3 [0 , %arg4 ] [4 , 2 ] [1 , 1 ] : tensor <4 x2 xf32 > into tensor <4 x6 xf32 >
359+ scf.yield %4 : tensor <4 x6 xf32 >
347360 }
348-
349- return %if : tensor <6 x2 xf32 >
361+ return %1 : tensor <4 x6 xf32 >
350362}
351363
352364// CHECK: func @rank_reduced_extract_slice(
353- // CHECK-SAME: %[[COND:[0-9a-z]*]]: i1
354-
355- // CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
356- // CHECK: %[[FILL_PROD:.*]] = linalg.generic
357- // CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>)
365+ // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
366+ // CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
367+ // CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
358368
359- // CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<6x2xf32>
360- // CHECK: %[[EXTRACT_SLICE_CONS:.*]] = tensor.extract_slice %[[EMPTY_PROD]][0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
361-
362- // CHECK: %[[FILL_CONS:.*]] = linalg.generic
363- // CHECK-SAME: outs(%[[EXTRACT_SLICE_CONS]] : tensor<6x2x1x1x1x1xf32>)
364- // CHECK: %[[CONS_COLLAPSE:.*]] = tensor.collapse_shape %[[FILL_CONS]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
365- // CHECK: %[[ADD1_CONS:.*]] = linalg.generic
366- // CHECK-SAME: ins(%[[CONS_COLLAPSE]] : tensor<6x2xf32>)
367- // CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<6x2xf32>)
369+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
370+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
371+ // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
372+ // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
373+ // CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32>
374+ // CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
375+ // CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
376+ // CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
377+ // CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32)
378+
379+ // For loop right after tensor alloc & fill, no linalg.generic.
380+ // CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
381+
382+ // Producer linalg.generic now inside the loop, with tiled args sliced before
383+ // it.
384+ // CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385+ // CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
386+ // CHECK: %[[MMUL_PROD:.*]] = linalg.generic
387+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388+ // CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
389+ //
390+ // Consumer uses a rank-reduced version of producer result so a collapse_shape
391+ // is generated.
392+ // CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
393+ // CHECK: %[[MMUL_CONS:.*]] = linalg.generic
394+ // CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395+ // CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
396+ // CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
397+ // CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
398+ // CHECK: return %[[FOR]] : tensor<4x6xf32>
0 commit comments