From 0239e6e4e4db6ea2984f0a12fcfc68f105f7bd87 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 15 Oct 2025 09:06:44 -0700 Subject: [PATCH] [ET-VK][ez] Accept `sample_kwargs` as an argument in several test util functions Title says it all! This makes it possible to export models that require kwargs to be defined instead of args. Differential Revision: [D84716455](https://our.internmc.facebook.com/intern/diff/D84716455/) [ghstack-poisoned] --- backends/vulkan/test/utils.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/test/utils.py b/backends/vulkan/test/utils.py index a887c53473a..90edc094ec7 100644 --- a/backends/vulkan/test/utils.py +++ b/backends/vulkan/test/utils.py @@ -50,11 +50,16 @@ class QuantizationMode(Enum): def get_exported_graph( model, sample_inputs, + sample_kwargs=None, dynamic_shapes=None, qmode=QuantizationMode.NONE, ) -> torch.fx.GraphModule: export_training_graph = export( - model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True + model, + sample_inputs, + kwargs=sample_kwargs, + dynamic_shapes=dynamic_shapes, + strict=True, ).module() if qmode == QuantizationMode.NONE: @@ -82,6 +87,7 @@ def random_uniform_tensor(shape, low=0.0, high=1.0, device=None, dtype=None): def export_model_to_vulkan( model, sample_inputs, + sample_kwargs=None, dynamic_shapes=None, operator_blocklist=None, operator_allowlist=None, @@ -91,11 +97,16 @@ def export_model_to_vulkan( ): compile_options = {} exported_graph = get_exported_graph( - model, sample_inputs, dynamic_shapes=dynamic_shapes, qmode=qmode + model, + sample_inputs, + sample_kwargs=sample_kwargs, + dynamic_shapes=dynamic_shapes, + qmode=qmode, ) program = export( exported_graph, sample_inputs, + kwargs=sample_kwargs, dynamic_shapes=dynamic_shapes, strict=True, ) @@ -422,6 +433,7 @@ def save_bundled_program( sample_inputs: Tuple[torch.Tensor], output_path: str, method_name: str = "forward", + sample_kwargs=None, et_program: Optional[ExecutorchProgramManager] = None, dynamic_shapes=None, ) -> str: @@ -441,13 +453,21 @@ def save_bundled_program( """ # If no ExecutorchProgramManager provided, export to Vulkan if et_program is None: - et_program = export_model_to_vulkan(model, sample_inputs, dynamic_shapes) + et_program = export_model_to_vulkan( + model, + sample_inputs, + sample_kwargs=sample_kwargs, + dynamic_shapes=dynamic_shapes, + ) + + if sample_kwargs is None: + sample_kwargs = {} # Generate expected outputs by running the model - expected_outputs = [getattr(model, method_name)(*sample_inputs)] + expected_outputs = [getattr(model, method_name)(*sample_inputs, **sample_kwargs)] # Flatten sample inputs to match expected format - inputs_flattened, _ = tree_flatten(sample_inputs) + inputs_flattened, _ = tree_flatten((sample_inputs, sample_kwargs)) # Create test suite with the sample inputs and expected outputs test_suites = [