Skip to content

Commit 89e2c5d

Browse files
SS-JIAssjia
andauthored
[ET-VK] Add some utility compile options + improve export script (pytorch#15795)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * pytorch#15829 * pytorch#15796 * __->__ pytorch#15795 * pytorch#15794 * pytorch#15793 Title says it all! Add two additional export options: 1. `skip_memory_planning` - skips the memory planning pass which can be useful for debugging. 2. `small_texture_limits` - sets the default texture limit to be (2048, 2048, 2048) which is compatible with more devices (i.e. desktop/laptop GPUs) compared to the default (16384, 16384, 2048) which is more targeted for mobile GPUs Also adds some improvements to the export script that were made while debugging the `YOLO_NAS` model (pytorch#15700) Differential Revision: [D86910640](https://our.internmc.facebook.com/intern/diff/D86910640/) --------- Co-authored-by: ssjia <[email protected]>
1 parent 4415bc6 commit 89e2c5d

File tree

3 files changed

+192
-70
lines changed

3 files changed

+192
-70
lines changed

backends/vulkan/test/utils.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,14 @@
88
import logging
99
from collections import OrderedDict
1010
from copy import deepcopy
11-
1211
from enum import auto, Enum
1312
from typing import Any, List, Optional, Tuple
1413

1514
import executorch.backends.vulkan.utils as utils
16-
1715
import torch
18-
1916
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
2017
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
2118
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
22-
2319
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
2420
get_symmetric_quantization_config,
2521
XNNPACKQuantizer,
@@ -36,7 +32,6 @@
3632
)
3733
from executorch.extension.pytree import tree_flatten
3834
from torch.export import export
39-
4035
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4136

4237

@@ -275,16 +270,25 @@ def check_outputs_equal(
275270
)
276271
return result
277272
else:
273+
result = True
278274
for i in range(len(ref_output)):
279-
if not torch.allclose(
280-
model_output[i], ref_output[i], atol=atol, rtol=rtol
281-
):
282-
print(f"\n=== Output {i} comparison failed ===")
283-
print_tensor_comparison_errors(
284-
model_output[i], ref_output[i], atol, rtol
285-
)
286-
return False
287-
return True
275+
if isinstance(ref_output[i], torch.Tensor):
276+
if not torch.allclose(
277+
model_output[i], ref_output[i], atol=atol, rtol=rtol
278+
):
279+
print(f"\n=== Output {i} comparison failed ===")
280+
print_tensor_comparison_errors(
281+
model_output[i], ref_output[i], atol, rtol
282+
)
283+
result = False
284+
elif isinstance(ref_output[i], int):
285+
if not model_output[i] == ref_output[i]:
286+
print(f"\n=== Output {i} comparison failed ===")
287+
print(f"{model_output[i]} vs {ref_output[[i]]}")
288+
result = False
289+
else:
290+
print(f"WARNING: Output {i} has type {type(ref_output[i])}")
291+
return result
288292
else:
289293
# If one output, eager returns tensor while executor tuple of size 1
290294
result = torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
@@ -326,7 +330,7 @@ def run_and_check_output(
326330
model_output = executorch_module.run_method("forward", tuple(inputs_flattened))
327331

328332
# Generate reference outputs using the reference model
329-
ref_output = reference_model(*sample_inputs)
333+
ref_output, _ = tree_flatten(reference_model(*sample_inputs))
330334

331335
# Check if outputs are equal
332336
return check_outputs_equal(
@@ -805,3 +809,26 @@ def find_bad_operators(
805809
"all_operators": all_operators,
806810
"test_count": test_count,
807811
}
812+
813+
814+
def make_indent(indent_level):
815+
indent_str = ""
816+
for _ in range(indent_level):
817+
indent_str += " "
818+
return indent_str
819+
820+
821+
def print_output(outputs, n: int = 0, indent_level: int = 0):
822+
if isinstance(outputs, (list, tuple)):
823+
print(f"{make_indent(indent_level)}output_{n} = {type(outputs)}")
824+
new_indent_level = indent_level + 2
825+
for n, test_out in enumerate(outputs):
826+
print_output(test_out, n, new_indent_level)
827+
elif isinstance(outputs, torch.Tensor):
828+
print(
829+
f"{make_indent(indent_level)}output_{n} = test_utils.random_uniform_tensor({outputs.shape}, low={outputs.min().item()}, high={outputs.max().item()}, dtype={outputs.dtype})"
830+
)
831+
elif isinstance(outputs, int):
832+
print(f"{make_indent(indent_level)}output_{n} = {outputs}")
833+
else:
834+
print(f"{make_indent(indent_level)}output_{n} = {type(outputs)}")

backends/vulkan/vulkan_preprocess.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import copy
910
from functools import partial
1011
from typing import Any, Callable, Dict, final, List
1112

@@ -127,15 +128,21 @@ def preprocess( # noqa: C901
127128
module_compile_spec: List[CompileSpec],
128129
) -> PreprocessResult:
129130
compile_options = parse_compile_spec(module_compile_spec)
130-
limits_x = compile_options.get(
131-
"texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0]
132-
)
133-
limits_y = compile_options.get(
134-
"texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1]
135-
)
136-
limits_z = compile_options.get(
137-
"texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2]
138-
)
131+
132+
default_texture_limits = copy.deepcopy(utils.DEFAULT_TEXTURE_LIMITS)
133+
# 2048 is the typical limit value for 3D textures, but mobile GPUs often support
134+
# 16384. Since the Vulkan delegate primarily targets mobile GPUs at the moment,
135+
# 16394 is the default texture limit used. This option is provided as a
136+
# convenient way to switch to using a limit of 2048 for image textures which
137+
# will be compatible with most GPUs.
138+
if compile_options.get("small_texture_limits", False):
139+
default_texture_limits[0] = 2048
140+
default_texture_limits[1] = 2048
141+
default_texture_limits[2] = 2048
142+
143+
limits_x = compile_options.get("texture_limits_x", default_texture_limits[0])
144+
limits_y = compile_options.get("texture_limits_y", default_texture_limits[1])
145+
limits_z = compile_options.get("texture_limits_z", default_texture_limits[2])
139146
texture_limits = (limits_x, limits_y, limits_z)
140147

141148
default_storage_type = compile_options.get(
@@ -204,22 +211,26 @@ def preprocess( # noqa: C901
204211

205212
# Finally, apply dynamic shape passes and memory planning pass. These passes
206213
# must be applied only when the graph structure is finalized.
207-
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
208-
mem_planning_suite = MemoryPlanningAlgorithmSuite(
209-
algo_list=[greedy_memory_planning]
210-
)
211-
# This is a workaround to allow the memory planning pass to work without having
212-
# to first apply ToOutVarPass(). See the `greedy()` function in
213-
# `exir.memory_planning`; if this attribute isn't set, assertions in
214-
# `collect_spec_from_nodes()` will fail.
215-
program.graph_module.encounter_to_out_var_failure = True
216-
program = apply_passes(
217-
program,
218-
[
219-
ConstraintBasedSymShapeEvalPass(),
220-
MemoryPlanningPass(memory_planning_algo=mem_planning_suite),
221-
],
222-
)
214+
final_passes = [
215+
ConstraintBasedSymShapeEvalPass(),
216+
]
217+
if not compile_options.get("skip_memory_planning", False):
218+
greedy_memory_planning = partial(
219+
greedy, allow_overlapping_allocations=False
220+
)
221+
mem_planning_suite = MemoryPlanningAlgorithmSuite(
222+
algo_list=[greedy_memory_planning]
223+
)
224+
# This is a workaround to allow the memory planning pass to work without having
225+
# to first apply ToOutVarPass(). See the `greedy()` function in
226+
# `exir.memory_planning`; if this attribute isn't set, assertions in
227+
# `collect_spec_from_nodes()` will fail.
228+
program.graph_module.encounter_to_out_var_failure = True
229+
final_passes.append(
230+
MemoryPlanningPass(memory_planning_algo=mem_planning_suite)
231+
)
232+
233+
program = apply_passes(program, final_passes)
223234

224235
graph_builder = VkGraphBuilder(
225236
program,

0 commit comments

Comments
 (0)