Skip to content

Commit 0a16826

Browse files
committed
skip tests for AutoencoderOobleckTests
1 parent 59e04c3 commit 0a16826

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,24 @@ def test_forward_with_norm_groups(self):
114114
def test_set_attn_processor_for_determinism(self):
115115
return
116116

117+
@unittest.skip(
118+
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
119+
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
120+
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
121+
"2. Unskip this test."
122+
)
123+
def test_layerwise_upcasting_inference(self):
124+
pass
125+
126+
@unittest.skip(
127+
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
128+
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
129+
"1. Make sure `nn::Module::to` works with `torch.nn.utils.weight_norm` wrapped convolution layer.\n"
130+
"2. Unskip this test."
131+
)
132+
def test_layerwise_upcasting_memory(self):
133+
pass
134+
117135

118136
@slow
119137
class AutoencoderOobleckIntegrationTests(unittest.TestCase):

tests/models/test_modeling_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,6 +1366,7 @@ def check_linear_dtype(module, storage_dtype, compute_dtype):
13661366
if any(re.search(pattern, name) for pattern in patterns_to_check):
13671367
dtype_to_check = compute_dtype
13681368
if getattr(submodule, "weight", None) is not None:
1369+
print(name, submodule.weight.dtype, dtype_to_check, patterns_to_check)
13691370
self.assertEqual(submodule.weight.dtype, dtype_to_check)
13701371
if getattr(submodule, "bias", None) is not None:
13711372
self.assertEqual(submodule.bias.dtype, dtype_to_check)
@@ -1377,6 +1378,7 @@ def test_layerwise_upcasting(storage_dtype, compute_dtype):
13771378
model = self.model_class(**config).eval()
13781379
model = model.to(torch_device, dtype=compute_dtype)
13791380
model.enable_layerwise_upcasting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1381+
13801382
check_linear_dtype(model, storage_dtype, compute_dtype)
13811383
output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy()
13821384

0 commit comments

Comments
 (0)