Skip to content

Commit 499ce50

Browse files
Arm backend: Add VGF tests to StableDiffusion module tests (#14655)
Also refactor the StableDiffusion module tests to use test_pipeline instead of ArmTester directly. Signed-off-by: Yufeng Shi <[email protected]>
1 parent f24351a commit 499ce50

File tree

4 files changed

+359
-179
lines changed

4 files changed

+359
-179
lines changed

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 99 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import unittest
7+
from typing import Tuple
88

99
import torch
1010
from executorch.backends.arm._passes import (
@@ -17,11 +17,17 @@
1717
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
1818
CLIP_text_encoder_config,
1919
)
20-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
from executorch.backends.arm.test.tester.test_pipeline import (
21+
TosaPipelineFP,
22+
TosaPipelineINT,
23+
VgfPipeline,
24+
)
2125
from transformers import CLIPTextModelWithProjection
2226

27+
input_t = Tuple[torch.Tensor]
28+
2329

24-
class TestCLIPTextModelWithProjection(unittest.TestCase):
30+
class TestCLIPTextModelWithProjection:
2531
"""
2632
Test class of CLIPTextModelWithProjection.
2733
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
@@ -69,47 +75,93 @@ def prepare_model_and_inputs(self):
6975

7076
return text_encoder_model, text_encoder_model_inputs
7177

72-
def test_CLIPTextModelWithProjection_tosa_FP(self):
73-
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
74-
with torch.no_grad():
75-
(
76-
ArmTester(
77-
text_encoder_model,
78-
example_inputs=text_encoder_model_inputs,
79-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
80-
transform_passes=[
81-
ConvertInt64ConstOpsToInt32Pass(),
82-
ConvertInt64OutputOpsToInt32Pass(),
83-
InsertInt32CastsAfterInt64PlaceholdersPass(),
84-
],
85-
)
86-
.export()
87-
.to_edge_transform_and_lower()
88-
.dump_operator_distribution()
89-
.check_count(self.ops_after_partitioner_FP)
90-
.to_executorch()
91-
.run_method_and_compare_outputs(
92-
inputs=text_encoder_model_inputs,
93-
)
94-
)
95-
96-
def test_CLIPTextModelWithProjection_tosa_INT(self):
97-
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
98-
with torch.no_grad():
99-
(
100-
ArmTester(
101-
text_encoder_model,
102-
example_inputs=text_encoder_model_inputs,
103-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
104-
)
105-
.quantize()
106-
.export()
107-
.to_edge_transform_and_lower()
108-
.dump_operator_distribution()
109-
.check_count(self.ops_after_partitioner_INT)
110-
.to_executorch()
111-
.run_method_and_compare_outputs(
112-
inputs=text_encoder_model_inputs,
113-
atol=0.8,
114-
)
115-
)
78+
79+
def test_CLIPTextModelWithProjection_tosa_FP():
80+
text_encoder_model, text_encoder_model_inputs = (
81+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
82+
)
83+
with torch.no_grad():
84+
pipeline = TosaPipelineFP[input_t](
85+
text_encoder_model,
86+
text_encoder_model_inputs,
87+
aten_op=[],
88+
exir_op=[],
89+
use_to_edge_transform_and_lower=True,
90+
transform_passes=[
91+
ConvertInt64ConstOpsToInt32Pass(),
92+
ConvertInt64OutputOpsToInt32Pass(),
93+
InsertInt32CastsAfterInt64PlaceholdersPass(),
94+
],
95+
)
96+
pipeline.change_args(
97+
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
98+
)
99+
pipeline.run()
100+
101+
102+
def test_CLIPTextModelWithProjection_tosa_INT():
103+
text_encoder_model, text_encoder_model_inputs = (
104+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
105+
)
106+
with torch.no_grad():
107+
pipeline = TosaPipelineINT[input_t](
108+
text_encoder_model,
109+
text_encoder_model_inputs,
110+
aten_op=[],
111+
exir_op=[],
112+
use_to_edge_transform_and_lower=True,
113+
atol=0.8,
114+
)
115+
pipeline.change_args(
116+
"check_count.exir",
117+
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
118+
)
119+
pipeline.run()
120+
121+
122+
@common.SkipIfNoModelConverter
123+
def test_CLIPTextModelWithProjection_vgf_FP():
124+
text_encoder_model, text_encoder_model_inputs = (
125+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
126+
)
127+
with torch.no_grad():
128+
pipeline = VgfPipeline[input_t](
129+
text_encoder_model,
130+
text_encoder_model_inputs,
131+
aten_op=[],
132+
exir_op=[],
133+
tosa_version="TOSA-1.0+FP",
134+
use_to_edge_transform_and_lower=True,
135+
atol=4, # TODO: Investiage numerical issue: MAX Diff ~50%
136+
transform_passes=[
137+
ConvertInt64ConstOpsToInt32Pass(),
138+
ConvertInt64OutputOpsToInt32Pass(),
139+
InsertInt32CastsAfterInt64PlaceholdersPass(),
140+
],
141+
)
142+
pipeline.change_args(
143+
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
144+
)
145+
pipeline.run()
146+
147+
148+
@common.SkipIfNoModelConverter
149+
def test_CLIPTextModelWithProjection_vgf_INT():
150+
text_encoder_model, text_encoder_model_inputs = (
151+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
152+
)
153+
with torch.no_grad():
154+
pipeline = VgfPipeline[input_t](
155+
text_encoder_model,
156+
text_encoder_model_inputs,
157+
aten_op=[],
158+
exir_op=[],
159+
tosa_version="TOSA-1.0+INT",
160+
use_to_edge_transform_and_lower=True,
161+
atol=0.8,
162+
)
163+
pipeline.change_args(
164+
"check_count.exir",
165+
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
166+
)
167+
pipeline.run()

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 92 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import unittest
7+
from typing import Tuple
88

99
import torch
1010
from diffusers.models.transformers import SD3Transformer2DModel
@@ -13,10 +13,16 @@
1313
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
1414
SD3Transformer2DModel_init_dict,
1515
)
16-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
from executorch.backends.arm.test.tester.test_pipeline import (
17+
TosaPipelineFP,
18+
TosaPipelineINT,
19+
VgfPipeline,
20+
)
21+
22+
input_t4 = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
1723

1824

19-
class TestSD3Transformer2DModel(unittest.TestCase):
25+
class TestSD3Transformer2DModel:
2026
"""
2127
Test class of AutoenSD3Transformer2DModelcoderKL.
2228
SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium
@@ -93,48 +99,88 @@ def forward(self, *args, **kwargs):
9399

94100
return sd35_transformer2D_model, sd35_transformer2D_model_inputs
95101

96-
def test_SD3Transformer2DModel_tosa_FP(self):
97-
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
98-
self.prepare_model_and_inputs()
99-
)
100-
with torch.no_grad():
101-
(
102-
ArmTester(
103-
sd35_transformer2D_model,
104-
example_inputs=sd35_transformer2D_model_inputs,
105-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
106-
)
107-
.export()
108-
.to_edge_transform_and_lower()
109-
.check_count(self.ops_after_partitioner_FP)
110-
.to_executorch()
111-
.run_method_and_compare_outputs(
112-
inputs=sd35_transformer2D_model_inputs,
113-
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
114-
atol=4.0,
115-
)
116-
)
117102

118-
def test_SD3Transformer2DModel_tosa_INT(self):
119-
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
120-
self.prepare_model_and_inputs()
103+
def test_SD3Transformer2DModel_tosa_FP():
104+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
105+
TestSD3Transformer2DModel().prepare_model_and_inputs()
106+
)
107+
with torch.no_grad():
108+
pipeline = TosaPipelineFP[input_t4](
109+
sd35_transformer2D_model,
110+
sd35_transformer2D_model_inputs,
111+
aten_op=[],
112+
exir_op=[],
113+
use_to_edge_transform_and_lower=True,
114+
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
115+
atol=4.0,
121116
)
122-
with torch.no_grad():
123-
(
124-
ArmTester(
125-
sd35_transformer2D_model,
126-
example_inputs=sd35_transformer2D_model_inputs,
127-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
128-
)
129-
.quantize()
130-
.export()
131-
.to_edge_transform_and_lower()
132-
.check_count(self.ops_after_partitioner_INT)
133-
.to_executorch()
134-
.run_method_and_compare_outputs(
135-
inputs=sd35_transformer2D_model_inputs,
136-
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
137-
rtol=1.0,
138-
atol=4.0,
139-
)
140-
)
117+
pipeline.change_args(
118+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
119+
)
120+
pipeline.run()
121+
122+
123+
def test_SD3Transformer2DModel_tosa_INT():
124+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
125+
TestSD3Transformer2DModel().prepare_model_and_inputs()
126+
)
127+
with torch.no_grad():
128+
pipeline = TosaPipelineINT[input_t4](
129+
sd35_transformer2D_model,
130+
sd35_transformer2D_model_inputs,
131+
aten_op=[],
132+
exir_op=[],
133+
use_to_edge_transform_and_lower=True,
134+
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
135+
rtol=1.0,
136+
atol=4.0,
137+
)
138+
pipeline.change_args(
139+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
140+
)
141+
pipeline.run()
142+
143+
144+
@common.SkipIfNoModelConverter
145+
def test_SD3Transformer2DModel_vgf_FP():
146+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
147+
TestSD3Transformer2DModel().prepare_model_and_inputs()
148+
)
149+
with torch.no_grad():
150+
pipeline = VgfPipeline[input_t4](
151+
sd35_transformer2D_model,
152+
sd35_transformer2D_model_inputs,
153+
aten_op=[],
154+
exir_op=[],
155+
tosa_version="TOSA-1.0+FP",
156+
use_to_edge_transform_and_lower=True,
157+
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
158+
atol=4.0,
159+
)
160+
pipeline.change_args(
161+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
162+
)
163+
pipeline.run()
164+
165+
166+
@common.SkipIfNoModelConverter
167+
def test_SD3Transformer2DModel_vgf_INT():
168+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
169+
TestSD3Transformer2DModel().prepare_model_and_inputs()
170+
)
171+
with torch.no_grad():
172+
pipeline = VgfPipeline[input_t4](
173+
sd35_transformer2D_model,
174+
sd35_transformer2D_model_inputs,
175+
aten_op=[],
176+
exir_op=[],
177+
tosa_version="TOSA-1.0+INT",
178+
use_to_edge_transform_and_lower=True,
179+
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
180+
rtol=1.0,
181+
atol=4.0,
182+
)
183+
pipeline.change_args(
184+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
185+
)
186+
pipeline.run()

0 commit comments

Comments
 (0)