|  | 
| 30 | 30 | #include <type_traits> | 
| 31 | 31 | #include <vector> | 
| 32 | 32 | 
 | 
|  | 33 | +#include <iostream> | 
|  | 34 | + | 
| 33 | 35 | namespace executorch { | 
| 34 | 36 | namespace backends { | 
| 35 | 37 | namespace vulkan { | 
| @@ -140,6 +142,14 @@ GraphConfig get_graph_config(ArrayRef<CompileSpec>& compile_specs) { | 
| 140 | 142 | 
 | 
| 141 | 143 |       config.set_memory_layout_override(memory_layout); | 
| 142 | 144 |     } | 
|  | 145 | +    if (strcmp(spec.key, "require_dynamic_shapes") == 0) { | 
|  | 146 | +      ET_CHECK_MSG(value_size == sizeof(uint8_t), "Unexpected value size!"); | 
|  | 147 | +      bool value = getBool(value_data); | 
|  | 148 | + | 
|  | 149 | +      if (value) { | 
|  | 150 | +        config.expect_dynamic_shapes = true; | 
|  | 151 | +      } | 
|  | 152 | +    } | 
| 143 | 153 |   } | 
| 144 | 154 | #ifdef ET_EVENT_TRACER_ENABLED | 
| 145 | 155 |   config.enable_querypool = true; | 
| @@ -500,9 +510,12 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { | 
| 500 | 510 |     compute_graph->encode_prepack(); | 
| 501 | 511 |     compute_graph->prepack(); | 
| 502 | 512 | 
 | 
| 503 |  | -    // TODO(ssjia): remove this once we can batch compile compute pipelines | 
| 504 |  | -    // during prepare(). | 
| 505 |  | -    compute_graph->encode_execute(); | 
|  | 513 | +    // If dynamic shapes are not expected, then the command buffer only needs to | 
|  | 514 | +    // be encoded once. Otherwise, wait until the first inference to encode the | 
|  | 515 | +    // the command buffer, when actual input shapes are known. | 
|  | 516 | +    if (!compute_graph->graphconfig().expect_dynamic_shapes) { | 
|  | 517 | +      compute_graph->encode_execute(); | 
|  | 518 | +    } | 
| 506 | 519 | 
 | 
| 507 | 520 |     return Error::Ok; | 
| 508 | 521 |   } | 
| @@ -574,7 +587,9 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface { | 
| 574 | 587 |     // constants are updated and DynamicDispatchNode can update the compute | 
| 575 | 588 |     // shader, global workgroup size, and local workgroup size to perform the | 
| 576 | 589 |     // model inference. | 
| 577 |  | -    if (should_propagate_resize) { | 
|  | 590 | +    if (should_propagate_resize || | 
|  | 591 | +        (compute_graph->graphconfig().expect_dynamic_shapes && | 
|  | 592 | +         compute_graph->execute_count() == 0u)) { | 
| 578 | 593 |       compute_graph->propagate_resize(); | 
| 579 | 594 |     } | 
| 580 | 595 | 
 | 
|  | 
0 commit comments