Skip to content

Commit 52c7104

Browse files
committed
switch to smaller model + test inference
1 parent a078342 commit 52c7104

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefi
203203
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
204204
parameters.
205205
206-
Note: We fully disable this if we are using `deepspeed`
207206
"""
208207
if model_to_load.device.type == "meta":
209208
return False

tests/models/test_modeling_common.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,44 @@ def test_keep_modules_in_fp32(self):
345345
SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
346346

347347
model = SD3Transformer2DModel.from_pretrained(
348-
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="transformer", torch_dtype=torch_dtype
349-
)
348+
"hf-internal-testing/tiny-sd3-pipe", subfolder="transformer", torch_dtype=torch_dtype
349+
).to("cuda")
350350

351351
for name, module in model.named_modules():
352352
if isinstance(module, torch.nn.Linear):
353353
if name in model._keep_in_fp32_modules:
354354
self.assertTrue(module.weight.dtype == torch.float32)
355355
else:
356356
self.assertTrue(module.weight.dtype == torch_dtype)
357+
358+
def get_dummy_inputs():
359+
batch_size = 2
360+
num_channels = 4
361+
height = width = embedding_dim = 32
362+
pooled_embedding_dim = embedding_dim * 2
363+
sequence_length = 154
364+
365+
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
366+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
367+
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
368+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
369+
370+
return {
371+
"hidden_states": hidden_states,
372+
"encoder_hidden_states": encoder_hidden_states,
373+
"pooled_projections": pooled_prompt_embeds,
374+
"timestep": timestep,
375+
}
376+
377+
# test if inference works.
378+
with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch_dtype):
379+
input_dict_for_transformer = get_dummy_inputs()
380+
model_inputs = {
381+
k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
382+
}
383+
model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
384+
_ = model(**model_inputs)
385+
357386
SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
358387

359388

0 commit comments

Comments
 (0)