Skip to content

Commit e38ca0d

Browse files
[ET-VK] 6/n Split dispatches between multiple command buffers. Replaced encode_execute function with invalidate_execute_encoding and moved encoding logic to execute function(). (#13054)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13016 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/128/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/128/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/128/orig @diff-train-skip-merge Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 54b1e6c commit e38ca0d

12 files changed

+15
-67
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -509,13 +509,6 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
509509

510510
compute_graph->prepack();
511511

512-
// If dynamic shapes are not expected, then the command buffer only needs to
513-
// be encoded once. Otherwise, wait until the first inference to encode the
514-
// the command buffer, when actual input shapes are known.
515-
if (!compute_graph->graphconfig().expect_dynamic_shapes) {
516-
compute_graph->encode_execute();
517-
}
518-
519512
return Error::Ok;
520513
}
521514

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -860,21 +860,20 @@ void ComputeGraph::prepack() {
860860
staging_nbytes_in_cmd_ = 0;
861861
}
862862

863-
void ComputeGraph::encode_execute() {
864-
clear_deferred_cmds();
865-
context_->flush();
866-
context_->set_cmd(/*reusable = */ true);
863+
void ComputeGraph::execute() {
864+
if (deferred_cmd_list_.empty()) {
865+
context_->flush();
866+
context_->set_cmd(/*reusable = */ true);
867867

868-
context_->cmd_reset_querypool();
868+
context_->cmd_reset_querypool();
869869

870-
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
871-
node->encode(this);
872-
}
870+
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
871+
node->encode(this);
872+
}
873873

874-
deferred_cmd_list_.emplace_back(std::move(context_->extract_cmd()));
875-
}
874+
deferred_cmd_list_.emplace_back(std::move(context_->extract_cmd()));
875+
}
876876

877-
void ComputeGraph::execute() {
878877
submit_deferred_cmds_and_wait();
879878
execute_count_++;
880879
}
@@ -898,7 +897,7 @@ void ComputeGraph::propagate_resize() {
898897
}
899898
// Only re-encode on resize if dynamic shapes are expected
900899
if (config_.expect_dynamic_shapes) {
901-
encode_execute();
900+
clear_deferred_cmds();
902901
}
903902
}
904903

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,6 @@ class ComputeGraph final {
892892
// Graph Execution
893893
//
894894

895-
void encode_execute();
896895
void execute();
897896

898897
//

backends/vulkan/test/op_tests/choose_qparams_test.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,6 @@ void test_vulkan_choose_qparams_tensor_impl(
458458
graph.prepare();
459459

460460
graph.prepack();
461-
graph.encode_execute();
462461

463462
// Run Vulkan choose_qparams_tensor
464463
graph.copy_into_staging(
@@ -678,7 +677,6 @@ void test_vulkan_choose_qparams_per_token_asymmetric_impl(
678677
graph.prepare();
679678

680679
graph.prepack();
681-
graph.encode_execute();
682680

683681
// Run Vulkan choose_qparams_per_token_asymmetric
684682
graph.copy_into_staging(

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,6 @@ void test_vulkan_dequantize_per_token_impl(
11401140
graph.prepare();
11411141

11421142
graph.prepack();
1143-
graph.encode_execute();
11441143

11451144
// Copy input data to GPU
11461145
graph.copy_into_staging(
@@ -1671,7 +1670,6 @@ void test_vulkan_dequantize_per_channel_impl(
16711670

16721671
graph.prepare();
16731672
graph.prepack();
1674-
graph.encode_execute();
16751673

16761674
// Copy input data to GPU
16771675
graph.copy_into_staging(
@@ -2345,7 +2343,6 @@ void test_vulkan_dequantize_per_tensor_tensor_impl(
23452343

23462344
graph.prepare();
23472345
graph.prepack();
2348-
graph.encode_execute();
23492346

23502347
// Run Vulkan dequantize_per_tensor.tensor
23512348
graph.copy_into_staging(

backends/vulkan/test/op_tests/quantize_affine_test.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,6 @@ void test_vulkan_quantize_affine_impl(
491491

492492
graph.prepare();
493493
graph.prepack();
494-
graph.encode_execute();
495494

496495
// Copy input data to GPU
497496
graph.copy_into_staging(
@@ -789,7 +788,6 @@ void test_vulkan_dequantize_affine_impl(
789788

790789
graph.prepare();
791790
graph.prepack();
792-
graph.encode_execute();
793791

794792
// Copy input data to GPU
795793
graph.copy_into_staging(
@@ -1079,7 +1077,6 @@ void test_vulkan_choose_qparams_affine_impl(
10791077

10801078
graph.prepare();
10811079
graph.prepack();
1082-
graph.encode_execute();
10831080

10841081
// Copy input data to GPU
10851082
graph.copy_into_staging(

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,6 @@ void test_vulkan_quantize_per_token_impl(
931931
graph.prepare();
932932

933933
graph.prepack();
934-
graph.encode_execute();
935934

936935
// Copy input data to GPU
937936
graph.copy_into_staging(
@@ -1413,7 +1412,6 @@ void test_vulkan_quantize_per_channel_impl(
14131412

14141413
graph.prepare();
14151414
graph.prepack();
1416-
graph.encode_execute();
14171415

14181416
// Copy input data to GPU
14191417
graph.copy_into_staging(
@@ -2042,7 +2040,6 @@ void test_vulkan_quantize_per_tensor_tensor_impl(
20422040

20432041
graph.prepare();
20442042
graph.prepack();
2045-
graph.encode_execute();
20462043

20472044
// Run Vulkan quantize_per_tensor.tensor
20482045
graph.copy_into_staging(

backends/vulkan/test/op_tests/quantized_linear_test.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,6 @@ void test_vulkan_linear_qga4w_impl(
456456
graph.prepare();
457457

458458
graph.prepack();
459-
graph.encode_execute();
460459

461460
//
462461
// Run model
@@ -551,7 +550,6 @@ void test_vulkan_linear_qcs4w_impl(
551550
graph.prepare();
552551

553552
graph.prepack();
554-
graph.encode_execute();
555553

556554
//
557555
// Run model
@@ -685,7 +683,6 @@ void test_vulkan_linear_qta8a_qga4w_impl(
685683
graph.prepare();
686684

687685
graph.prepack();
688-
graph.encode_execute();
689686

690687
//
691688
// Run model
@@ -900,4 +897,4 @@ TEST_F(VulkanLinearQTA8AQGA4WTest, test_vulkan_linear_quant_gemv) {
900897
/*M = */ 1,
901898
/*K = */ 256,
902899
/*N = */ 256);
903-
}
900+
}

backends/vulkan/test/op_tests/rotary_embedding_test.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ void test_reference(
114114
graph.prepare();
115115

116116
graph.prepack();
117-
graph.encode_execute();
118117

119118
//
120119
// Run model

backends/vulkan/test/op_tests/sdpa_test.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ void test_vulkan_sdpa(
352352
graph.prepare();
353353

354354
graph.prepack();
355-
graph.encode_execute();
356355

357356
//
358357
// Run model
@@ -586,7 +585,6 @@ void test_vulkan_flash_attention(
586585
graph.prepare();
587586
graph.encode_prepack();
588587
graph.prepack();
589-
graph.encode_execute();
590588

591589
// Copy inputs and run
592590
graph.copy_into_staging(r_q.staging, q.const_data_ptr(), q.numel());
@@ -845,7 +843,6 @@ void test_reference_flash_attention(
845843
graph.prepare();
846844
graph.encode_prepack();
847845
graph.prepack();
848-
graph.encode_execute();
849846

850847
graph.copy_into_staging(r_q.staging, q.const_data_ptr(), q.numel());
851848
graph.copy_into_staging(r_k.staging, k.const_data_ptr(), k.numel());

0 commit comments

Comments
 (0)