Skip to content

Commit e96e35c

Browse files
committed
Update base for Update on "Introduce public MergedDataMap"
Add public merged data map. Module can use this to resolve multiple named data maps. Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/) [ghstack-poisoned]
2 parents 3a0d66f + 70ea661 commit e96e35c

File tree

62 files changed

+1307
-610
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1307
-610
lines changed

backends/arm/operators/op_bmm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def define_node(
7979
input1_zp = input_qparams[1].get_zp_per_tensor()
8080
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
8181
bmm_output_name = bmm_result.name
82+
elif inputs[0].dtype == ts.DType.INT16:
83+
input_qparams = get_input_qparams(node)
84+
input0_zp = input_qparams[0].get_zp_per_tensor()
85+
input1_zp = input_qparams[1].get_zp_per_tensor()
86+
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT48)
87+
bmm_output_name = bmm_result.name
8288
else:
8389
bmm_output_name = output.name
8490
input0_zp, input1_zp = 0, 0
@@ -118,3 +124,20 @@ def define_node(
118124
output_zp=[output_qparams.get_zp_per_tensor()],
119125
rounding_mode=RoundingMode.SINGLE_ROUND,
120126
)
127+
elif output.dtype == ts.DType.INT16:
128+
output_qparams = get_output_qparams(node)[0]
129+
final_output_scale = (
130+
input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
131+
) / output_qparams.get_scale_per_tensor()
132+
133+
build_rescale(
134+
tosa_fb=tosa_graph,
135+
scale=[final_output_scale],
136+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
137+
input_node=bmm_result, # type: ignore[possibly-undefined]
138+
output_name=output.name,
139+
output_type=ts.DType.INT16,
140+
input_zp=[0],
141+
output_zp=[output_qparams.get_zp_per_tensor()],
142+
rounding_mode=RoundingMode.SINGLE_ROUND,
143+
)

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 99 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import unittest
7+
from typing import Tuple
88

99
import torch
1010
from executorch.backends.arm._passes import (
@@ -17,11 +17,17 @@
1717
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
1818
CLIP_text_encoder_config,
1919
)
20-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
from executorch.backends.arm.test.tester.test_pipeline import (
21+
TosaPipelineFP,
22+
TosaPipelineINT,
23+
VgfPipeline,
24+
)
2125
from transformers import CLIPTextModelWithProjection
2226

27+
input_t = Tuple[torch.Tensor]
28+
2329

24-
class TestCLIPTextModelWithProjection(unittest.TestCase):
30+
class TestCLIPTextModelWithProjection:
2531
"""
2632
Test class of CLIPTextModelWithProjection.
2733
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
@@ -69,47 +75,93 @@ def prepare_model_and_inputs(self):
6975

7076
return text_encoder_model, text_encoder_model_inputs
7177

72-
def test_CLIPTextModelWithProjection_tosa_FP(self):
73-
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
74-
with torch.no_grad():
75-
(
76-
ArmTester(
77-
text_encoder_model,
78-
example_inputs=text_encoder_model_inputs,
79-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
80-
transform_passes=[
81-
ConvertInt64ConstOpsToInt32Pass(),
82-
ConvertInt64OutputOpsToInt32Pass(),
83-
InsertInt32CastsAfterInt64PlaceholdersPass(),
84-
],
85-
)
86-
.export()
87-
.to_edge_transform_and_lower()
88-
.dump_operator_distribution()
89-
.check_count(self.ops_after_partitioner_FP)
90-
.to_executorch()
91-
.run_method_and_compare_outputs(
92-
inputs=text_encoder_model_inputs,
93-
)
94-
)
95-
96-
def test_CLIPTextModelWithProjection_tosa_INT(self):
97-
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
98-
with torch.no_grad():
99-
(
100-
ArmTester(
101-
text_encoder_model,
102-
example_inputs=text_encoder_model_inputs,
103-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
104-
)
105-
.quantize()
106-
.export()
107-
.to_edge_transform_and_lower()
108-
.dump_operator_distribution()
109-
.check_count(self.ops_after_partitioner_INT)
110-
.to_executorch()
111-
.run_method_and_compare_outputs(
112-
inputs=text_encoder_model_inputs,
113-
atol=0.8,
114-
)
115-
)
78+
79+
def test_CLIPTextModelWithProjection_tosa_FP():
80+
text_encoder_model, text_encoder_model_inputs = (
81+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
82+
)
83+
with torch.no_grad():
84+
pipeline = TosaPipelineFP[input_t](
85+
text_encoder_model,
86+
text_encoder_model_inputs,
87+
aten_op=[],
88+
exir_op=[],
89+
use_to_edge_transform_and_lower=True,
90+
transform_passes=[
91+
ConvertInt64ConstOpsToInt32Pass(),
92+
ConvertInt64OutputOpsToInt32Pass(),
93+
InsertInt32CastsAfterInt64PlaceholdersPass(),
94+
],
95+
)
96+
pipeline.change_args(
97+
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
98+
)
99+
pipeline.run()
100+
101+
102+
def test_CLIPTextModelWithProjection_tosa_INT():
103+
text_encoder_model, text_encoder_model_inputs = (
104+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
105+
)
106+
with torch.no_grad():
107+
pipeline = TosaPipelineINT[input_t](
108+
text_encoder_model,
109+
text_encoder_model_inputs,
110+
aten_op=[],
111+
exir_op=[],
112+
use_to_edge_transform_and_lower=True,
113+
atol=0.8,
114+
)
115+
pipeline.change_args(
116+
"check_count.exir",
117+
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
118+
)
119+
pipeline.run()
120+
121+
122+
@common.SkipIfNoModelConverter
123+
def test_CLIPTextModelWithProjection_vgf_FP():
124+
text_encoder_model, text_encoder_model_inputs = (
125+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
126+
)
127+
with torch.no_grad():
128+
pipeline = VgfPipeline[input_t](
129+
text_encoder_model,
130+
text_encoder_model_inputs,
131+
aten_op=[],
132+
exir_op=[],
133+
tosa_version="TOSA-1.0+FP",
134+
use_to_edge_transform_and_lower=True,
135+
atol=4, # TODO: Investiage numerical issue: MAX Diff ~50%
136+
transform_passes=[
137+
ConvertInt64ConstOpsToInt32Pass(),
138+
ConvertInt64OutputOpsToInt32Pass(),
139+
InsertInt32CastsAfterInt64PlaceholdersPass(),
140+
],
141+
)
142+
pipeline.change_args(
143+
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
144+
)
145+
pipeline.run()
146+
147+
148+
@common.SkipIfNoModelConverter
149+
def test_CLIPTextModelWithProjection_vgf_INT():
150+
text_encoder_model, text_encoder_model_inputs = (
151+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
152+
)
153+
with torch.no_grad():
154+
pipeline = VgfPipeline[input_t](
155+
text_encoder_model,
156+
text_encoder_model_inputs,
157+
aten_op=[],
158+
exir_op=[],
159+
tosa_version="TOSA-1.0+INT",
160+
use_to_edge_transform_and_lower=True,
161+
atol=0.8,
162+
)
163+
pipeline.change_args(
164+
"check_count.exir",
165+
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
166+
)
167+
pipeline.run()

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 92 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import unittest
7+
from typing import Tuple
88

99
import torch
1010
from diffusers.models.transformers import SD3Transformer2DModel
@@ -13,10 +13,16 @@
1313
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
1414
SD3Transformer2DModel_init_dict,
1515
)
16-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
from executorch.backends.arm.test.tester.test_pipeline import (
17+
TosaPipelineFP,
18+
TosaPipelineINT,
19+
VgfPipeline,
20+
)
21+
22+
input_t4 = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
1723

1824

19-
class TestSD3Transformer2DModel(unittest.TestCase):
25+
class TestSD3Transformer2DModel:
2026
"""
2127
Test class of AutoenSD3Transformer2DModelcoderKL.
2228
SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium
@@ -93,48 +99,88 @@ def forward(self, *args, **kwargs):
9399

94100
return sd35_transformer2D_model, sd35_transformer2D_model_inputs
95101

96-
def test_SD3Transformer2DModel_tosa_FP(self):
97-
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
98-
self.prepare_model_and_inputs()
99-
)
100-
with torch.no_grad():
101-
(
102-
ArmTester(
103-
sd35_transformer2D_model,
104-
example_inputs=sd35_transformer2D_model_inputs,
105-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
106-
)
107-
.export()
108-
.to_edge_transform_and_lower()
109-
.check_count(self.ops_after_partitioner_FP)
110-
.to_executorch()
111-
.run_method_and_compare_outputs(
112-
inputs=sd35_transformer2D_model_inputs,
113-
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
114-
atol=4.0,
115-
)
116-
)
117102

118-
def test_SD3Transformer2DModel_tosa_INT(self):
119-
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
120-
self.prepare_model_and_inputs()
103+
def test_SD3Transformer2DModel_tosa_FP():
104+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
105+
TestSD3Transformer2DModel().prepare_model_and_inputs()
106+
)
107+
with torch.no_grad():
108+
pipeline = TosaPipelineFP[input_t4](
109+
sd35_transformer2D_model,
110+
sd35_transformer2D_model_inputs,
111+
aten_op=[],
112+
exir_op=[],
113+
use_to_edge_transform_and_lower=True,
114+
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
115+
atol=4.0,
121116
)
122-
with torch.no_grad():
123-
(
124-
ArmTester(
125-
sd35_transformer2D_model,
126-
example_inputs=sd35_transformer2D_model_inputs,
127-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
128-
)
129-
.quantize()
130-
.export()
131-
.to_edge_transform_and_lower()
132-
.check_count(self.ops_after_partitioner_INT)
133-
.to_executorch()
134-
.run_method_and_compare_outputs(
135-
inputs=sd35_transformer2D_model_inputs,
136-
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
137-
rtol=1.0,
138-
atol=4.0,
139-
)
140-
)
117+
pipeline.change_args(
118+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
119+
)
120+
pipeline.run()
121+
122+
123+
def test_SD3Transformer2DModel_tosa_INT():
124+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
125+
TestSD3Transformer2DModel().prepare_model_and_inputs()
126+
)
127+
with torch.no_grad():
128+
pipeline = TosaPipelineINT[input_t4](
129+
sd35_transformer2D_model,
130+
sd35_transformer2D_model_inputs,
131+
aten_op=[],
132+
exir_op=[],
133+
use_to_edge_transform_and_lower=True,
134+
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
135+
rtol=1.0,
136+
atol=4.0,
137+
)
138+
pipeline.change_args(
139+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
140+
)
141+
pipeline.run()
142+
143+
144+
@common.SkipIfNoModelConverter
145+
def test_SD3Transformer2DModel_vgf_FP():
146+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
147+
TestSD3Transformer2DModel().prepare_model_and_inputs()
148+
)
149+
with torch.no_grad():
150+
pipeline = VgfPipeline[input_t4](
151+
sd35_transformer2D_model,
152+
sd35_transformer2D_model_inputs,
153+
aten_op=[],
154+
exir_op=[],
155+
tosa_version="TOSA-1.0+FP",
156+
use_to_edge_transform_and_lower=True,
157+
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
158+
atol=4.0,
159+
)
160+
pipeline.change_args(
161+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
162+
)
163+
pipeline.run()
164+
165+
166+
@common.SkipIfNoModelConverter
167+
def test_SD3Transformer2DModel_vgf_INT():
168+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
169+
TestSD3Transformer2DModel().prepare_model_and_inputs()
170+
)
171+
with torch.no_grad():
172+
pipeline = VgfPipeline[input_t4](
173+
sd35_transformer2D_model,
174+
sd35_transformer2D_model_inputs,
175+
aten_op=[],
176+
exir_op=[],
177+
tosa_version="TOSA-1.0+INT",
178+
use_to_edge_transform_and_lower=True,
179+
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
180+
rtol=1.0,
181+
atol=4.0,
182+
)
183+
pipeline.change_args(
184+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
185+
)
186+
pipeline.run()

0 commit comments

Comments
 (0)