@@ -309,3 +309,81 @@ func.func @hoist_pack_unpack_multiple_loop(%arg0 : tensor<1x1x4x2x16x16xbf16>, %
309309// CHECK: scf.yield %[[INNER_FOR_RESULT]] : tensor<1x1x4x2x16x16xf32>
310310// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x64x32xf32>
311311// CHECK: linalg.unpack %[[OUTER_FOR_RESULT]] inner_dims_pos = [2, 3] inner_tiles = [16, 16] into %[[EMPTY:.+]] : tensor<1x1x4x2x16x16xf32> -> tensor<1x1x64x32xf32>
312+
313+ // -----
314+
315+ func.func @propagate_extract_basic (%arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
316+ %empty = tensor.empty () : tensor <128 xf32 >
317+ %extracted_slice = tensor.extract_slice %empty [%arg0 ] [%arg0 ] [1 ] : tensor <128 xf32 > to tensor <?xf32 >
318+ %generic = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%extracted_slice : tensor <?xf32 >) outs (%arg1 : tensor <?xbf16 >) {
319+ ^bb0 (%in: f32 , %out: bf16 ):
320+ %1 = arith.truncf %in : f32 to bf16
321+ linalg.yield %1 : bf16
322+ } -> tensor <?xbf16 >
323+ return %generic : tensor <?xbf16 >
324+ }
325+
326+ // CHECK-LABEL: func.func @propagate_extract_basic
327+ // CHECK: linalg.generic
328+ // CHECK: tensor.extract_slice
329+
330+ // -----
331+
332+ func.func @no_propagate_extract_blockargument (%input : tensor <128 xf32 >, %arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
333+ %extracted_slice = tensor.extract_slice %input [%arg0 ] [%arg0 ] [1 ] : tensor <128 xf32 > to tensor <?xf32 >
334+ %generic = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%extracted_slice : tensor <?xf32 >) outs (%arg1 : tensor <?xbf16 >) {
335+ ^bb0 (%in: f32 , %out: bf16 ):
336+ %1 = arith.truncf %in : f32 to bf16
337+ linalg.yield %1 : bf16
338+ } -> tensor <?xbf16 >
339+ return %generic : tensor <?xbf16 >
340+ }
341+
342+ // CHECK-LABEL: func.func @no_propagate_extract_blockargument
343+ // CHECK: tensor.extract_slice
344+ // CHECK: linalg.generic
345+
346+
347+ // -----
348+
349+ func.func @no_propagate_extract_differentblock_1 (%arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
350+ %empty = tensor.empty () : tensor <128 xf32 >
351+ %c0 = arith.constant 0 : index
352+ %c32 = arith.constant 32 : index
353+ %for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args (%arg4 = %arg1 ) -> tensor <?xbf16 > {
354+ %extracted_slice = tensor.extract_slice %empty [%arg0 ] [%arg0 ] [1 ] : tensor <128 xf32 > to tensor <?xf32 >
355+ %generic = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%extracted_slice : tensor <?xf32 >) outs (%arg4 : tensor <?xbf16 >) {
356+ ^bb0 (%in: f32 , %out: bf16 ):
357+ %1 = arith.truncf %in : f32 to bf16
358+ linalg.yield %1 : bf16
359+ } -> tensor <?xbf16 >
360+ scf.yield %generic : tensor <?xbf16 >
361+ }
362+ return %for : tensor <?xbf16 >
363+ }
364+
365+ // CHECK-LABEL: func.func @no_propagate_extract_differentblock_1
366+ // CHECK: tensor.extract_slice
367+ // CHECK: linalg.generic
368+
369+ // -----
370+
371+ func.func @no_propagate_extract_differentblock_2 (%arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
372+ %empty = tensor.empty () : tensor <128 xf32 >
373+ %c0 = arith.constant 0 : index
374+ %c32 = arith.constant 32 : index
375+ %extracted_slice = tensor.extract_slice %empty [%arg0 ] [%arg0 ] [1 ] : tensor <128 xf32 > to tensor <?xf32 >
376+ %for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args (%arg4 = %arg1 ) -> tensor <?xbf16 > {
377+ %generic = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%extracted_slice : tensor <?xf32 >) outs (%arg4 : tensor <?xbf16 >) {
378+ ^bb0 (%in: f32 , %out: bf16 ):
379+ %1 = arith.truncf %in : f32 to bf16
380+ linalg.yield %1 : bf16
381+ } -> tensor <?xbf16 >
382+ scf.yield %generic : tensor <?xbf16 >
383+ }
384+ return %for : tensor <?xbf16 >
385+ }
386+
387+ // CHECK-LABEL: func.func @no_propagate_extract_differentblock_2
388+ // CHECK: tensor.extract_slice
389+ // CHECK: linalg.generic
0 commit comments