Skip to content

Commit eea0436

Browse files
committed
updates
1 parent 01eb2e5 commit eea0436

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def test_forward_with_norm_groups(self):
114114
def test_set_attn_processor_for_determinism(self):
115115
return
116116

117+
@unittest.skip("Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'")
118+
def test_layerwise_casting_training(self):
119+
return super().test_layerwise_casting_training()
120+
117121
@unittest.skip(
118122
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
119123
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"

tests/models/test_modeling_common.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,18 +1350,19 @@ def test_fn(storage_dtype, compute_dtype):
13501350
model.train()
13511351

13521352
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1353-
output = model(**inputs_dict)
1353+
with torch.amp.autocast(device_type=torch.device(torch_device).type):
1354+
output = model(**inputs_dict)
13541355

1355-
if isinstance(output, dict):
1356-
output = output.to_tuple()[0]
1356+
if isinstance(output, dict):
1357+
output = output.to_tuple()[0]
13571358

1358-
input_tensor = inputs_dict[self.main_input_name]
1359-
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
1360-
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
1361-
loss = torch.nn.functional.mse_loss(output, noise)
1359+
input_tensor = inputs_dict[self.main_input_name]
1360+
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
1361+
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
1362+
loss = torch.nn.functional.mse_loss(output, noise)
1363+
13621364
loss.backward()
13631365

1364-
13651366
test_fn(torch.float16, torch.float32)
13661367
test_fn(torch.float8_e4m3fn, torch.float32)
13671368
test_fn(torch.float8_e5m2, torch.float32)

tests/models/unets/test_models_unet_1d.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def test_ema_training(self):
6060
def test_training(self):
6161
pass
6262

63+
@unittest.skip("Test not supported.")
64+
def test_layerwise_casting_training(self):
65+
pass
66+
6367
def test_determinism(self):
6468
super().test_determinism()
6569

@@ -239,6 +243,10 @@ def test_ema_training(self):
239243
def test_training(self):
240244
pass
241245

246+
@unittest.skip("Test not supported.")
247+
def test_layerwise_casting_training(self):
248+
pass
249+
242250
def prepare_init_args_and_inputs_for_common(self):
243251
init_dict = {
244252
"in_channels": 14,

0 commit comments

Comments
 (0)