Skip to content

Commit d108c18

Browse files
committed
update
1 parent e2d2650 commit d108c18

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

tests/models/test_modeling_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,10 +1179,11 @@ def test_disk_offload_without_safetensors(self):
11791179
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
11801180

11811181
max_size = int(self.model_split_percents[0] * model_size)
1182-
max_memory = {0: max_size, "cpu": max_size}
1182+
max_memory = {0: max_size, "disk": max_size}
11831183
new_model = self.model_class.from_pretrained(
11841184
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
11851185
)
1186+
__import__("ipdb").set_trace()
11861187

11871188
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
11881189
torch.manual_seed(0)

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
3434
model_class = SD3Transformer2DModel
3535
main_input_name = "hidden_states"
36+
model_split_percents = [0.4, 0.8, 0.9]
3637

3738
@property
3839
def dummy_input(self):
@@ -67,7 +68,7 @@ def prepare_init_args_and_inputs_for_common(self):
6768
"sample_size": 32,
6869
"patch_size": 1,
6970
"in_channels": 4,
70-
"num_layers": 1,
71+
"num_layers": 4,
7172
"attention_head_dim": 8,
7273
"num_attention_heads": 4,
7374
"caption_projection_dim": 32,
@@ -107,6 +108,7 @@ def test_gradient_checkpointing_is_applied(self):
107108
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
108109
model_class = SD3Transformer2DModel
109110
main_input_name = "hidden_states"
111+
model_split_percents = [0.4, 0.8, 0.9]
110112

111113
@property
112114
def dummy_input(self):
@@ -141,7 +143,7 @@ def prepare_init_args_and_inputs_for_common(self):
141143
"sample_size": 32,
142144
"patch_size": 1,
143145
"in_channels": 4,
144-
"num_layers": 2,
146+
"num_layers": 5,
145147
"attention_head_dim": 8,
146148
"num_attention_heads": 4,
147149
"caption_projection_dim": 32,

0 commit comments

Comments
 (0)