@@ -1130,3 +1130,42 @@ module {
1130
1130
return %1 : tensor <?x1 x61 x1 xf32 >
1131
1131
}
1132
1132
}
1133
+
1134
+ // -----
1135
+
1136
+ func.func @no_fold_empty_tensor_dim_out_of_bounds (%arg0: tensor <1 x?x10 xf32 >) -> tensor <1 x?xf32 > {
1137
+ %cst = arith.constant 1.000000e+00 : f32
1138
+ %cst7 = arith.constant 7 : index
1139
+ %dim = tensor.dim %arg0 , %cst7 : tensor <1 x?x10 xf32 >
1140
+ %0 = tensor.empty (%dim ) : tensor <1 x?xf32 >
1141
+ %1 = linalg.fill ins (%cst : f32 ) outs (%0 : tensor <1 x?xf32 >) -> tensor <1 x?xf32 >
1142
+ return %1 : tensor <1 x?xf32 >
1143
+ }
1144
+ // CHECK-LABEL: func.func @no_fold_empty_tensor_dim_out_of_bounds
1145
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1146
+ // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1147
+ // CHECK: %[[C7:.*]] = arith.constant 7
1148
+ // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C7]] : tensor<1x?x10xf32>
1149
+ // CHECK: %[[VAL_0:.*]] = tensor.empty(%[[DIM]]) : tensor<1x?xf32>
1150
+ // CHECK: %[[VAL_1:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_0]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1151
+ // CHECK: return %[[VAL_1]] : tensor<1x?xf32>
1152
+ // CHECK: }
1153
+
1154
+ // -----
1155
+
1156
+ func.func @fold_empty_tensor_dim_op (%arg0: tensor <1 x?x10 xf32 >) -> tensor <1 x?xf32 > {
1157
+ %cst = arith.constant 1.000000e+00 : f32
1158
+ %cst2 = index.constant 2
1159
+ %dim10 = tensor.dim %arg0 , %cst2 : tensor <1 x?x10 xf32 >
1160
+ %0 = tensor.empty (%dim10 ) : tensor <1 x?xf32 >
1161
+ %1 = linalg.fill ins (%cst : f32 ) outs (%0 : tensor <1 x?xf32 >) -> tensor <1 x?xf32 >
1162
+ return %1 : tensor <1 x?xf32 >
1163
+ }
1164
+ // CHECK-LABEL: func.func @fold_empty_tensor_dim_op
1165
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x10xf32>) -> tensor<1x?xf32> {
1166
+ // CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
1167
+ // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x10xf32>
1168
+ // CHECK: %[[VAL_1:.*]] = tensor.cast %[[VAL_0]] : tensor<1x10xf32> to tensor<1x?xf32>
1169
+ // CHECK: %[[VAL_2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAL_1]] : tensor<1x?xf32>) -> tensor<1x?xf32>
1170
+ // CHECK: return %[[VAL_2]] : tensor<1x?xf32>
1171
+ // CHECK: }
0 commit comments