@@ -1520,11 +1520,18 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
15201520 ValueRef c = graph.add_tensor (size_big, vkapi::kFloat );
15211521 ValueRef e = graph.add_tensor (size_big, vkapi::kFloat );
15221522
1523+ ValueRef w1_packed = graph.add_tensor (size_small, vkapi::kFloat );
1524+ ValueRef w2_packed = graph.add_tensor (size_small, vkapi::kFloat );
1525+
1526+ auto prepackFn = VK_GET_OP_FN (" et_vk.prepack.default" );
1527+ prepackFn (graph, {w1, w1_packed});
1528+ prepackFn (graph, {w2, w2_packed});
1529+
15231530 auto addFn = VK_GET_OP_FN (" aten.add.Tensor" );
1524- addFn (graph, {a.value , w1 , kDummyValueRef , c});
1531+ addFn (graph, {a.value , w1_packed , kDummyValueRef , c});
15251532
15261533 auto mulFn = VK_GET_OP_FN (" aten.mul.Tensor" );
1527- mulFn (graph, {c, w2 , e});
1534+ mulFn (graph, {c, w2_packed , e});
15281535
15291536 IOValueRef out = {};
15301537 out.value = e;
@@ -2597,24 +2604,16 @@ void test_binary_op(
25972604 std::vector<int64_t > sizes_big,
25982605 std::vector<int64_t > sizes_small,
25992606 vkapi::ScalarType dtype,
2600- utils::GPUMemoryLayout memory_layout,
2601- bool prepack = true ) {
2607+ utils::GPUMemoryLayout memory_layout) {
26022608 GraphConfig config;
26032609 ComputeGraph graph (config);
26042610
26052611 IOValueRef arg2{};
26062612
2607- CREATE_WEIGHT_TENSOR (arg2_w, sizes_small, dtype, 2 .5f );
2608-
26092613 // Build graph
26102614
26112615 IOValueRef arg1 = graph.add_input_tensor (sizes_big, dtype, memory_layout);
2612-
2613- if (prepack) {
2614- arg2.value = arg2_w;
2615- } else {
2616- arg2 = graph.add_input_tensor (sizes_small, dtype, memory_layout);
2617- }
2616+ arg2 = graph.add_input_tensor (sizes_small, dtype, memory_layout);
26182617
26192618 IOValueRef out;
26202619 out.value = graph.add_tensor (sizes_big, dtype, memory_layout);
@@ -2635,7 +2634,7 @@ void test_binary_op(
26352634
26362635 for (int i = 1 ; i < 4 ; i++) {
26372636 float val_arg1 = i + 1.5 ;
2638- float val_arg2 = prepack ? 2 . 5f : i - 3.5 ;
2637+ float val_arg2 = i - 3.5 ;
26392638
26402639 float val_out = val_arg1 + val_arg2;
26412640 if (op_name == " sub" ) {
@@ -2648,21 +2647,14 @@ void test_binary_op(
26482647 val_out = val_arg1 / val_arg2;
26492648 }
26502649
2651- if (prepack) {
2652- execute_graph_and_check_output (graph, {val_arg1}, {val_out});
2653- } else {
2654- execute_graph_and_check_output (graph, {val_arg1, val_arg2}, {val_out});
2655- }
2650+ execute_graph_and_check_output (graph, {val_arg1, val_arg2}, {val_out});
26562651 }
26572652}
26582653
2659- #define CALL_TEST_FN_FORALL_CONDITIONS (_ ) \
2660- _ (vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked , false ) \
2661- _(vkapi::kFloat , utils::kTexture3D , utils::kHeightPacked , false ) \
2662- _(vkapi::kFloat , utils::kTexture3D , utils::kChannelsPacked , false ) \
2663- _(vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked , true ) \
2664- _(vkapi::kFloat , utils::kTexture3D , utils::kHeightPacked , true ) \
2665- _(vkapi::kFloat , utils::kTexture3D , utils::kChannelsPacked , true )
2654+ #define CALL_TEST_FN_FORALL_CONDITIONS (_ ) \
2655+ _ (vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked ) \
2656+ _(vkapi::kFloat , utils::kTexture3D , utils::kHeightPacked ) \
2657+ _(vkapi::kFloat , utils::kTexture3D , utils::kChannelsPacked )
26662658
26672659#define CALL_TEST_FN_FOR_W_PACKED (_ ) \
26682660 _ (vkapi::kFloat , utils::kTexture3D , utils::kWidthPacked , false ) \
@@ -2677,15 +2669,15 @@ void test_binary_op(
26772669 _(vkapi::kFloat , utils::kBuffer , utils::kChannelsPacked , true )
26782670
26792671TEST(VulkanComputeGraphOpsTest, add_smoke_test) {
2680- #define RUN_TESTS (dtype, storage, layout, prepack ) \
2681- test_binary_op (" add" , {17 , 21 }, {17 , 21 }, dtype, layout, prepack ); \
2682- test_binary_op (" add" , {17 , 21 }, {1 , 1 }, dtype, layout, prepack ); \
2683- test_binary_op (" sub" , {11 , 22 }, {11 , 22 }, dtype, layout, prepack ); \
2684- test_binary_op (" sub" , {11 , 22 }, {11 , 1 }, dtype, layout, prepack ); \
2685- test_binary_op (" add" , {7 , 17 , 17 }, {7 , 17 , 17 }, dtype, layout, prepack ); \
2686- test_binary_op (" add" , {7 , 17 , 17 }, {7 , 1 , 17 }, dtype, layout, prepack ); \
2687- test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 9 , 7 }, dtype, layout, prepack ); \
2688- test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 1 , 1 }, dtype, layout, prepack );
2672+ #define RUN_TESTS (dtype, storage, layout ) \
2673+ test_binary_op (" add" , {17 , 21 }, {17 , 21 }, dtype, layout); \
2674+ test_binary_op (" add" , {17 , 21 }, {1 , 1 }, dtype, layout); \
2675+ test_binary_op (" sub" , {11 , 22 }, {11 , 22 }, dtype, layout); \
2676+ test_binary_op (" sub" , {11 , 22 }, {11 , 1 }, dtype, layout); \
2677+ test_binary_op (" add" , {7 , 17 , 17 }, {7 , 17 , 17 }, dtype, layout); \
2678+ test_binary_op (" add" , {7 , 17 , 17 }, {7 , 1 , 17 }, dtype, layout); \
2679+ test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 9 , 7 }, dtype, layout); \
2680+ test_binary_op (" sub" , {9 , 9 , 7 }, {9 , 1 , 1 }, dtype, layout);
26892681
26902682 CALL_TEST_FN_FORALL_CONDITIONS (RUN_TESTS);
26912683
0 commit comments