@@ -1209,3 +1209,110 @@ func.func @hoist_linalg_ops_div_by_zero(%a : tensor<128x128xi32>,
12091209
12101210 func.return %final : tensor <?x128 xi32 >
12111211}
1212+
1213+ // -----
1214+
1215+ // CHECK-LABEL: func @hoist_vector_transfer_ops
1216+ // CHECK: vector.transfer_read
1217+ // CHECK: scf.for
1218+ // CHECK-NOT: vector.transfer_read
1219+ // CHECK: arith.addf
1220+ // CHECK: scf.yield
1221+ func.func @hoist_vector_transfer_ops (
1222+ %a : tensor <128 x128 xf32 >,
1223+ %lb : index ,
1224+ %ub : index ,
1225+ %step : index ,
1226+ %ida : index ,
1227+ %idb : index ) -> vector <4 x4 xf32 > {
1228+ %cst_0 = arith.constant 0.0 : f32
1229+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1230+ %final =
1231+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1232+ %read = vector.transfer_read %a [%ida , %idb ], %cst_0 : tensor <128 x128 xf32 >, vector <4 x4 xf32 >
1233+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1234+ scf.yield %out : vector <4 x4 xf32 >
1235+ }
1236+ func.return %final : vector <4 x4 xf32 >
1237+ }
1238+
1239+ // -----
1240+
1241+ // CHECK-LABEL: func @hoist_vector_transfer_ops
1242+ // CHECK: vector.transfer_write
1243+ // CHECK: vector.transfer_read
1244+ // CHECK: scf.for
1245+ // CHECK-NOT: vector.transfer_write
1246+ // CHECK-NOT: vector.transfer_read
1247+ // CHECK: arith.addf
1248+ // CHECK: scf.yield
1249+ func.func @hoist_vector_transfer_ops (
1250+ %lb : index ,
1251+ %ub : index ,
1252+ %step : index ,
1253+ %ida : index ,
1254+ %idb : index ) -> vector <4 x4 xf32 > {
1255+ %c0 = arith.constant 0 : index
1256+ %cst_0 = arith.constant 0.0 : f32
1257+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1258+ %empty = tensor.empty () : tensor <4 x4 xf32 >
1259+ %final =
1260+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1261+ %a = vector.transfer_write %cst , %empty [%c0 , %c0 ] : vector <4 x4 xf32 >, tensor <4 x4 xf32 >
1262+ %read = vector.transfer_read %a [%c0 , %c0 ], %cst_0 : tensor <4 x4 xf32 >, vector <4 x4 xf32 >
1263+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1264+ scf.yield %out : vector <4 x4 xf32 >
1265+ }
1266+ func.return %final : vector <4 x4 xf32 >
1267+ }
1268+
1269+ // -----
1270+
1271+ // CHECK-LABEL: func @do_not_hoist_vector_transfer_ops_loop_dep
1272+ // CHECK-NOT: vector.transfer_read
1273+ // CHECK: scf.for
1274+ // CHECK: vector.transfer_read
1275+ // CHECK: arith.addf
1276+ // CHECK: scf.yield
1277+ func.func @do_not_hoist_vector_transfer_ops_loop_dep (
1278+ %a : tensor <128 x128 xf32 >,
1279+ %lb : index ,
1280+ %ub : index ,
1281+ %step : index ,
1282+ %ida : index ) -> vector <4 x4 xf32 > {
1283+ %cst_0 = arith.constant 0.0 : f32
1284+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1285+ %final =
1286+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1287+ %read = vector.transfer_read %a [%ida , %i ], %cst_0 : tensor <128 x128 xf32 >, vector <4 x4 xf32 >
1288+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1289+ scf.yield %out : vector <4 x4 xf32 >
1290+ }
1291+ func.return %final : vector <4 x4 xf32 >
1292+ }
1293+
1294+ // -----
1295+
1296+ // CHECK-LABEL: func @do_not_hoist_vector_transfer_ops_memref
1297+ // CHECK-NOT: vector.transfer_read
1298+ // CHECK: scf.for
1299+ // CHECK: vector.transfer_read
1300+ // CHECK: arith.addf
1301+ // CHECK: scf.yield
1302+ func.func @do_not_hoist_vector_transfer_ops_memref (
1303+ %a : memref <128 x128 xf32 >,
1304+ %lb : index ,
1305+ %ub : index ,
1306+ %step : index ,
1307+ %ida : index ,
1308+ %idb : index ) -> vector <4 x4 xf32 > {
1309+ %cst_0 = arith.constant 0.0 : f32
1310+ %cst = arith.constant dense <0.0 > : vector <4 x4 xf32 >
1311+ %final =
1312+ scf.for %i = %lb to %ub step %step iter_args (%acc = %cst ) -> vector <4 x4 xf32 > {
1313+ %read = vector.transfer_read %a [%ida , %idb ], %cst_0 : memref <128 x128 xf32 >, vector <4 x4 xf32 >
1314+ %out = arith.addf %read , %acc : vector <4 x4 xf32 >
1315+ scf.yield %out : vector <4 x4 xf32 >
1316+ }
1317+ func.return %final : vector <4 x4 xf32 >
1318+ }
0 commit comments