Skip to content

Commit bc83cb8

Browse files
committed
add test for checking compile on different shapes.
1 parent cd81349 commit bc83cb8

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/models/test_modeling_common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,6 +1961,22 @@ def test_compile_with_group_offloading(self):
19611961
_ = model(**inputs_dict)
19621962
_ = model(**inputs_dict)
19631963

1964+
def test_compile_on_different_shapes(self):
1965+
torch.fx.experimental._config.use_duck_shape = False
1966+
1967+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1968+
model = self.model_class(**init_dict).to(torch_device)
1969+
model = torch.compile(model, fullgraph=True, dynamic=True)
1970+
1971+
with (
1972+
torch._inductor.utils.fresh_inductor_cache(),
1973+
torch._dynamo.config.patch(error_on_recompile=True),
1974+
torch.no_grad(),
1975+
):
1976+
print(f"{inputs_dict.keys()=}")
1977+
out = model(**inputs_dict)
1978+
assert out is None
1979+
19641980

19651981
@slow
19661982
@require_torch_2

0 commit comments

Comments
 (0)