Skip to content

Commit 56a99e9

Browse files
committed
refactor preprocess to use EagerModelBase
1 parent 8234c14 commit 56a99e9

File tree

3 files changed

+113
-99
lines changed

3 files changed

+113
-99
lines changed

examples/models/llama3_2_vision/preprocess/export_preprocess.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,60 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Dict, List, Optional, Tuple
8+
79
import torch
810
from executorch.examples.models.llama3_2_vision.preprocess.export_preprocess_lib import (
9-
export_preprocess,
10-
get_example_inputs,
1111
lower_to_executorch_preprocess,
1212
)
13+
from executorch.examples.models.llama3_2_vision.preprocess.model import (
14+
CLIPImageTransform,
15+
PreprocessConfig,
16+
)
17+
from executorch.exir import (
18+
dynamic_shape,
19+
EdgeCompileConfig,
20+
ExecutorchBackendConfig,
21+
to_edge,
22+
)
23+
from torch.export import Dim, ExportedProgram
1324

1425

1526
def main():
27+
# Eager model.
28+
model = CLIPImageTransformModel(PreprocessConfig())
1629

17-
# ExecuTorch
18-
ep_et = export_preprocess()
19-
et = lower_to_executorch_preprocess(ep_et)
20-
with open("preprocess_et.pte", "wb") as file:
21-
et.write_to_file(file)
22-
23-
# AOTInductor
24-
ep_aoti = export_preprocess()
25-
torch._inductor.aot_compile(
26-
ep_aoti.module(),
27-
get_example_inputs(),
28-
options={"aot_inductor.output_path": "preprocess_aoti.so"},
30+
# Export.
31+
ep = torch.export.export(
32+
model.get_eager_model(),
33+
model.get_example_inputs(),
34+
dynamic_shapes=model.get_dynamic_shapes(),
35+
strict=False,
2936
)
3037

38+
# Executorch
39+
edge_program = to_edge(
40+
ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
41+
)
42+
et_program = edge_program.to_executorch()
43+
with open("preprocess_et.pte", "wb") as file:
44+
et_program.write_to_file(file)
45+
46+
# Export.
47+
# ep = torch.export.export(
48+
# model.get_eager_model(),
49+
# model.get_example_inputs(),
50+
# dynamic_shapes=model.get_dynamic_shapes(),
51+
# strict=False,
52+
# )
53+
#
54+
# # AOTInductor
55+
# torch._inductor.aot_compile(
56+
# ep.module(),
57+
# model.get_example_inputs(),
58+
# options={"aot_inductor.output_path": "preprocess_aoti.so"},
59+
# )
60+
3161

3262
if __name__ == "__main__":
3363
main()

examples/models/llama3_2_vision/preprocess/export_preprocess_lib.py

Lines changed: 0 additions & 85 deletions
This file was deleted.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from dataclasses import dataclass
10+
from typing import Dict, List, Optional, Tuple
11+
12+
import torch
13+
14+
from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa
15+
from torch.export import Dim, ExportedProgram
16+
from torchtune.models.clip.inference._transform import _CLIPImageTransform
17+
18+
from ...model_base import EagerModelBase
19+
20+
21+
@dataclass
22+
class PreprocessConfig:
23+
image_mean: Optional[List[float]] = None
24+
image_std: Optional[List[float]] = None
25+
resample: str = "bilinear"
26+
max_num_tiles: int = 4
27+
tile_size: int = 224
28+
antialias: bool = False
29+
30+
31+
class CLIPImageTransformModel(EagerModelBase):
32+
def __init__(
33+
self,
34+
config: PreprocessConfig,
35+
):
36+
super().__init__()
37+
38+
# Eager model.
39+
self.model = _CLIPImageTransform(
40+
image_mean=config.image_mean,
41+
image_std=config.image_std,
42+
resample=config.resample,
43+
max_num_tiles=config.max_num_tiles,
44+
tile_size=config.tile_size,
45+
antialias=config.antialias,
46+
)
47+
48+
# Replace non-exportable ops with custom ops.
49+
self.model.tile_crop = torch.ops.preprocess.tile_crop.default
50+
51+
def get_eager_model(self) -> torch.nn.Module:
52+
return self.model
53+
54+
def get_example_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55+
image = torch.ones(3, 800, 600)
56+
target_size = torch.tensor([448, 336])
57+
canvas_size = torch.tensor([448, 448])
58+
return (image, target_size, canvas_size)
59+
60+
def get_dynamic_shapes(self) -> Dict[str, Dict[int, Dim]]:
61+
img_h = Dim("img_h", min=1, max=4000)
62+
img_w = Dim("img_w", min=1, max=4000)
63+
64+
dynamic_shapes = {
65+
"image": {1: img_h, 2: img_w},
66+
"target_size": None,
67+
"canvas_size": None,
68+
}
69+
return dynamic_shapes

0 commit comments

Comments
 (0)