Skip to content

Commit 01eb2e5

Browse files
committed
updates
1 parent 2883fa9 commit 01eb2e5

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,8 @@ def test_variant_sharded_ckpt_right_format(self):
13401340

13411341
def test_layerwise_casting_training(self):
13421342
def test_fn(storage_dtype, compute_dtype):
1343+
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
1344+
return
13431345
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
13441346

13451347
model = self.model_class(**init_dict)
@@ -1355,6 +1357,7 @@ def test_fn(storage_dtype, compute_dtype):
13551357

13561358
input_tensor = inputs_dict[self.main_input_name]
13571359
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
1360+
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
13581361
loss = torch.nn.functional.mse_loss(output, noise)
13591362
loss.backward()
13601363

0 commit comments

Comments
 (0)