Skip to content

Commit 8dc6064

Browse files
committed
Refactor XNNPACK tester to extract base class
1 parent 120eb85 commit 8dc6064

File tree

16 files changed

+1168
-762
lines changed

16 files changed

+1168
-762
lines changed

backends/test/harness/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .tester import Tester
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .export import Export
2+
from .partition import Partition
3+
from .quantize import Quantize
4+
from .run_passes import RunPasses
5+
from .serialize import Serialize
6+
from .stage import Stage, StageType
7+
from .to_edge import ToEdge
8+
from .to_edge_transform_and_lower import ToEdgeTransformAndLower
9+
from .to_executorch import ToExecutorch
10+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Any, Optional, Sequence, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from torch.export import export, ExportedProgram
7+
8+
class Export(Stage):
9+
def __init__(self, dynamic_shapes: Optional[Tuple[Any]] = None):
10+
self.exported_program = None
11+
self.dynamic_shapes = dynamic_shapes
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.EXPORT
15+
16+
def run(
17+
self,
18+
artifact: torch.nn.Module,
19+
inputs: Tuple[torch.Tensor],
20+
) -> None:
21+
self.exported_program = export(
22+
artifact, inputs, dynamic_shapes=self.dynamic_shapes, strict=True
23+
)
24+
25+
@property
26+
def artifact(self) -> ExportedProgram:
27+
return self.exported_program
28+
29+
@property
30+
def graph_module(self) -> str:
31+
return self.exported_program.graph_module
32+
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from executorch.backends.test.harness.stages.stage import Stage, StageType
2+
from executorch.exir import (
3+
EdgeProgramManager,
4+
)
5+
from executorch.exir.backend.backend_api import validation_disabled
6+
from executorch.exir.backend.partitioner import Partitioner
7+
8+
class Partition(Stage):
9+
def __init__(self, partitioner: Partitioner):
10+
self.partitioner = partitioner
11+
self.delegate_module = None
12+
13+
def stage_type(self) -> StageType:
14+
return StageType.PARTITION
15+
16+
def run(self, artifact: EdgeProgramManager, inputs=None):
17+
with validation_disabled():
18+
self.delegate_module = artifact
19+
self.delegate_module = self.delegate_module.to_backend(self.partitioner)
20+
21+
@property
22+
def artifact(self) -> EdgeProgramManager:
23+
return self.delegate_module
24+
25+
@property
26+
def graph_module(self) -> str:
27+
return self.delegate_module.exported_program().graph_module
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Any, Optional, Sequence, Tuple
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
7+
DuplicateDynamicQuantChainPass,
8+
)
9+
10+
from torch.export import export_for_training
11+
12+
from torchao.quantization.pt2e.quantize_pt2e import (
13+
convert_pt2e,
14+
prepare_pt2e,
15+
prepare_qat_pt2e,
16+
)
17+
from torchao.quantization.pt2e.quantizer import Quantizer
18+
19+
class Quantize(Stage):
20+
def __init__(
21+
self,
22+
quantizer: Optional[Quantizer] = None,
23+
quantization_config: Optional[Any] = None,
24+
calibrate: bool = True,
25+
calibration_samples: Optional[Sequence[Any]] = None,
26+
is_qat: Optional[bool] = False,
27+
):
28+
self.quantizer = quantizer
29+
self.quantization_config = quantization_config
30+
self.calibrate = calibrate
31+
self.calibration_samples = calibration_samples
32+
33+
self.quantizer.set_global(self.quantization_config)
34+
35+
self.converted_graph = None
36+
self.is_qat = is_qat
37+
38+
def stage_type(self) -> str:
39+
return StageType.QUANTIZE
40+
41+
def run(
42+
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
43+
) -> None:
44+
assert inputs is not None
45+
if self.is_qat:
46+
artifact.train()
47+
captured_graph = export_for_training(artifact, inputs, strict=True).module()
48+
49+
assert isinstance(captured_graph, torch.fx.GraphModule)
50+
51+
if self.is_qat:
52+
prepared = prepare_qat_pt2e(captured_graph, self.quantizer)
53+
else:
54+
prepared = prepare_pt2e(captured_graph, self.quantizer)
55+
56+
if self.calibrate:
57+
# Calibrate prepared model to provide data to quantization observers.
58+
if self.calibration_samples is not None:
59+
for inp in self.calibration_samples:
60+
prepared(*inp)
61+
else:
62+
prepared(*inputs)
63+
64+
converted = convert_pt2e(prepared)
65+
DuplicateDynamicQuantChainPass()(converted)
66+
67+
self.converted_graph = converted
68+
69+
@property
70+
def artifact(self) -> torch.fx.GraphModule:
71+
return self.converted_graph
72+
73+
@property
74+
def graph_module(self) -> str:
75+
return self.converted_graph
76+
77+
def run_artifact(self, inputs):
78+
return self.converted_graph.forward(*inputs)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any, Callable, List, Optional, Sequence, Type, Tuple, Union
2+
3+
import torch
4+
5+
from executorch.backends.test.harness.stages.stage import Stage, StageType
6+
from executorch.exir import (
7+
EdgeCompileConfig,
8+
EdgeProgramManager,
9+
)
10+
from executorch.exir.backend.partitioner import Partitioner
11+
from torch._export.pass_base import PassType
12+
from torch.export import ExportedProgram
13+
14+
class RunPasses(Stage):
15+
def __init__(
16+
self,
17+
pass_manager_cls: Type,
18+
pass_list: Optional[List[Type[PassType]]] = None,
19+
pass_functions: Optional[List[Callable]] = None,
20+
):
21+
self.pass_manager_cls = pass_manager_cls
22+
self.pass_list = pass_list
23+
self.pass_functions = pass_functions
24+
self.edge_or_aten_program = None
25+
26+
def stage_type(self) -> StageType:
27+
return StageType.RUN_PASSES
28+
29+
def run(
30+
self, artifact: Union[EdgeProgramManager, ExportedProgram], inputs=None
31+
) -> None:
32+
if isinstance(artifact, EdgeProgramManager):
33+
self.edge_or_aten_program = artifact
34+
if self.pass_list:
35+
pass_manager = self.pass_manager_cls(
36+
artifact.exported_program(), self.pass_list
37+
)
38+
self.edge_or_aten_program._edge_programs["forward"] = (
39+
pass_manager.transform()
40+
)
41+
if self.pass_functions:
42+
assert isinstance(self.pass_functions, list)
43+
for pass_function in self.pass_functions:
44+
self.edge_or_aten_program._edge_programs["forward"] = pass_function(
45+
self.edge_or_aten_program.exported_program()
46+
)
47+
else:
48+
transformed_ep = artifact
49+
if self.pass_list:
50+
assert isinstance(self.pass_list, list)
51+
for pass_ in self.pass_list:
52+
transformed_ep = _transform(transformed_ep, pass_())
53+
54+
if self.pass_functions:
55+
assert isinstance(self.pass_functions, list)
56+
for pass_function in self.pass_functions:
57+
transformed_ep = pass_function(transformed_ep)
58+
59+
self.edge_or_aten_program = transformed_ep
60+
61+
@property
62+
def artifact(self) -> Union[EdgeProgramManager, ExportedProgram]:
63+
return self.edge_or_aten_program
64+
65+
@property
66+
def graph_module(self) -> str:
67+
if isinstance(self.edge_or_aten_program, EdgeProgramManager):
68+
return self.edge_or_aten_program.exported_program().graph_module
69+
else:
70+
return self.edge_or_aten_program.graph_module
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import copy
2+
import logging
3+
4+
from typing import Optional
5+
6+
from executorch.backends.test.harness.stages.stage import Stage, StageType
7+
from executorch.exir import (
8+
EdgeCompileConfig,
9+
EdgeProgramManager,
10+
ExecutorchBackendConfig,
11+
ExecutorchProgramManager,
12+
)
13+
14+
from torch.utils._pytree import tree_flatten
15+
16+
logger = logging.getLogger(__name__)
17+
logger.setLevel(logging.INFO)
18+
try:
19+
from executorch.extension.pybindings.portable_lib import ( # @manual
20+
_load_for_executorch_from_buffer,
21+
)
22+
except ImportError as e:
23+
logger.warning(f"{e=}")
24+
pass
25+
26+
class Serialize(Stage):
27+
def __init__(self):
28+
self.buffer = None
29+
30+
def stage_type(self) -> StageType:
31+
return StageType.SERIALIZE
32+
33+
def run(self, artifact: ExecutorchProgramManager, inputs=None) -> None:
34+
self.buffer = artifact.buffer
35+
36+
@property
37+
def artifact(self) -> bytes:
38+
return self.buffer
39+
40+
@property
41+
def graph_module(self) -> None:
42+
return None
43+
44+
def run_artifact(self, inputs):
45+
inputs_flattened, _ = tree_flatten(inputs)
46+
executorch_module = _load_for_executorch_from_buffer(self.buffer)
47+
executorch_output = copy.deepcopy(
48+
executorch_module.run_method("forward", tuple(inputs_flattened))
49+
)
50+
return executorch_output
51+
52+
def dump_artifact(self, path_to_dump: Optional[str]):
53+
"""
54+
dump_artifact is overridden to dump the serialized bytes into pte file
55+
"""
56+
if not path_to_dump:
57+
raise RuntimeError("path_to_dump file not provided")
58+
else:
59+
with open(path_to_dump, "wb") as f:
60+
f.write(self.artifact)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from abc import ABC, abstractmethod
2+
from enum import Enum
3+
from typing import Optional
4+
5+
from executorch.exir import (
6+
EdgeProgramManager,
7+
)
8+
9+
10+
class StageType(Enum):
11+
QUANTIZE = 0
12+
EXPORT = 1
13+
RUN_PASSES = 2
14+
TO_EDGE = 3
15+
TO_EDGE_TRANSFORM_AND_LOWER = 4
16+
PARTITION = 5
17+
TO_EXECUTORCH = 6
18+
SERIALIZE = 7
19+
20+
21+
class Stage(ABC):
22+
"""
23+
Interface for a Stage in the PT2.0 lowering pipeline
24+
"""
25+
26+
@abstractmethod
27+
def stage_type(self) -> StageType:
28+
"""
29+
Returns the type of the stage.
30+
"""
31+
pass
32+
33+
@abstractmethod
34+
def run(self, artifact, inputs):
35+
"""
36+
Executes this stage, generates the 'artifact', for later stages.
37+
"""
38+
pass
39+
40+
@property
41+
@abstractmethod
42+
def artifact(self):
43+
"""
44+
Returns the artifact generated by this stage. To be used by the next stage in the pipeline.
45+
"""
46+
pass
47+
48+
@property
49+
@abstractmethod
50+
def graph_module(self):
51+
"""
52+
Return the artifact's graph module for this stage
53+
"""
54+
pass
55+
56+
def run_artifact(self, inputs):
57+
"""
58+
Returns the output of calling the artifact generated by this stage with inputs
59+
"""
60+
if isinstance(self.artifact, ExportedProgram):
61+
return self.artifact(*inputs)
62+
else:
63+
return self.artifact.exported_program().module()(*inputs)
64+
65+
# Debug Tools for stages
66+
def artifact_str(self):
67+
"""
68+
Return string printable artifact for this stage
69+
"""
70+
if isinstance(self.artifact, EdgeProgramManager):
71+
return self.artifact.exported_program()
72+
return self.artifact
73+
74+
def stage_banner(self):
75+
"""
76+
Returns banner string for this stage
77+
"""
78+
return "#" * 36 + " " + str(self.__class__.__name__) + " " + "#" * 36 + "\n"
79+
80+
def dump_artifact(self, path_to_dump: Optional[str]):
81+
"""
82+
Dumps string printable artifact to path. If path_to_dump, then it is printed to terminal
83+
"""
84+
if path_to_dump:
85+
with open(path_to_dump, "a") as fp:
86+
fp.write(str(self.stage_banner() + "\n"))
87+
fp.write(str(self.artifact_str()))
88+
else:
89+
print(self.stage_banner() + "\n")
90+
print(self.artifact_str())

0 commit comments

Comments
 (0)