Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/models/autoencoders/test_models_autoencoder_oobleck.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def test_forward_with_norm_groups(self):
def test_set_attn_processor_for_determinism(self):
return

@unittest.skip(
"Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'"
)
def test_layerwise_casting_training(self):
return super().test_layerwise_casting_training()

@unittest.skip(
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
Expand Down
30 changes: 30 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,36 @@ def test_variant_sharded_ckpt_right_format(self):
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)

def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
return
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict)
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()

inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict)

if isinstance(output, dict):
output = output.to_tuple()[0]

input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)

loss.backward()

test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)

def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS

Expand Down
8 changes: 8 additions & 0 deletions tests/models/unets/test_models_unet_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def test_ema_training(self):
def test_training(self):
pass

@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass

def test_determinism(self):
super().test_determinism()

Expand Down Expand Up @@ -239,6 +243,10 @@ def test_ema_training(self):
def test_training(self):
pass

@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
Expand Down