@@ -1130,3 +1130,42 @@ module {
11301130 return %1 : tensor <?x1 x61 x1 xf32 >
11311131 }
11321132}
1133+
1134+ // -----
1135+
1136+ func.func @no_fold_empty_tensor_dim_out_of_bounds (%arg0: tensor <1 x?x10 xf32 >) -> tensor <1 x?xf32 > {
1137+ %cst = arith.constant 1.000000e+00 : f32
1138+ %cst7 = arith.constant 7 : index
1139+ %dim = tensor.dim %arg0 , %cst7 : tensor <1 x?x10 xf32 >
1140+ %0 = tensor.empty (%dim ) : tensor <1 x?xf32 >
1141+ %1 = linalg.fill ins (%cst : f32 ) outs (%0 : tensor <1 x?xf32 >) -> tensor <1 x?xf32 >
1142+ return %1 : tensor <1 x?xf32 >
1143+ }
1144+ // CHECK-LABEL: func.func @no_fold_empty_tensor_dim_out_of_bounds
1145+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1146+ // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1147+ // CHECK: %[[C7:.*]] = arith.constant 7
1148+ // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C7]] : tensor<1x?x10xf32>
1149+ // CHECK: %[[VAL_0:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?xf32>
1150+ // CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_0]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1151+ // CHECK: return %[[VAL_1]] : tensor<1x?xf32>
1152+ // CHECK: }
1153+
1154+ // -----
1155+
1156+ func.func @fold_empty_tensor_dim_op (%arg0: tensor <1 x?x10 xf32 >) -> tensor <1 x?xf32 > {
1157+ %cst = arith.constant 1.000000e+00 : f32
1158+ %cst2 = index.constant 2
1159+ %dim10 = tensor.dim %arg0 , %cst2 : tensor <1 x?x10 xf32 >
1160+ %0 = tensor.empty (%dim10 ) : tensor <1 x?xf32 >
1161+ %1 = linalg.fill ins (%cst : f32 ) outs (%0 : tensor <1 x?xf32 >) -> tensor <1 x?xf32 >
1162+ return %1 : tensor <1 x?xf32 >
1163+ }
1164+ // CHECK-LABEL: func.func @fold_empty_tensor_dim_op
1165+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1166+ // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1167+ // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x10xf32>
1168+ // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1x10xf32> to tensor<1x?xf32>
1169+ // CHECK: %[[VAL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_1]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1170+ // CHECK: return %[[VAL_2]] : tensor<1x?xf32>
1171+ // CHECK: }
0 commit comments