Skip to content

Commit 2883fa9

Browse files
committed
add a test to check if we can train with layerwise casting.
1 parent 464374f commit 2883fa9

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,33 @@ def test_variant_sharded_ckpt_right_format(self):
13381338
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
13391339
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
13401340

1341+
def test_layerwise_casting_training(self):
1342+
def test_fn(storage_dtype, compute_dtype):
1343+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1344+
1345+
model = self.model_class(**init_dict)
1346+
model = model.to(torch_device, dtype=compute_dtype)
1347+
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1348+
model.train()
1349+
1350+
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1351+
output = model(**inputs_dict)
1352+
1353+
if isinstance(output, dict):
1354+
output = output.to_tuple()[0]
1355+
1356+
input_tensor = inputs_dict[self.main_input_name]
1357+
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
1358+
loss = torch.nn.functional.mse_loss(output, noise)
1359+
loss.backward()
1360+
1361+
1362+
test_fn(torch.float16, torch.float32)
1363+
test_fn(torch.float8_e4m3fn, torch.float32)
1364+
test_fn(torch.float8_e5m2, torch.float32)
1365+
test_fn(torch.float8_e4m3fn, torch.bfloat16)
1366+
1367+
13411368
def test_layerwise_casting_inference(self):
13421369
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
13431370

0 commit comments

Comments
 (0)