Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.


import unittest
from typing import Tuple

import torch
from executorch.backends.arm._passes import (
Expand All @@ -17,11 +17,17 @@
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
CLIP_text_encoder_config,
)
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)
from transformers import CLIPTextModelWithProjection

input_t = Tuple[torch.Tensor]


class TestCLIPTextModelWithProjection(unittest.TestCase):
class TestCLIPTextModelWithProjection:
"""
Test class of CLIPTextModelWithProjection.
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
Expand Down Expand Up @@ -69,47 +75,93 @@ def prepare_model_and_inputs(self):

return text_encoder_model, text_encoder_model_inputs

def test_CLIPTextModelWithProjection_tosa_FP(self):
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
with torch.no_grad():
(
ArmTester(
text_encoder_model,
example_inputs=text_encoder_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
transform_passes=[
ConvertInt64ConstOpsToInt32Pass(),
ConvertInt64OutputOpsToInt32Pass(),
InsertInt32CastsAfterInt64PlaceholdersPass(),
],
)
.export()
.to_edge_transform_and_lower()
.dump_operator_distribution()
.check_count(self.ops_after_partitioner_FP)
.to_executorch()
.run_method_and_compare_outputs(
inputs=text_encoder_model_inputs,
)
)

def test_CLIPTextModelWithProjection_tosa_INT(self):
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
with torch.no_grad():
(
ArmTester(
text_encoder_model,
example_inputs=text_encoder_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
)
.quantize()
.export()
.to_edge_transform_and_lower()
.dump_operator_distribution()
.check_count(self.ops_after_partitioner_INT)
.to_executorch()
.run_method_and_compare_outputs(
inputs=text_encoder_model_inputs,
atol=0.8,
)
)

def test_CLIPTextModelWithProjection_tosa_FP():
text_encoder_model, text_encoder_model_inputs = (
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = TosaPipelineFP[input_t](
text_encoder_model,
text_encoder_model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
transform_passes=[
ConvertInt64ConstOpsToInt32Pass(),
ConvertInt64OutputOpsToInt32Pass(),
InsertInt32CastsAfterInt64PlaceholdersPass(),
],
)
pipeline.change_args(
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
)
pipeline.run()


def test_CLIPTextModelWithProjection_tosa_INT():
text_encoder_model, text_encoder_model_inputs = (
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = TosaPipelineINT[input_t](
text_encoder_model,
text_encoder_model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
atol=0.8,
)
pipeline.change_args(
"check_count.exir",
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
)
pipeline.run()


@common.SkipIfNoModelConverter
def test_CLIPTextModelWithProjection_vgf_FP():
text_encoder_model, text_encoder_model_inputs = (
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = VgfPipeline[input_t](
text_encoder_model,
text_encoder_model_inputs,
aten_op=[],
exir_op=[],
tosa_version="TOSA-1.0+FP",
use_to_edge_transform_and_lower=True,
atol=4, # TODO: Investiage numerical issue: MAX Diff ~50%
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure this test isn't flaky.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's a good point. I tested it locally and 4 should be enough. But we can increase it later if we see flaky issue upstream.

transform_passes=[
ConvertInt64ConstOpsToInt32Pass(),
ConvertInt64OutputOpsToInt32Pass(),
InsertInt32CastsAfterInt64PlaceholdersPass(),
],
)
pipeline.change_args(
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
)
pipeline.run()


@common.SkipIfNoModelConverter
def test_CLIPTextModelWithProjection_vgf_INT():
text_encoder_model, text_encoder_model_inputs = (
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = VgfPipeline[input_t](
text_encoder_model,
text_encoder_model_inputs,
aten_op=[],
exir_op=[],
tosa_version="TOSA-1.0+INT",
use_to_edge_transform_and_lower=True,
atol=0.8,
)
pipeline.change_args(
"check_count.exir",
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
)
pipeline.run()
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.


import unittest
from typing import Tuple

import torch
from diffusers.models.transformers import SD3Transformer2DModel
Expand All @@ -13,10 +13,16 @@
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
SD3Transformer2DModel_init_dict,
)
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

input_t4 = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]


class TestSD3Transformer2DModel(unittest.TestCase):
class TestSD3Transformer2DModel:
"""
Test class of AutoenSD3Transformer2DModelcoderKL.
SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium
Expand Down Expand Up @@ -93,48 +99,88 @@ def forward(self, *args, **kwargs):

return sd35_transformer2D_model, sd35_transformer2D_model_inputs

def test_SD3Transformer2DModel_tosa_FP(self):
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
self.prepare_model_and_inputs()
)
with torch.no_grad():
(
ArmTester(
sd35_transformer2D_model,
example_inputs=sd35_transformer2D_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
)
.export()
.to_edge_transform_and_lower()
.check_count(self.ops_after_partitioner_FP)
.to_executorch()
.run_method_and_compare_outputs(
inputs=sd35_transformer2D_model_inputs,
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
atol=4.0,
)
)

def test_SD3Transformer2DModel_tosa_INT(self):
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
self.prepare_model_and_inputs()
def test_SD3Transformer2DModel_tosa_FP():
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
TestSD3Transformer2DModel().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = TosaPipelineFP[input_t4](
sd35_transformer2D_model,
sd35_transformer2D_model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
atol=4.0,
)
with torch.no_grad():
(
ArmTester(
sd35_transformer2D_model,
example_inputs=sd35_transformer2D_model_inputs,
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
)
.quantize()
.export()
.to_edge_transform_and_lower()
.check_count(self.ops_after_partitioner_INT)
.to_executorch()
.run_method_and_compare_outputs(
inputs=sd35_transformer2D_model_inputs,
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
rtol=1.0,
atol=4.0,
)
)
pipeline.change_args(
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
)
pipeline.run()


def test_SD3Transformer2DModel_tosa_INT():
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
TestSD3Transformer2DModel().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = TosaPipelineINT[input_t4](
sd35_transformer2D_model,
sd35_transformer2D_model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
rtol=1.0,
atol=4.0,
)
pipeline.change_args(
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
)
pipeline.run()


@common.SkipIfNoModelConverter
def test_SD3Transformer2DModel_vgf_FP():
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
TestSD3Transformer2DModel().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = VgfPipeline[input_t4](
sd35_transformer2D_model,
sd35_transformer2D_model_inputs,
aten_op=[],
exir_op=[],
tosa_version="TOSA-1.0+FP",
use_to_edge_transform_and_lower=True,
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
atol=4.0,
)
pipeline.change_args(
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
)
pipeline.run()


@common.SkipIfNoModelConverter
def test_SD3Transformer2DModel_vgf_INT():
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
TestSD3Transformer2DModel().prepare_model_and_inputs()
)
with torch.no_grad():
pipeline = VgfPipeline[input_t4](
sd35_transformer2D_model,
sd35_transformer2D_model_inputs,
aten_op=[],
exir_op=[],
tosa_version="TOSA-1.0+INT",
use_to_edge_transform_and_lower=True,
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
rtol=1.0,
atol=4.0,
)
pipeline.change_args(
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
)
pipeline.run()
Loading
Loading