@@ -328,66 +328,71 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
328328#map4 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
329329#map5 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
330330func.func @rank_reduced_extract_slice (
331- %arg0: tensor <1 x6 x5 xf32 >, %arg1: tensor <1 x5 x6 xf32 >, %arg2: tensor <4 x6 xf32 >,
332- %arg3: tensor <1 x6 x6 xf32 >, %arg4: tensor <4 x6 xf32 >, %arg5: tensor <4 x2 xf32 >
331+ %prod_in: tensor <1 x6 x5 xf32 >, %prod_weight: tensor <1 x5 x6 xf32 >,
332+ %cons_in: tensor <4 x6 xf32 >, %prod_init: tensor <1 x6 x6 xf32 >,
333+ %for_iv_init: tensor <4 x6 xf32 >, %cons_init: tensor <4 x2 xf32 >
333334) -> tensor <4 x6 xf32 > {
334335 %c0 = arith.constant 0 : index
335336 %c2 = arith.constant 2 : index
336337 %c6 = arith.constant 6 : index
337- %0 = linalg.generic
338+ %mmul_prod = linalg.generic
338339 {index ing_maps = [#map0 , #map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]}
339- ins (%arg0 , %arg1 : tensor <1 x6 x5 xf32 >, tensor <1 x5 x6 xf32 >) outs (%arg3 : tensor <1 x6 x6 xf32 >) {
340+ ins (%prod_in , %prod_weight : tensor <1 x6 x5 xf32 >, tensor <1 x5 x6 xf32 >) outs (%prod_init : tensor <1 x6 x6 xf32 >) {
340341 ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
341342 %10 = arith.mulf %in , %in_1 : f32
342343 %11 = arith.addf %out , %10 : f32
343344 linalg.yield %11 : f32
344345 } -> tensor <1 x6 x6 xf32 >
345- %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args (%arg6 = %arg4 ) -> (tensor <4 x6 xf32 >) {
346- %2 = tensor.extract_slice %0 [0 , 0 , %arg7 ] [1 , 6 , 2 ] [1 , 1 , 1 ] : tensor <1 x6 x6 xf32 > to tensor <6 x2 xf32 >
347- %3 = linalg.generic
346+ %for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args (%arg6 = %for_iv_init ) -> (tensor <4 x6 xf32 >) {
347+
348+ // Extract slice with rank-reduced result type. When fused in the loop
349+ // with sliced operands, the producer linalg must have its now sliced
350+ // result be rank-reduced as well to match consumer's use type.
351+ %prod_slice = tensor.extract_slice %mmul_prod [0 , 0 , %arg7 ] [1 , 6 , 2 ] [1 , 1 , 1 ] : tensor <1 x6 x6 xf32 > to tensor <6 x2 xf32 >
352+ %mmul_cons = linalg.generic
348353 {index ing_maps = [#map3 , #map4 , #map5 ], iterator_types = [" parallel" , " parallel" , " reduction" ]}
349- ins (%arg2 , %2 : tensor <4 x6 xf32 >, tensor <6 x2 xf32 >) outs (%arg5 : tensor <4 x2 xf32 >) {
354+ ins (%cons_in , %prod_slice : tensor <4 x6 xf32 >, tensor <6 x2 xf32 >) outs (%cons_init : tensor <4 x2 xf32 >) {
350355 ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
351356 %20 = arith.mulf %in , %in_1 : f32
352357 %21 = arith.addf %out , %20 : f32
353358 linalg.yield %21 : f32
354359 } -> tensor <4 x2 xf32 >
355- %4 = tensor.insert_slice %3 into %arg6 [0 , %arg7 ] [4 , 2 ] [1 , 1 ] : tensor <4 x2 xf32 > into tensor <4 x6 xf32 >
360+ %4 = tensor.insert_slice %mmul_cons into %arg6 [0 , %arg7 ] [4 , 2 ] [1 , 1 ] : tensor <4 x2 xf32 > into tensor <4 x6 xf32 >
356361 scf.yield %4 : tensor <4 x6 xf32 >
357362 }
358- return %1 : tensor <4 x6 xf32 >
363+ return %for : tensor <4 x6 xf32 >
359364}
360365
361366// CHECK: func @rank_reduced_extract_slice(
362- // CHECK-SAME: %[[ARG0 :[0-9a-z]*]]: tensor<1x6x5xf32>
363- // CHECK-SAME: %[[ARG1 :[0-9a-z]*]]: tensor<1x5x6xf32>
364- // CHECK-SAME: %[[ARG2 :[0-9a-z]*]]: tensor<4x6xf32>
365- // CHECK-SAME: %[[ARG3 :[0-9a-z]*]]: tensor<1x6x6xf32>
366- // CHECK-SAME: %[[ARG4 :[0-9a-z]*]]: tensor<4x6xf32>
367- // CHECK-SAME: %[[ARG5 :[0-9a-z]*]]: tensor<4x2xf32>
367+ // CHECK-SAME: %[[PROD_IN :[0-9a-z]*]]: tensor<1x6x5xf32>
368+ // CHECK-SAME: %[[PROD_WEIGHT :[0-9a-z]*]]: tensor<1x5x6xf32>
369+ // CHECK-SAME: %[[CONS_IN :[0-9a-z]*]]: tensor<4x6xf32>
370+ // CHECK-SAME: %[[PROD_INIT :[0-9a-z]*]]: tensor<1x6x6xf32>
371+ // CHECK-SAME: %[[FOR_IV_INIT :[0-9a-z]*]]: tensor<4x6xf32>
372+ // CHECK-SAME: %[[CONS_INIT :[0-9a-z]*]]: tensor<4x2xf32>
368373
369374// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
370375// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
371376// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
372377
373378// For loop right after tensor alloc & fill, no linalg.generic.
374379// CHECK-NOT: linalg.generic
375- // CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4 ]])
380+ // CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT ]])
376381
377382// Producer linalg.generic now inside the loop, with tiled args sliced before
378383// it.
379- // CHECK-DAG: %[[ARG1_SLICE :.*]] = tensor.extract_slice %[[ARG1 ]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
380- // CHECK-DAG: %[[PROD_SLICE :.*]] = tensor.extract_slice %[[ARG3 ]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
384+ // CHECK-DAG: %[[PROD_WEIGHT_SLICE :.*]] = tensor.extract_slice %[[PROD_WEIGHT ]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385+ // CHECK-DAG: %[[PROD_INIT_SLICE :.*]] = tensor.extract_slice %[[PROD_INIT ]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
381386// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
382- // CHECK-SAME: ins(%[[ARG0 ]], %[[ARG1_SLICE ]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
383- // CHECK-SAME: outs(%[[PROD_SLICE ]] : tensor<1x6x2xf32>)
387+ // CHECK-SAME: ins(%[[PROD_IN ]], %[[PROD_WEIGHT_SLICE ]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388+ // CHECK-SAME: outs(%[[PROD_INIT_SLICE ]] : tensor<1x6x2xf32>)
384389//
385390// Consumer uses a rank-reduced version of producer result so a collapse_shape
386391// is generated.
387392// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
388393// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
389- // CHECK-SAME: ins(%[[ARG2 ]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
390- // CHECK-SAME: outs(%[[ARG5 ]] : tensor<4x2xf32>)
394+ // CHECK-SAME: ins(%[[CONS_IN ]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395+ // CHECK-SAME: outs(%[[CONS_INIT ]] : tensor<4x2xf32>)
391396// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
392397// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
393398// CHECK: return %[[FOR]] : tensor<4x6xf32>
0 commit comments