Skip to content

Commit 81109a5

Browse files
committed
propagate similar to the pipeline tests.
1 parent e9fee7c commit 81109a5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/pipelines/test_pipelines.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2334,21 +2334,21 @@ def test_hotswapping_pipeline(self, rank0, rank1):
23342334
def test_hotswapping_compiled_pipline_linear(self, rank0, rank1):
23352335
# It's important to add this context to raise an error on recompilation
23362336
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
2337-
with torch._dynamo.config.patch(error_on_recompile=True):
2337+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
23382338
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
23392339

23402340
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
23412341
def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1):
23422342
# It's important to add this context to raise an error on recompilation
23432343
target_modules = ["conv", "conv1", "conv2"]
2344-
with torch._dynamo.config.patch(error_on_recompile=True):
2344+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
23452345
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
23462346

23472347
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
23482348
def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1):
23492349
# It's important to add this context to raise an error on recompilation
23502350
target_modules = ["to_q", "conv"]
2351-
with torch._dynamo.config.patch(error_on_recompile=True):
2351+
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
23522352
self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
23532353

23542354
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):

0 commit comments

Comments
 (0)