Skip to content

Commit 79c7e73

Browse files
author
ssjia
committed
[ET-VK][ez] Accept sample_kwargs as an argument in several test util functions
Title says it all! This makes it possible to export models that require kwargs to be defined instead of args. Differential Revision: [D84716455](https://our.internmc.facebook.com/intern/diff/D84716455/) ghstack-source-id: 316415596 Pull Request resolved: #15153
1 parent f1e2548 commit 79c7e73

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

backends/vulkan/test/utils.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,16 @@ class QuantizationMode(Enum):
5050
def get_exported_graph(
5151
model,
5252
sample_inputs,
53+
sample_kwargs=None,
5354
dynamic_shapes=None,
5455
qmode=QuantizationMode.NONE,
5556
) -> torch.fx.GraphModule:
5657
export_training_graph = export(
57-
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
58+
model,
59+
sample_inputs,
60+
kwargs=sample_kwargs,
61+
dynamic_shapes=dynamic_shapes,
62+
strict=True,
5863
).module()
5964

6065
if qmode == QuantizationMode.NONE:
@@ -82,6 +87,7 @@ def random_uniform_tensor(shape, low=0.0, high=1.0, device=None, dtype=None):
8287
def export_model_to_vulkan(
8388
model,
8489
sample_inputs,
90+
sample_kwargs=None,
8591
dynamic_shapes=None,
8692
operator_blocklist=None,
8793
operator_allowlist=None,
@@ -91,11 +97,16 @@ def export_model_to_vulkan(
9197
):
9298
compile_options = {}
9399
exported_graph = get_exported_graph(
94-
model, sample_inputs, dynamic_shapes=dynamic_shapes, qmode=qmode
100+
model,
101+
sample_inputs,
102+
sample_kwargs=sample_kwargs,
103+
dynamic_shapes=dynamic_shapes,
104+
qmode=qmode,
95105
)
96106
program = export(
97107
exported_graph,
98108
sample_inputs,
109+
kwargs=sample_kwargs,
99110
dynamic_shapes=dynamic_shapes,
100111
strict=True,
101112
)
@@ -422,6 +433,7 @@ def save_bundled_program(
422433
sample_inputs: Tuple[torch.Tensor],
423434
output_path: str,
424435
method_name: str = "forward",
436+
sample_kwargs=None,
425437
et_program: Optional[ExecutorchProgramManager] = None,
426438
dynamic_shapes=None,
427439
) -> str:
@@ -441,13 +453,21 @@ def save_bundled_program(
441453
"""
442454
# If no ExecutorchProgramManager provided, export to Vulkan
443455
if et_program is None:
444-
et_program = export_model_to_vulkan(model, sample_inputs, dynamic_shapes)
456+
et_program = export_model_to_vulkan(
457+
model,
458+
sample_inputs,
459+
sample_kwargs=sample_kwargs,
460+
dynamic_shapes=dynamic_shapes,
461+
)
462+
463+
if sample_kwargs is None:
464+
sample_kwargs = {}
445465

446466
# Generate expected outputs by running the model
447-
expected_outputs = [getattr(model, method_name)(*sample_inputs)]
467+
expected_outputs = [getattr(model, method_name)(*sample_inputs, **sample_kwargs)]
448468

449469
# Flatten sample inputs to match expected format
450-
inputs_flattened, _ = tree_flatten(sample_inputs)
470+
inputs_flattened, _ = tree_flatten((sample_inputs, sample_kwargs))
451471

452472
# Create test suite with the sample inputs and expected outputs
453473
test_suites = [

0 commit comments

Comments
 (0)