Skip to content

Commit b98ed10

Browse files
committed
reduce-overhead
1 parent be212ca commit b98ed10

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/tests_fabric/strategies/test_ddp_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,16 @@ def test_reapply_compile():
8484
fabric.launch()
8585

8686
model = BoringModel()
87-
compile_kwargs = {"mode": "reduce-overhead"}
88-
compiled_model = torch.compile(model, **compile_kwargs)
87+
# compile_kwargs = {"mode": "reduce-overhead"}
88+
compiled_model = torch.compile(model) # , **compile_kwargs
8989
torch.compile.reset_mock()
9090

9191
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
9292

9393
assert isinstance(fabric_model._forward_module, OptimizedModule)
9494
assert isinstance(fabric_model._forward_module._orig_mod, DistributedDataParallel)
9595
# Assert we called compile again with the same arguments, but on the DDP-wrapped module
96-
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
96+
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod) # , **compile_kwargs
9797

9898
assert fabric_model._original_module == model
9999
assert fabric_model._forward_module._orig_mod.module == model

tests/tests_fabric/strategies/test_fsdp_integration.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ def test_reapply_compile():
411411
fabric.launch()
412412

413413
model = BoringModel()
414-
compile_kwargs = {"mode": "reduce-overhead"}
415-
compiled_model = torch.compile(model, **compile_kwargs)
414+
# compile_kwargs = {"mode": "reduce-overhead"}
415+
compiled_model = torch.compile(model) # , **compile_kwargs
416416
torch.compile.reset_mock()
417417

418418
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
@@ -421,7 +421,7 @@ def test_reapply_compile():
421421
assert isinstance(fabric_model._forward_module._orig_mod, FullyShardedDataParallel)
422422

423423
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
424-
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
424+
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod) # , **compile_kwargs
425425

426426
assert fabric_model._original_module == model
427427
assert fabric_model._forward_module._orig_mod.module == model

0 commit comments

Comments
 (0)