@@ -1520,11 +1520,18 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
15201520 ValueRef c = graph.add_tensor (size_big, vkapi::kFloat );
15211521 ValueRef e = graph.add_tensor (size_big, vkapi::kFloat );
15221522
1523+ ValueRef w1_packed = graph.add_tensor (size_small, vkapi::kFloat );
1524+ ValueRef w2_packed = graph.add_tensor (size_small, vkapi::kFloat );
1525+
1526+ auto prepackFn = VK_GET_OP_FN (" et_vk.prepack.default" );
1527+ prepackFn (graph, {w1, w1_packed});
1528+ prepackFn (graph, {w2, w2_packed});
1529+
15231530 auto addFn = VK_GET_OP_FN (" aten.add.Tensor" );
1524- addFn (graph, {a.value , w1 , kDummyValueRef , c});
1531+ addFn (graph, {a.value , w1_packed , kDummyValueRef , c});
15251532
15261533 auto mulFn = VK_GET_OP_FN (" aten.mul.Tensor" );
1527- mulFn (graph, {c, w2 , e});
1534+ mulFn (graph, {c, w2_packed , e});
15281535
15291536 IOValueRef out = {};
15301537 out.value = e;
@@ -2597,8 +2604,7 @@ void test_binary_op(
25972604 std::vector<int64_t > sizes_big,
25982605 std::vector<int64_t > sizes_small,
25992606 vkapi::ScalarType dtype,
2600- utils::GPUMemoryLayout memory_layout,
2601- bool prepack = true ) {
2607+ utils::GPUMemoryLayout memory_layout) {
26022608 GraphConfig config;
26032609 ComputeGraph graph (config);
26042610
@@ -2609,12 +2615,7 @@ void test_binary_op(
26092615 // Build graph
26102616
26112617 IOValueRef arg1 = graph.add_input_tensor (sizes_big, dtype, memory_layout);
2612-
2613- if (prepack) {
2614- arg2.value = arg2_w;
2615- } else {
2616- arg2 = graph.add_input_tensor (sizes_small, dtype, memory_layout);
2617- }
2618+ arg2 = graph.add_input_tensor (sizes_small, dtype, memory_layout);
26182619
26192620 IOValueRef out;
26202621 out.value = graph.add_tensor (sizes_big, dtype, memory_layout);
@@ -2635,7 +2636,7 @@ void test_binary_op(
26352636
26362637 for (int i = 1 ; i < 4 ; i++) {
26372638 float val_arg1 = i + 1.5 ;
2638- float val_arg2 = prepack ? 2 . 5f : i - 3.5 ;
2639+ float val_arg2 = i - 3.5 ;
26392640
26402641 float val_out = val_arg1 + val_arg2;
26412642 if (op_name == " sub" ) {
@@ -2648,21 +2649,14 @@ void test_binary_op(
26482649 val_out = val_arg1 / val_arg2;
26492650 }
26502651
2651- if (prepack) {
2652- execute_graph_and_check_output (graph, {val_arg1}, {val_out});
2653- } else {
2654- execute_graph_and_check_output (graph, {val_arg1, val_arg2}, {val_out});
2655- }
2652+ execute_graph_and_check_output (graph, {val_arg1, val_arg2}, {val_out});
26562653 }
26572654}
26582655
2659- #define CALL_TEST_FN_FORALL_CONDITIONS (_ ) \
2660- _ (vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked , false ) \
2661- _(vkapi::kFloat , utils::kTexture3D , utils::kHeightPacked , false ) \
2662- _(vkapi::kFloat , utils::kTexture3D , utils::kChannelsPacked , false ) \
2663- _(vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked , true ) \
2664- _(vkapi::kFloat , utils::kTexture3D , utils::kHeightPacked , true ) \
2665- _(vkapi::kFloat , utils::kTexture3D , utils::kChannelsPacked , true )
2656+ #define CALL_TEST_FN_FORALL_CONDITIONS (_ ) \
2657+ _ (vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked ) \
2658+ _(vkapi::kFloat , utils::kTexture3D , utils::kHeightPacked ) \
2659+ _(vkapi::kFloat , utils::kTexture3D , utils::kChannelsPacked )
26662660
26672661#define CALL_TEST_FN_FOR_W_PACKED (_ ) \
26682662 _ (vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked , false ) \
@@ -2677,15 +2671,15 @@ void test_binary_op(
26772671 _(vkapi::kFloat , utils::kBuffer , utils::kChannelsPacked , true )
26782672
26792673TEST(VulkanComputeGraphOpsTest, add_smoke_test) {
2680- #define RUN_TESTS (dtype, storage, layout, prepack ) \
2681- test_binary_op (" add" , {17 , 21 }, {17 , 21 }, dtype, layout, prepack ); \
2682- test_binary_op (" add" , {17 , 21 }, {1 , 1 }, dtype, layout, prepack ); \
2683- test_binary_op (" sub" , {11 , 22 }, {11 , 22 }, dtype, layout, prepack ); \
2684- test_binary_op (" sub" , {11 , 22 }, {11 , 1 }, dtype, layout, prepack ); \
2685- test_binary_op (" add" , {7 , 17 , 17 }, {7 , 17 , 17 }, dtype, layout, prepack ); \
2686- test_binary_op (" add" , {7 , 17 , 17 }, {7 , 1 , 17 }, dtype, layout, prepack ); \
2687- test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 9 , 7 }, dtype, layout, prepack ); \
2688- test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 1 , 1 }, dtype, layout, prepack );
2674+ #define RUN_TESTS (dtype, storage, layout ) \
2675+ test_binary_op (" add" , {17 , 21 }, {17 , 21 }, dtype, layout); \
2676+ test_binary_op (" add" , {17 , 21 }, {1 , 1 }, dtype, layout); \
2677+ test_binary_op (" sub" , {11 , 22 }, {11 , 22 }, dtype, layout); \
2678+ test_binary_op (" sub" , {11 , 22 }, {11 , 1 }, dtype, layout); \
2679+ test_binary_op (" add" , {7 , 17 , 17 }, {7 , 17 , 17 }, dtype, layout); \
2680+ test_binary_op (" add" , {7 , 17 , 17 }, {7 , 1 , 17 }, dtype, layout); \
2681+ test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 9 , 7 }, dtype, layout); \
2682+ test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 1 , 1 }, dtype, layout);
26892683
26902684 CALL_TEST_FN_FORALL_CONDITIONS (RUN_TESTS);
26912685
0 commit comments