Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,34 @@ jobs:
# run e2e (export, tokenizer and runner)
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llava.sh

test-preprocess-linux:
name: test-preprocess-linux
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
strategy:
fail-fast: false
with:
runner: linux.24xlarge
docker-image: executorch-ubuntu-22.04-clang12
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"

PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"

# install pybind
bash install_requirements.sh --pybind xnnpack

# install preprocess requirements
bash examples/models/llama3_2_vision/install_requirements.sh

# run python unittest
python -m unittest examples.models.llama3_2_vision.preprocess.test_preprocess
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't think we need a new job right? Just add it to pytest.ini should be good enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to use pytest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So not sure if you need to run bash examples/models/llama3_2_vision/install_requirements.sh though. Let's wait and see.



test-quantized-aot-lib-linux:
name: test-quantized-aot-lib-linux
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
Expand Down
3 changes: 3 additions & 0 deletions examples/models/llama3_2_vision/preprocess/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class PreprocessConfig:
max_num_tiles: int = 4
tile_size: int = 224
antialias: bool = False
# Used for reference eager model from torchtune.
resize_to_max_canvas: bool = False
possible_resolutions: Optional[List[Tuple[int, int]]] = None


class CLIPImageTransformModel(EagerModelBase):
Expand Down
238 changes: 124 additions & 114 deletions examples/models/llama3_2_vision/preprocess/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,32 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Tuple

import numpy as np
import PIL
import torch

# Import these first. Otherwise, the custom ops are not registered.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.examples.models.llama3_2_vision.preprocess.export_preprocess_lib import (
export_preprocess,
get_example_inputs,
lower_to_executorch_preprocess,
from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa # usort: skip

from executorch.examples.models.llama3_2_vision.preprocess.model import (
CLIPImageTransformModel,
PreprocessConfig,
)

from executorch.exir import EdgeCompileConfig, to_edge

from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)

from parameterized import parameterized
from PIL import Image

from torchtune.models.clip.inference._transform import (
_CLIPImageTransform,
CLIPImageTransform,
)
from torchtune.models.clip.inference._transform import CLIPImageTransform

from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import (
find_supported_resolutions,
Expand All @@ -43,18 +41,6 @@
from torchvision.transforms.v2 import functional as F


@dataclass
class PreprocessConfig:
image_mean: Optional[List[float]] = None
image_std: Optional[List[float]] = None
resize_to_max_canvas: bool = True
resample: str = "bilinear"
antialias: bool = False
tile_size: int = 224
max_num_tiles: int = 4
possible_resolutions = None


class TestImageTransform(unittest.TestCase):
"""
This unittest checks that the exported image transform model produces the
Expand All @@ -66,6 +52,58 @@ class TestImageTransform(unittest.TestCase):
https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26
"""

@staticmethod
def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas)

reference_model = CLIPImageTransform(
image_mean=config.image_mean,
image_std=config.image_std,
resample=config.resample,
antialias=config.antialias,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
resize_to_max_canvas=config.resize_to_max_canvas,
possible_resolutions=None,
)

model = CLIPImageTransformModel(config)

exported_model = torch.export.export(
model.get_eager_model(),
model.get_example_inputs(),
dynamic_shapes=model.get_dynamic_shapes(),
strict=False,
)

# aoti_path = torch._inductor.aot_compile(
# exported_model.module(),
# model.get_example_inputs(),
# )

edge_program = to_edge(
exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)
executorch_model = edge_program.to_executorch()

return {
"config": config,
"reference_model": reference_model,
"model": model,
"exported_model": exported_model,
# "aoti_path": aoti_path,
"executorch_model": executorch_model,
}

@classmethod
def setUpClass(cls):
cls.models_no_resize = TestImageTransform.initialize_models(
resize_to_max_canvas=False
)
cls.models_resize = TestImageTransform.initialize_models(
resize_to_max_canvas=True
)

def setUp(self):
np.random.seed(0)

Expand Down Expand Up @@ -121,51 +159,7 @@ def prepare_inputs(

return image_tensor, inscribed_size, best_resolution

# This test setup mirrors the one in torchtune:
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
# The values are slightly different, as torchtune uses antialias=True,
# and this test uses antialias=False, which is exportable (has a portable kernel).
@parameterized.expand(
[
(
(100, 400, 3), # image_size
torch.Size([2, 3, 224, 224]), # expected shape
False, # resize_to_max_canvas
[0.2230, 0.1763], # expected_tile_means
[1.0, 1.0], # expected_tile_max
[0.0, 0.0], # expected_tile_min
[1, 2], # expected_aspect_ratio
),
(
(1000, 300, 3), # image_size
torch.Size([4, 3, 224, 224]), # expected shape
True, # resize_to_max_canvas
[0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means
[0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max
[0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min
[4, 1], # expected_aspect_ratio
),
(
(200, 200, 3), # image_size
torch.Size([4, 3, 224, 224]), # expected shape
True, # resize_to_max_canvas
[0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means
[0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max
[0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min
[2, 2], # expected_aspect_ratio
),
(
(600, 200, 3), # image_size
torch.Size([3, 3, 224, 224]), # expected shape
False, # resize_to_max_canvas
[0.4472, 0.4468, 0.3031], # expected_tile_means
[1.0, 1.0, 1.0], # expected_tile_max
[0.0, 0.0, 0.0], # expected_tile_min
[3, 1], # expected_aspect_ratio
),
]
)
def test_preprocess(
def run_preprocess(
self,
image_size: Tuple[int],
expected_shape: torch.Size,
Expand All @@ -175,45 +169,7 @@ def test_preprocess(
expected_tile_min: List[float],
expected_ar: List[int],
) -> None:
config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas)

reference_model = CLIPImageTransform(
image_mean=config.image_mean,
image_std=config.image_std,
resize_to_max_canvas=config.resize_to_max_canvas,
resample=config.resample,
antialias=config.antialias,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
possible_resolutions=None,
)

eager_model = _CLIPImageTransform(
image_mean=config.image_mean,
image_std=config.image_std,
resample=config.resample,
antialias=config.antialias,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
)

exported_model = export_preprocess(
image_mean=config.image_mean,
image_std=config.image_std,
resample=config.resample,
antialias=config.antialias,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
)

executorch_model = lower_to_executorch_preprocess(exported_model)
executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer)

aoti_path = torch._inductor.aot_compile(
exported_model.module(),
get_example_inputs(),
)

models = self.models_resize if resize_to_max_canvas else self.models_no_resize
# Prepare image input.
image = (
np.random.randint(0, 256, np.prod(image_size))
Expand All @@ -223,6 +179,7 @@ def test_preprocess(
image = PIL.Image.fromarray(image)

# Run reference model.
reference_model = models["reference_model"]
reference_output = reference_model(image=image)
reference_image = reference_output["image"]
reference_ar = reference_output["aspect_ratio"].tolist()
Expand All @@ -249,10 +206,11 @@ def test_preprocess(
# Pre-work for eager and exported models. The reference model performs these
# calculations and passes the result to _CLIPImageTransform, the exportable model.
image_tensor, inscribed_size, best_resolution = self.prepare_inputs(
image=image, config=config
image=image, config=models["config"]
)

# Run eager model and check it matches reference model.
eager_model = models["model"].get_eager_model()
eager_image, eager_ar = eager_model(
image_tensor, inscribed_size, best_resolution
)
Expand All @@ -261,6 +219,7 @@ def test_preprocess(
self.assertEqual(reference_ar, eager_ar)

# Run exported model and check it matches reference model.
exported_model = models["exported_model"]
exported_image, exported_ar = exported_model.module()(
image_tensor, inscribed_size, best_resolution
)
Expand All @@ -269,14 +228,65 @@ def test_preprocess(
self.assertEqual(reference_ar, exported_ar)

# Run executorch model and check it matches reference model.
executorch_model = models["executorch_model"]
executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer)
et_image, et_ar = executorch_module.forward(
(image_tensor, inscribed_size, best_resolution)
)
self.assertTrue(torch.allclose(reference_image, et_image))
self.assertEqual(reference_ar, et_ar.tolist())

# Run aoti model and check it matches reference model.
aoti_model = torch._export.aot_load(aoti_path, "cpu")
aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
self.assertTrue(torch.allclose(reference_image, aoti_image))
self.assertEqual(reference_ar, aoti_ar.tolist())
# aoti_path = models["aoti_path"]
# aoti_model = torch._export.aot_load(aoti_path, "cpu")
# aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
# self.assertTrue(torch.allclose(reference_image, aoti_image))
# self.assertEqual(reference_ar, aoti_ar.tolist())

# This test setup mirrors the one in torchtune:
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
# The values are slightly different, as torchtune uses antialias=True,
# and this test uses antialias=False, which is exportable (has a portable kernel).
def test_preprocess1(self):
self.run_preprocess(
(100, 400, 3), # image_size
torch.Size([2, 3, 224, 224]), # expected shape
False, # resize_to_max_canvas
[0.2230, 0.1763], # expected_tile_means
[1.0, 1.0], # expected_tile_max
[0.0, 0.0], # expected_tile_min
[1, 2], # expected_aspect_ratio
)

def test_preprocess2(self):
self.run_preprocess(
(1000, 300, 3), # image_size
torch.Size([4, 3, 224, 224]), # expected shape
True, # resize_to_max_canvas
[0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means
[0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max
[0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min
[4, 1], # expected_aspect_ratio
)

def test_preprocess3(self):
self.run_preprocess(
(200, 200, 3), # image_size
torch.Size([4, 3, 224, 224]), # expected shape
True, # resize_to_max_canvas
[0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means
[0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max
[0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min
[2, 2], # expected_aspect_ratio
)

def test_preprocess4(self):
self.run_preprocess(
(600, 200, 3), # image_size
torch.Size([3, 3, 224, 224]), # expected shape
False, # resize_to_max_canvas
[0.4472, 0.4468, 0.3031], # expected_tile_means
[1.0, 1.0, 1.0], # expected_tile_max
[0.0, 0.0, 0.0], # expected_tile_min
[3, 1], # expected_aspect_ratio
)
Loading