@@ -312,8 +312,8 @@ func.func @hoist_pack_unpack_multiple_loop(%arg0 : tensor<1x1x4x2x16x16xbf16>, %
312312
313313// -----
314314
315- func.func @propagate_extract_basic (%arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
316- %empty = tensor.empty () : tensor <128 xf32 >
315+ func.func @propagate_extract_basic (%input : tensor < 128 x f32 >, % arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
316+ %empty = util.optimization_barrier %input : tensor <128 xf32 >
317317 %extracted_slice = tensor.extract_slice %empty [%arg0 ] [%arg0 ] [1 ] : tensor <128 xf32 > to tensor <?xf32 >
318318 %generic = linalg.generic {index ing_maps = [affine_map <(d0 ) -> (d0 )>, affine_map <(d0 ) -> (d0 )>], iterator_types = [" parallel" ]} ins (%extracted_slice : tensor <?xf32 >) outs (%arg1 : tensor <?xbf16 >) {
319319 ^bb0 (%in: f32 , %out: bf16 ):
@@ -346,8 +346,8 @@ func.func @no_propagate_extract_blockargument(%input : tensor<128xf32>, %arg0 :
346346
347347// -----
348348
349- func.func @no_propagate_extract_differentblock_1 (%arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
350- %empty = tensor.empty () : tensor <128 xf32 >
349+ func.func @no_propagate_extract_differentblock_1 (%input : tensor < 128 x f32 >, % arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
350+ %empty = util.optimization_barrier %input : tensor <128 xf32 >
351351 %c0 = arith.constant 0 : index
352352 %c32 = arith.constant 32 : index
353353 %for = scf.for %arg2 = %c0 to %c32 step %arg0 iter_args (%arg4 = %arg1 ) -> tensor <?xbf16 > {
@@ -368,8 +368,8 @@ func.func @no_propagate_extract_differentblock_1(%arg0 : index, %arg1 : tensor<?
368368
369369// -----
370370
371- func.func @no_propagate_extract_differentblock_2 (%arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
372- %empty = tensor.empty () : tensor <128 xf32 >
371+ func.func @no_propagate_extract_differentblock_2 (%input : tensor < 128 x f32 >, % arg0 : index , %arg1 : tensor <?xbf16 >) -> tensor <?xbf16 > {
372+ %empty = util.optimization_barrier %input : tensor <128 xf32 >
373373 %c0 = arith.constant 0 : index
374374 %c32 = arith.constant 32 : index
375375 %extracted_slice = tensor.extract_slice %empty [%arg0 ] [%arg0 ] [1 ] : tensor <128 xf32 > to tensor <?xf32 >
0 commit comments