Skip to content

Commit 42a4656

Browse files
Arm backend: Add initial module tests for Stable Diffusion 3.5 Medium
- Add module tests for two text encoders: CLIP and T5 - Add module tests for VAE autoencoder - Add module tests for SD3Transformer2DModel - Add flag to exmaples/arm/setup.sh for installing Stable Diffusion dependencies - Handle int64 inputs to aten.slice_copy.Tensor using pass InsertCastForOpsWithInt64InputPass Change-Id: I4389e87749cfb4e40f837cc6bfa8a59a73fb3a3a Signed-off-by: Yufeng Shi <[email protected]>
1 parent f332196 commit 42a4656

File tree

10 files changed

+618
-13
lines changed

10 files changed

+618
-13
lines changed

backends/arm/README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ The you can run the tests with
104104
pytest -c /dev/null -v -n auto backends/arm/test
105105
```
106106

107+
### Model test dependencies
108+
Some model tests in Arm backend require third-party libraries or packages. To run these tests, you need to install the required dependencies by running the script `examples/arm/setup.sh` with the flag `--setup-test-dependency`.
109+
110+
Please note that installing model test dependencies is a standalone process. When using the `--setup-test-dependency` flag, the script will install only the necessary dependencies for model tests, skipping all other setup procedures.
111+
112+
List of models with specific dependencies:
113+
- Stable Diffusion: [diffusers](https://github.com/huggingface/diffusers/tree/main)
114+
107115
## Passes
108116

109117
With the default passes in the Arm Ethos-U backend, assuming the model lowers fully to the
@@ -189,7 +197,14 @@ Configuration of the EthosUBackend export flow is controlled by CompileSpec info
189197
As this is in active development see the EthosUBackend for accurate information on [compilation flags](https://github.com/pytorch/executorch/blob/29f6dc9353e90951ed3fae3c57ae416de0520067/backends/arm/arm_backend.py#L319-L324)
190198

191199
## Model specific and optional passes
192-
The current TOSA version does not support int64. For LLMs for example LLama, often aten.emedding is the first operator and it requires int64 indicies.
193-
In order to lower this to TOSA and int64->int32 cast need to be injected. This pass need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter. See example in: backends/arm/test/models/test_llama.py.
194-
By doing this aten.embedding will be decomposed into to aten.index_select which can handle int32 indices.
195-
Note that this additional step is only needed for pure float models. With quantization this is automatically handled during annotation before the export stage.
200+
The current TOSA version does not support int64. However, int64 is commonly used in many models. In order to lower the operators with int64 inputs and/or outputs to TOSA, a few passes have been developed to handle the int64-related issues. The main idea behind these passes is to replace the uses of int64 with int32 where feasible.
201+
- For floating-point models, these passes need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter.
202+
- For quantized models, these transformations will be automatically handled during annotation before the export stage.
203+
204+
List of model specific and optional passes:
205+
- InsertCastForOpsWithInt64InputPass
206+
- Functionality:
207+
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
208+
- Example usage: backends/arm/test/models/test_llama.py
209+
- Supported Ops:
210+
- aten.embedding.default, aten.slice_copy.Tensor

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,14 @@
2020

2121
class InsertCastForOpsWithInt64InputPass(ExportPass):
2222

23-
aten_ops = (torch.ops.aten.embedding.default,)
24-
edge_ops = (exir_ops.edge.aten.embedding.default,)
23+
aten_ops = (
24+
torch.ops.aten.embedding.default,
25+
torch.ops.aten.slice_copy.Tensor,
26+
)
27+
edge_ops = (
28+
exir_ops.edge.aten.embedding.default,
29+
exir_ops.edge.aten.slice_copy.Tensor,
30+
)
2531

2632
def get_decomposition(self, op):
2733
if op in self.edge_ops:
@@ -60,35 +66,59 @@ def call(self, graph_module):
6066
continue
6167

6268
args = node.args
63-
weights = args[0]
64-
indices = args[1]
6569

66-
valid_for_insert = False
6770
if node.target in (
6871
exir_ops.edge.aten.embedding.default,
6972
torch.ops.aten.embedding.default,
7073
):
74+
weights = args[0]
75+
indices = args[1]
7176
valid_for_insert = self._check_aten_embedding_within_int32(
7277
weights, indices, node
7378
)
7479

75-
if valid_for_insert:
80+
if valid_for_insert:
81+
to_copy_op = self.get_decomposition(node.target)
82+
with graph.inserting_before(node):
83+
cast_before = create_node(
84+
graph,
85+
to_copy_op,
86+
args=(indices,),
87+
kwargs={
88+
"dtype": torch.int32,
89+
"memory_format": torch.preserve_format,
90+
},
91+
)
92+
node.replace_input_with(indices, cast_before)
93+
94+
modified_graph = True
95+
96+
elif node.target in (
97+
exir_ops.edge.aten.slice_copy.Tensor,
98+
torch.ops.aten.slice_copy.Tensor,
99+
):
100+
# MLETORCH-829: Add range check for slice_copy
101+
input_tensor = args[0]
102+
fake_tensor = input_tensor.meta["val"]
103+
if fake_tensor.dtype != torch.int64:
104+
continue
105+
76106
to_copy_op = self.get_decomposition(node.target)
77107
with graph.inserting_before(node):
78108
cast_before = create_node(
79109
graph,
80110
to_copy_op,
81-
args=(indices,),
111+
args=(input_tensor,),
82112
kwargs={
83113
"dtype": torch.int32,
84114
"memory_format": torch.preserve_format,
85115
},
86116
)
87-
node.replace_input_with(indices, cast_before)
117+
node.replace_input_with(input_tensor, cast_before)
88118

89119
modified_graph = True
90120

91121
if modified_graph:
92122
graph_module.recompile()
93123
graph_module = super().call(graph_module).graph_module
94-
return PassResult(graph_module, True)
124+
return PassResult(graph_module, modified_graph)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env bash
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
set -e
8+
9+
# Install diffusers for Stable Diffusion model test
10+
pip install "diffusers[torch]==0.33.1"
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
# Adapted from Hugging Face's diffusers library:
7+
# https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
8+
#
9+
# Licensed under the Apache License, Version 2.0
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
19+
from transformers import CLIPTextConfig, T5Config
20+
21+
22+
"""
23+
This file defines test configs used to initialize Stable Diffusion module tests.
24+
Module tests in the same directory will import these configs.
25+
26+
To stay aligned with the Stable Diffusion implementation in the HuggingFace Diffusers library,
27+
the configs here are either directly copied from corresponding test files or exported from
28+
pre-trained models used in the Diffusers library.
29+
30+
Licenses:
31+
The test parameters are from Hugging Face's diffusers library and under the Apache 2.0 License,
32+
while the remainder of the code is under the BSD-style license found in the LICENSE file in the
33+
root directory of this source tree.
34+
"""
35+
36+
37+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L56
38+
CLIP_text_encoder_config = CLIPTextConfig(
39+
bos_token_id=0,
40+
eos_token_id=2,
41+
hidden_size=32,
42+
intermediate_size=37,
43+
layer_norm_eps=1e-05,
44+
num_attention_heads=4,
45+
num_hidden_layers=5,
46+
pad_token_id=1,
47+
vocab_size=1000,
48+
hidden_act="gelu",
49+
projection_dim=32,
50+
)
51+
52+
53+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L76
54+
# Exported from: T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5").config
55+
T5_encoder_config = T5Config(
56+
bos_token_id=0,
57+
classifier_dropout=0.0,
58+
d_ff=37,
59+
d_kv=8,
60+
d_model=32,
61+
decoder_start_token_id=0,
62+
dense_act_fn="relu",
63+
dropout_rate=0.1,
64+
eos_token_id=1,
65+
feed_forward_proj="relu",
66+
gradient_checkpointing=False,
67+
initializer_factor=0.002,
68+
is_encoder_decoder=True,
69+
is_gated_act=False,
70+
layer_norm_epsilon=1e-06,
71+
model_type="t5",
72+
num_decoder_layers=5,
73+
num_heads=4,
74+
num_layers=5,
75+
pad_token_id=0,
76+
relative_attention_max_distance=128,
77+
relative_attention_num_buckets=8,
78+
transformers_version="4.47.1",
79+
vocab_size=1000,
80+
)
81+
82+
83+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/models/transformers/test_models_transformer_sd3.py#L142
84+
SD3Transformer2DModel_init_dict = {
85+
"sample_size": 32,
86+
"patch_size": 1,
87+
"in_channels": 4,
88+
"num_layers": 4,
89+
"attention_head_dim": 8,
90+
"num_attention_heads": 4,
91+
"caption_projection_dim": 32,
92+
"joint_attention_dim": 32,
93+
"pooled_projection_dim": 64,
94+
"out_channels": 4,
95+
"pos_embed_max_size": 96,
96+
"dual_attention_layers": (0,),
97+
"qk_norm": "rms_norm",
98+
}
99+
100+
101+
# Source: https://github.com/huggingface/diffusers/blob/v0.33.1/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py#L83
102+
AutoencoderKL_config = {
103+
"sample_size": 32,
104+
"in_channels": 3,
105+
"out_channels": 3,
106+
"block_out_channels": (4,),
107+
"layers_per_block": 1,
108+
"latent_channels": 4,
109+
"norm_num_groups": 1,
110+
"use_quant_conv": False,
111+
"use_post_quant_conv": False,
112+
"shift_factor": 0.0609,
113+
"scaling_factor": 1.5035,
114+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm._passes import InsertCastForOpsWithInt64InputPass
11+
12+
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.test.models.stable_diffusion.stable_diffusion_module_test_configs import (
14+
CLIP_text_encoder_config,
15+
)
16+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
17+
from transformers import CLIPTextModelWithProjection
18+
19+
20+
class TestCLIPTextModelWithProjection(unittest.TestCase):
21+
"""
22+
Test class of CLIPTextModelWithProjection.
23+
CLIPTextModelWithProjection is one of the text_encoder used by Stable Diffusion 3.5 Medium
24+
"""
25+
26+
# Adjust nbr below as we increase op support. Note: most of the delegates
27+
# calls are directly consecutive to each other in the .pte. The reason
28+
# for that is some assert ops are removed by passes in the
29+
# .to_executorch step, i.e. after Arm partitioner.
30+
ops_after_partitioner = {
31+
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 3,
32+
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
33+
"executorch_exir_dialects_edge__ops_aten_index_Tensor": 1,
34+
"executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1,
35+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
36+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
37+
"torch.ops.higher_order.executorch_call_delegate": 3,
38+
}
39+
40+
def _prepare_inputs(
41+
self,
42+
batch_size=12,
43+
seq_length=7,
44+
vocab_size=1000,
45+
):
46+
input_ids = torch.randint(
47+
low=0,
48+
high=vocab_size,
49+
size=(batch_size, seq_length),
50+
dtype=torch.long,
51+
)
52+
return (input_ids,)
53+
54+
def prepare_model_and_inputs(self):
55+
clip_text_encoder_config = CLIP_text_encoder_config
56+
57+
text_encoder_model = CLIPTextModelWithProjection(clip_text_encoder_config)
58+
text_encoder_model.eval()
59+
text_encoder_model_inputs = self._prepare_inputs()
60+
61+
return text_encoder_model, text_encoder_model_inputs
62+
63+
def test_CLIPTextModelWithProjection_tosa_MI(self):
64+
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
65+
with torch.no_grad():
66+
(
67+
ArmTester(
68+
text_encoder_model,
69+
example_inputs=text_encoder_model_inputs,
70+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
71+
transform_passes=[InsertCastForOpsWithInt64InputPass()],
72+
)
73+
.export()
74+
.to_edge_transform_and_lower()
75+
.dump_operator_distribution()
76+
.check_count(self.ops_after_partitioner)
77+
.to_executorch()
78+
.run_method_and_compare_outputs(
79+
inputs=text_encoder_model_inputs,
80+
)
81+
)
82+
83+
# MLETORCH-867, MLETORCH-1059
84+
# Failures: "Fatal Python error: Aborted, Dependency cycles, KeyError in CastInt64BuffersToInt32Pass")
85+
@unittest.expectedFailure
86+
def test_CLIPTextModelWithProjection_tosa_BI(self):
87+
text_encoder_model, text_encoder_model_inputs = self.prepare_model_and_inputs()
88+
with torch.no_grad():
89+
(
90+
ArmTester(
91+
text_encoder_model,
92+
example_inputs=text_encoder_model_inputs,
93+
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+INT"),
94+
)
95+
.quantize()
96+
.export()
97+
.to_edge_transform_and_lower()
98+
.dump_operator_distribution()
99+
.to_executorch()
100+
.run_method_and_compare_outputs(
101+
inputs=text_encoder_model_inputs,
102+
)
103+
)

0 commit comments

Comments
 (0)