Skip to content

Commit 119c28c

Browse files
Arm backend: Add initial module tests for Stable Diffusion 3.5 Medium (#12242)
- Add module tests for two text encoders: CLIP and T5 - Add module tests for VAE autoencoder - Add module tests for SD3Transformer2DModel - Add flag to exmaples/arm/setup.sh for installing Stable Diffusion dependencies - Handle int64 inputs to aten.slice_copy.Tensor using pass InsertCastForOpsWithInt64InputPass Signed-off-by: Yufeng Shi <[email protected]>
1 parent 5338708 commit 119c28c

10 files changed

+618
-27
lines changed

backends/arm/README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ The you can run the tests with
104104
pytest -c /dev/null -v -n auto backends/arm/test
105105
```
106106

107+
### Model test dependencies
108+
Some model tests in Arm backend require third-party libraries or packages. To run these tests, you need to install the required dependencies by running the script `examples/arm/setup.sh` with the flag `--setup-test-dependency`.
109+
110+
Please note that installing model test dependencies is a standalone process. When using the `--setup-test-dependency` flag, the script will install only the necessary dependencies for model tests, skipping all other setup procedures.
111+
112+
List of models with specific dependencies:
113+
- Stable Diffusion: [diffusers](https://github.com/huggingface/diffusers/tree/main)
114+
107115
## Passes
108116

109117
With the default passes in the Arm Ethos-U backend, assuming the model lowers fully to the
@@ -189,7 +197,14 @@ Configuration of the EthosUBackend export flow is controlled by CompileSpec info
189197
As this is in active development see the EthosUBackend for accurate information on [compilation flags](https://github.com/pytorch/executorch/blob/29f6dc9353e90951ed3fae3c57ae416de0520067/backends/arm/arm_backend.py#L319-L324)
190198

191199
## Model specific and optional passes
192-
The current TOSA version does not support int64. For LLMs for example LLama, often aten.emedding is the first operator and it requires int64 indicies.
193-
In order to lower this to TOSA and int64->int32 cast need to be injected. This pass need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter. See example in: backends/arm/test/models/test_llama.py.
194-
By doing this aten.embedding will be decomposed into to aten.index_select which can handle int32 indices.
195-
Note that this additional step is only needed for pure float models. With quantization this is automatically handled during annotation before the export stage.
200+
The current TOSA version does not support int64. However, int64 is commonly used in many models. In order to lower the operators with int64 inputs and/or outputs to TOSA, a few passes have been developed to handle the int64-related issues. The main idea behind these passes is to replace the uses of int64 with int32 where feasible.
201+
- For floating-point models, these passes need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter.
202+
- For quantized models, these transformations will be automatically handled during annotation before the export stage.
203+
204+
List of model specific and optional passes:
205+
- InsertCastForOpsWithInt64InputPass
206+
- Functionality:
207+
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
208+
- Example usage: backends/arm/test/models/test_llama.py
209+
- Supported Ops:
210+
- aten.embedding.default, aten.slice_copy.Tensor

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@
2020

2121
class InsertCastForOpsWithInt64InputPass(ExportPass):
2222

23-
aten_ops = (torch.ops.aten.embedding.default,)
24-
edge_ops = (exir_ops.edge.aten.embedding.default,)
23+
aten_ops = (
24+
torch.ops.aten.embedding.default,
25+
torch.ops.aten.slice_copy.Tensor,
26+
)
27+
edge_ops = (
28+
exir_ops.edge.aten.embedding.default,
29+
exir_ops.edge.aten.slice_copy.Tensor,
30+
)
2531

2632
def get_decomposition(self, op):
2733
if op in self.edge_ops:
@@ -49,6 +55,20 @@ def _check_aten_embedding_within_int32(self, weights, indices, node: torch.fx.No
4955

5056
return True
5157

58+
def _insert_int32_cast_before_node(self, graph, node, original_input):
59+
to_copy_op = self.get_decomposition(node.target)
60+
with graph.inserting_before(node):
61+
cast_before = create_node(
62+
graph,
63+
to_copy_op,
64+
args=(original_input,),
65+
kwargs={
66+
"dtype": torch.int32,
67+
"memory_format": torch.preserve_format,
68+
},
69+
)
70+
node.replace_input_with(original_input, cast_before)
71+
5272
def call(self, graph_module):
5373
graph = graph_module.graph
5474
modified_graph = False
@@ -60,35 +80,31 @@ def call(self, graph_module):
6080
continue
6181

6282
args = node.args
63-
weights = args[0]
64-
indices = args[1]
6583

66-
valid_for_insert = False
6784
if node.target in (
6885
exir_ops.edge.aten.embedding.default,
6986
torch.ops.aten.embedding.default,
7087
):
71-
valid_for_insert = self._check_aten_embedding_within_int32(
72-
weights, indices, node
73-
)
74-
75-
if valid_for_insert:
76-
to_copy_op = self.get_decomposition(node.target)
77-
with graph.inserting_before(node):
78-
cast_before = create_node(
79-
graph,
80-
to_copy_op,
81-
args=(indices,),
82-
kwargs={
83-
"dtype": torch.int32,
84-
"memory_format": torch.preserve_format,
85-
},
86-
)
87-
node.replace_input_with(indices, cast_before)
88+
weights = args[0]
89+
indices = args[1]
90+
if self._check_aten_embedding_within_int32(weights, indices, node):
91+
self._insert_int32_cast_before_node(graph, node, indices)
92+
modified_graph = True
93+
94+
elif node.target in (
95+
exir_ops.edge.aten.slice_copy.Tensor,
96+
torch.ops.aten.slice_copy.Tensor,
97+
):
98+
# MLETORCH-829: Add range check for slice_copy
99+
input_tensor = args[0]
100+
fake_tensor = input_tensor.meta["val"]
101+
if fake_tensor.dtype != torch.int64:
102+
continue
88103

104+
self._insert_int32_cast_before_node(graph, node, input_tensor)
89105
modified_graph = True
90106

91107
if modified_graph:
92108
graph_module.recompile()
93109
graph_module = super().call(graph_module).graph_module
94-
return PassResult(graph_module, True)
110+
return PassResult(graph_module, modified_graph)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env bash
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
set -e
8+
9+
# Install diffusers for Stable Diffusion model test
10+
pip install "diffusers[torch]==0.33.1"
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# Adapted from Hugging Face's diffusers library:
7+
# https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
8+
#
9+
# Licensed under the Apache License, Version 2.0
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
from transformers import CLIPTextConfig, T5Config
20+
21+
22+
"""
23+
This file defines test configs used to initialize Stable Diffusion module tests.
24+
Module tests in the same directory will import these configs.
25+
26+
To stay aligned with the Stable Diffusion implementation in the HuggingFace Diffusers library,
27+
the configs here are either directly copied from corresponding test files or exported from
28+
pre-trained models used in the Diffusers library.
29+
30+
Licenses:
31+
The test parameters are from Hugging Face's diffusers library and under the Apache 2.0 License,
32+
while the remainder of the code is under the BSD-style license found in the LICENSE file in the
33+
root directory of this source tree.
34+
"""
35+
36+
37+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L56
38+
CLIP_text_encoder_config = CLIPTextConfig(
39+
bos_token_id=0,
40+
eos_token_id=2,
41+
hidden_size=32,
42+
intermediate_size=37,
43+
layer_norm_eps=1e-05,
44+
num_attention_heads=4,
45+
num_hidden_layers=5,
46+
pad_token_id=1,
47+
vocab_size=1000,
48+
hidden_act="gelu",
49+
projection_dim=32,
50+
)
51+
52+
53+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L76
54+
# Exported from: T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5").config
55+
T5_encoder_config = T5Config(
56+
bos_token_id=0,
57+
classifier_dropout=0.0,
58+
d_ff=37,
59+
d_kv=8,
60+
d_model=32,
61+
decoder_start_token_id=0,
62+
dense_act_fn="relu",
63+
dropout_rate=0.1,
64+
eos_token_id=1,
65+
feed_forward_proj="relu",
66+
gradient_checkpointing=False,
67+
initializer_factor=0.002,
68+
is_encoder_decoder=True,
69+
is_gated_act=False,
70+
layer_norm_epsilon=1e-06,
71+
model_type="t5",
72+
num_decoder_layers=5,
73+
num_heads=4,
74+
num_layers=5,
75+
pad_token_id=0,
76+
relative_attention_max_distance=128,
77+
relative_attention_num_buckets=8,
78+
transformers_version="4.47.1",
79+
vocab_size=1000,
80+
)
81+
82+
83+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/models/transformers/test_models_transformer_sd3.py#L142
84+
SD3Transformer2DModel_init_dict = {
85+
"sample_size": 32,
86+
"patch_size": 1,
87+
"in_channels": 4,
88+
"num_layers": 4,
89+
"attention_head_dim": 8,
90+
"num_attention_heads": 4,
91+
"caption_projection_dim": 32,
92+
"joint_attention_dim": 32,
93+
"pooled_projection_dim": 64,
94+
"out_channels": 4,
95+
"pos_embed_max_size": 96,
96+
"dual_attention_layers": (0,),
97+
"qk_norm": "rms_norm",
98+
}
99+
100+
101+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L83
102+
AutoencoderKL_config = {
103+
"sample_size": 32,
104+
"in_channels": 3,
105+
"out_channels": 3,
106+
"block_out_channels": (4,),
107+
"layers_per_block": 1,
108+
"latent_channels": 4,
109+
"norm_num_groups": 1,
110+
"use_quant_conv": False,
111+
"use_post_quant_conv": False,
112+
"shift_factor": 0.0609,
113+
"scaling_factor": 1.5035,
114+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass
11+
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
14+
CLIP_text_encoder_config,
15+
)
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
from transformers import CLIPTextModelWithProjection
18+
19+
20+
class TestCLIPTextModelWithProjection(unittest.TestCase):
21+
"""
22+
Test class of CLIPTextModelWithProjection.
23+
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
24+
"""
25+
26+
# Adjust nbr below as we increase op support. Note: most of the delegates
27+
# calls are directly consecutive to each other in the .pte. The reason
28+
# for that is some assert ops are removed by passes in the
29+
# .to_executorch step, i.e. after Arm partitioner.
30+
ops_after_partitioner = {
31+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 3,
32+
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
33+
"executorch_exir_dialects_edge__ops_aten_index_Tensor": 1,
34+
"executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1,
35+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
36+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
37+
"torch.ops.higher_order.executorch_call_delegate": 3,
38+
}
39+
40+
def _prepare_inputs(
41+
self,
42+
batch_size=12,
43+
seq_length=7,
44+
vocab_size=1000,
45+
):
46+
input_ids = torch.randint(
47+
low=0,
48+
high=vocab_size,
49+
size=(batch_size, seq_length),
50+
dtype=torch.long,
51+
)
52+
return (input_ids,)
53+
54+
def prepare_model_and_inputs(self):
55+
clip_text_encoder_config = CLIP_text_encoder_config
56+
57+
text_encoder_model = CLIPTextModelWithProjection(clip_text_encoder_config)
58+
text_encoder_model.eval()
59+
text_encoder_model_inputs = self._prepare_inputs()
60+
61+
return text_encoder_model, text_encoder_model_inputs
62+
63+
def test_CLIPTextModelWithProjection_tosa_MI(self):
64+
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
65+
with torch.no_grad():
66+
(
67+
ArmTester(
68+
text_encoder_model,
69+
example_inputs=text_encoder_model_inputs,
70+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
71+
transform_passes=[InsertCastForOpsWithInt64InputPass()],
72+
)
73+
.export()
74+
.to_edge_transform_and_lower()
75+
.dump_operator_distribution()
76+
.check_count(self.ops_after_partitioner)
77+
.to_executorch()
78+
.run_method_and_compare_outputs(
79+
inputs=text_encoder_model_inputs,
80+
)
81+
)
82+
83+
# MLETORCH-867, MLETORCH-1059
84+
# Failures: "Fatal Python error: Aborted, Dependency cycles, KeyError in CastInt64BuffersToInt32Pass")
85+
@unittest.expectedFailure
86+
def test_CLIPTextModelWithProjection_tosa_BI(self):
87+
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
88+
with torch.no_grad():
89+
(
90+
ArmTester(
91+
text_encoder_model,
92+
example_inputs=text_encoder_model_inputs,
93+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
94+
)
95+
.quantize()
96+
.export()
97+
.to_edge_transform_and_lower()
98+
.dump_operator_distribution()
99+
.to_executorch()
100+
.run_method_and_compare_outputs(
101+
inputs=text_encoder_model_inputs,
102+
)
103+
)

0 commit comments

Comments
 (0)