Skip to content

Commit 4756522

Browse files
committed
update
1 parent d108c18 commit 4756522

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

tests/models/test_modeling_common.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,21 +1169,19 @@ def test_disk_offload_without_safetensors(self):
11691169
base_output = model(**inputs_dict)
11701170

11711171
model_size = compute_module_sizes(model)[""]
1172+
max_size = int(self.model_split_percents[0] * model_size)
1173+
# Force disk offload by setting very small CPU memory
1174+
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
1175+
11721176
with tempfile.TemporaryDirectory() as tmp_dir:
11731177
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
1174-
11751178
with self.assertRaises(ValueError):
1176-
max_size = int(self.model_split_percents[0] * model_size)
1177-
max_memory = {0: max_size, "cpu": max_size}
11781179
# This errors out because it's missing an offload folder
11791180
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
11801181

1181-
max_size = int(self.model_split_percents[0] * model_size)
1182-
max_memory = {0: max_size, "disk": max_size}
11831182
new_model = self.model_class.from_pretrained(
11841183
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
11851184
)
1186-
__import__("ipdb").set_trace()
11871185

11881186
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
11891187
torch.manual_seed(0)

tests/models/transformers/test_models_transformer_omnigen.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class OmniGenTransformerTests(ModelTesterMixin, unittest.TestCase):
3030
model_class = OmniGenTransformer2DModel
3131
main_input_name = "hidden_states"
3232
uses_custom_attn_processor = True
33+
model_split_percents = [0.1, 0.1, 0.1]
3334

3435
@property
3536
def dummy_input(self):
@@ -73,9 +74,9 @@ def prepare_init_args_and_inputs_for_common(self):
7374
"num_attention_heads": 4,
7475
"num_key_value_heads": 4,
7576
"intermediate_size": 32,
76-
"num_layers": 1,
77+
"num_layers": 20,
7778
"pad_token_id": 0,
78-
"vocab_size": 100,
79+
"vocab_size": 1000,
7980
"in_channels": 4,
8081
"time_step_dim": 4,
8182
"rope_scaling": {"long_factor": list(range(1, 3)), "short_factor": list(range(1, 3))},

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
3434
model_class = SD3Transformer2DModel
3535
main_input_name = "hidden_states"
36-
model_split_percents = [0.4, 0.8, 0.9]
36+
model_split_percents = [0.8, 0.8, 0.9]
3737

3838
@property
3939
def dummy_input(self):
@@ -108,7 +108,7 @@ def test_gradient_checkpointing_is_applied(self):
108108
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
109109
model_class = SD3Transformer2DModel
110110
main_input_name = "hidden_states"
111-
model_split_percents = [0.4, 0.8, 0.9]
111+
model_split_percents = [0.8, 0.8, 0.9]
112112

113113
@property
114114
def dummy_input(self):
@@ -143,7 +143,7 @@ def prepare_init_args_and_inputs_for_common(self):
143143
"sample_size": 32,
144144
"patch_size": 1,
145145
"in_channels": 4,
146-
"num_layers": 5,
146+
"num_layers": 4,
147147
"attention_head_dim": 8,
148148
"num_attention_heads": 4,
149149
"caption_projection_dim": 32,

tests/single_file/test_model_flux_transformer_single_file.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
FluxTransformer2DModel,
2323
)
2424
from diffusers.utils.testing_utils import (
25+
slow,
2526
backend_empty_cache,
2627
enable_full_determinism,
2728
require_torch_accelerator,
@@ -32,6 +33,7 @@
3233
enable_full_determinism()
3334

3435

36+
@slow
3537
@require_torch_accelerator
3638
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
3739
model_class = FluxTransformer2DModel

0 commit comments

Comments
 (0)