Skip to content

Commit 3aad6a9

Browse files
committed
Update on "Introduce public MergedDataMap"
Add public merged data map. Module can use this to resolve multiple named data maps. Differential Revision: [D83527299](https://our.internmc.facebook.com/intern/diff/D83527299/) [ghstack-poisoned]
2 parents a00bab7 + e96e35c commit 3aad6a9

File tree

75 files changed

+1479
-643
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+1479
-643
lines changed

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,11 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
630630
list(APPEND _executorch_extensions extension_module_static)
631631
endif()
632632

633+
if(EXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP)
634+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/named_data_map)
635+
list(APPEND _executorch_extensions extension_named_data_map)
636+
endif()
637+
633638
if(EXECUTORCH_BUILD_EXTENSION_LLM)
634639
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
635640
set(SUPPORT_REGEX_LOOKAHEAD ON)

backends/arm/operators/op_bmm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def define_node(
7979
input1_zp = input_qparams[1].get_zp_per_tensor()
8080
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
8181
bmm_output_name = bmm_result.name
82+
elif inputs[0].dtype == ts.DType.INT16:
83+
input_qparams = get_input_qparams(node)
84+
input0_zp = input_qparams[0].get_zp_per_tensor()
85+
input1_zp = input_qparams[1].get_zp_per_tensor()
86+
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT48)
87+
bmm_output_name = bmm_result.name
8288
else:
8389
bmm_output_name = output.name
8490
input0_zp, input1_zp = 0, 0
@@ -118,3 +124,20 @@ def define_node(
118124
output_zp=[output_qparams.get_zp_per_tensor()],
119125
rounding_mode=RoundingMode.SINGLE_ROUND,
120126
)
127+
elif output.dtype == ts.DType.INT16:
128+
output_qparams = get_output_qparams(node)[0]
129+
final_output_scale = (
130+
input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
131+
) / output_qparams.get_scale_per_tensor()
132+
133+
build_rescale(
134+
tosa_fb=tosa_graph,
135+
scale=[final_output_scale],
136+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
137+
input_node=bmm_result, # type: ignore[possibly-undefined]
138+
output_name=output.name,
139+
output_type=ts.DType.INT16,
140+
input_zp=[0],
141+
output_zp=[output_qparams.get_zp_per_tensor()],
142+
rounding_mode=RoundingMode.SINGLE_ROUND,
143+
)

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 99 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import unittest
7+
from typing import Tuple
88

99
import torch
1010
from executorch.backends.arm._passes import (
@@ -17,11 +17,17 @@
1717
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
1818
CLIP_text_encoder_config,
1919
)
20-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
20+
from executorch.backends.arm.test.tester.test_pipeline import (
21+
TosaPipelineFP,
22+
TosaPipelineINT,
23+
VgfPipeline,
24+
)
2125
from transformers import CLIPTextModelWithProjection
2226

27+
input_t = Tuple[torch.Tensor]
28+
2329

24-
class TestCLIPTextModelWithProjection(unittest.TestCase):
30+
class TestCLIPTextModelWithProjection:
2531
"""
2632
Test class of CLIPTextModelWithProjection.
2733
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
@@ -69,47 +75,93 @@ def prepare_model_and_inputs(self):
6975

7076
return text_encoder_model, text_encoder_model_inputs
7177

72-
def test_CLIPTextModelWithProjection_tosa_FP(self):
73-
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
74-
with torch.no_grad():
75-
(
76-
ArmTester(
77-
text_encoder_model,
78-
example_inputs=text_encoder_model_inputs,
79-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
80-
transform_passes=[
81-
ConvertInt64ConstOpsToInt32Pass(),
82-
ConvertInt64OutputOpsToInt32Pass(),
83-
InsertInt32CastsAfterInt64PlaceholdersPass(),
84-
],
85-
)
86-
.export()
87-
.to_edge_transform_and_lower()
88-
.dump_operator_distribution()
89-
.check_count(self.ops_after_partitioner_FP)
90-
.to_executorch()
91-
.run_method_and_compare_outputs(
92-
inputs=text_encoder_model_inputs,
93-
)
94-
)
95-
96-
def test_CLIPTextModelWithProjection_tosa_INT(self):
97-
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
98-
with torch.no_grad():
99-
(
100-
ArmTester(
101-
text_encoder_model,
102-
example_inputs=text_encoder_model_inputs,
103-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
104-
)
105-
.quantize()
106-
.export()
107-
.to_edge_transform_and_lower()
108-
.dump_operator_distribution()
109-
.check_count(self.ops_after_partitioner_INT)
110-
.to_executorch()
111-
.run_method_and_compare_outputs(
112-
inputs=text_encoder_model_inputs,
113-
atol=0.8,
114-
)
115-
)
78+
79+
def test_CLIPTextModelWithProjection_tosa_FP():
80+
text_encoder_model, text_encoder_model_inputs = (
81+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
82+
)
83+
with torch.no_grad():
84+
pipeline = TosaPipelineFP[input_t](
85+
text_encoder_model,
86+
text_encoder_model_inputs,
87+
aten_op=[],
88+
exir_op=[],
89+
use_to_edge_transform_and_lower=True,
90+
transform_passes=[
91+
ConvertInt64ConstOpsToInt32Pass(),
92+
ConvertInt64OutputOpsToInt32Pass(),
93+
InsertInt32CastsAfterInt64PlaceholdersPass(),
94+
],
95+
)
96+
pipeline.change_args(
97+
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
98+
)
99+
pipeline.run()
100+
101+
102+
def test_CLIPTextModelWithProjection_tosa_INT():
103+
text_encoder_model, text_encoder_model_inputs = (
104+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
105+
)
106+
with torch.no_grad():
107+
pipeline = TosaPipelineINT[input_t](
108+
text_encoder_model,
109+
text_encoder_model_inputs,
110+
aten_op=[],
111+
exir_op=[],
112+
use_to_edge_transform_and_lower=True,
113+
atol=0.8,
114+
)
115+
pipeline.change_args(
116+
"check_count.exir",
117+
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
118+
)
119+
pipeline.run()
120+
121+
122+
@common.SkipIfNoModelConverter
123+
def test_CLIPTextModelWithProjection_vgf_FP():
124+
text_encoder_model, text_encoder_model_inputs = (
125+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
126+
)
127+
with torch.no_grad():
128+
pipeline = VgfPipeline[input_t](
129+
text_encoder_model,
130+
text_encoder_model_inputs,
131+
aten_op=[],
132+
exir_op=[],
133+
tosa_version="TOSA-1.0+FP",
134+
use_to_edge_transform_and_lower=True,
135+
atol=4, # TODO: Investiage numerical issue: MAX Diff ~50%
136+
transform_passes=[
137+
ConvertInt64ConstOpsToInt32Pass(),
138+
ConvertInt64OutputOpsToInt32Pass(),
139+
InsertInt32CastsAfterInt64PlaceholdersPass(),
140+
],
141+
)
142+
pipeline.change_args(
143+
"check_count.exir", TestCLIPTextModelWithProjection.ops_after_partitioner_FP
144+
)
145+
pipeline.run()
146+
147+
148+
@common.SkipIfNoModelConverter
149+
def test_CLIPTextModelWithProjection_vgf_INT():
150+
text_encoder_model, text_encoder_model_inputs = (
151+
TestCLIPTextModelWithProjection().prepare_model_and_inputs()
152+
)
153+
with torch.no_grad():
154+
pipeline = VgfPipeline[input_t](
155+
text_encoder_model,
156+
text_encoder_model_inputs,
157+
aten_op=[],
158+
exir_op=[],
159+
tosa_version="TOSA-1.0+INT",
160+
use_to_edge_transform_and_lower=True,
161+
atol=0.8,
162+
)
163+
pipeline.change_args(
164+
"check_count.exir",
165+
TestCLIPTextModelWithProjection.ops_after_partitioner_INT,
166+
)
167+
pipeline.run()

backends/arm/test/models/stable_diffusion/test_SD3Transformer2DModel.py

Lines changed: 92 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import unittest
7+
from typing import Tuple
88

99
import torch
1010
from diffusers.models.transformers import SD3Transformer2DModel
@@ -13,10 +13,16 @@
1313
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
1414
SD3Transformer2DModel_init_dict,
1515
)
16-
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
from executorch.backends.arm.test.tester.test_pipeline import (
17+
TosaPipelineFP,
18+
TosaPipelineINT,
19+
VgfPipeline,
20+
)
21+
22+
input_t4 = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
1723

1824

19-
class TestSD3Transformer2DModel(unittest.TestCase):
25+
class TestSD3Transformer2DModel:
2026
"""
2127
Test class of AutoenSD3Transformer2DModelcoderKL.
2228
SD3Transformer2DModel is the transformer model used by Stable Diffusion 3.5 Medium
@@ -93,48 +99,88 @@ def forward(self, *args, **kwargs):
9399

94100
return sd35_transformer2D_model, sd35_transformer2D_model_inputs
95101

96-
def test_SD3Transformer2DModel_tosa_FP(self):
97-
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
98-
self.prepare_model_and_inputs()
99-
)
100-
with torch.no_grad():
101-
(
102-
ArmTester(
103-
sd35_transformer2D_model,
104-
example_inputs=sd35_transformer2D_model_inputs,
105-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
106-
)
107-
.export()
108-
.to_edge_transform_and_lower()
109-
.check_count(self.ops_after_partitioner_FP)
110-
.to_executorch()
111-
.run_method_and_compare_outputs(
112-
inputs=sd35_transformer2D_model_inputs,
113-
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
114-
atol=4.0,
115-
)
116-
)
117102

118-
def test_SD3Transformer2DModel_tosa_INT(self):
119-
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
120-
self.prepare_model_and_inputs()
103+
def test_SD3Transformer2DModel_tosa_FP():
104+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
105+
TestSD3Transformer2DModel().prepare_model_and_inputs()
106+
)
107+
with torch.no_grad():
108+
pipeline = TosaPipelineFP[input_t4](
109+
sd35_transformer2D_model,
110+
sd35_transformer2D_model_inputs,
111+
aten_op=[],
112+
exir_op=[],
113+
use_to_edge_transform_and_lower=True,
114+
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
115+
atol=4.0,
121116
)
122-
with torch.no_grad():
123-
(
124-
ArmTester(
125-
sd35_transformer2D_model,
126-
example_inputs=sd35_transformer2D_model_inputs,
127-
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
128-
)
129-
.quantize()
130-
.export()
131-
.to_edge_transform_and_lower()
132-
.check_count(self.ops_after_partitioner_INT)
133-
.to_executorch()
134-
.run_method_and_compare_outputs(
135-
inputs=sd35_transformer2D_model_inputs,
136-
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
137-
rtol=1.0,
138-
atol=4.0,
139-
)
140-
)
117+
pipeline.change_args(
118+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
119+
)
120+
pipeline.run()
121+
122+
123+
def test_SD3Transformer2DModel_tosa_INT():
124+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
125+
TestSD3Transformer2DModel().prepare_model_and_inputs()
126+
)
127+
with torch.no_grad():
128+
pipeline = TosaPipelineINT[input_t4](
129+
sd35_transformer2D_model,
130+
sd35_transformer2D_model_inputs,
131+
aten_op=[],
132+
exir_op=[],
133+
use_to_edge_transform_and_lower=True,
134+
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
135+
rtol=1.0,
136+
atol=4.0,
137+
)
138+
pipeline.change_args(
139+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
140+
)
141+
pipeline.run()
142+
143+
144+
@common.SkipIfNoModelConverter
145+
def test_SD3Transformer2DModel_vgf_FP():
146+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
147+
TestSD3Transformer2DModel().prepare_model_and_inputs()
148+
)
149+
with torch.no_grad():
150+
pipeline = VgfPipeline[input_t4](
151+
sd35_transformer2D_model,
152+
sd35_transformer2D_model_inputs,
153+
aten_op=[],
154+
exir_op=[],
155+
tosa_version="TOSA-1.0+FP",
156+
use_to_edge_transform_and_lower=True,
157+
rtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
158+
atol=4.0,
159+
)
160+
pipeline.change_args(
161+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_FP
162+
)
163+
pipeline.run()
164+
165+
166+
@common.SkipIfNoModelConverter
167+
def test_SD3Transformer2DModel_vgf_INT():
168+
sd35_transformer2D_model, sd35_transformer2D_model_inputs = (
169+
TestSD3Transformer2DModel().prepare_model_and_inputs()
170+
)
171+
with torch.no_grad():
172+
pipeline = VgfPipeline[input_t4](
173+
sd35_transformer2D_model,
174+
sd35_transformer2D_model_inputs,
175+
aten_op=[],
176+
exir_op=[],
177+
tosa_version="TOSA-1.0+INT",
178+
use_to_edge_transform_and_lower=True,
179+
qtol=1.0, # TODO: MLETORCH-875: Reduce tolerance of SD3Transformer2DModel with FP and INT
180+
rtol=1.0,
181+
atol=4.0,
182+
)
183+
pipeline.change_args(
184+
"check_count.exir", TestSD3Transformer2DModel.ops_after_partitioner_INT
185+
)
186+
pipeline.run()

0 commit comments

Comments
 (0)