File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed
tests/tests_fabric/strategies Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -84,16 +84,16 @@ def test_reapply_compile():
8484 fabric .launch ()
8585
8686 model = BoringModel ()
87- compile_kwargs = {"mode" : "reduce-overhead" }
88- compiled_model = torch .compile (model , ** compile_kwargs )
87+ # compile_kwargs = {"mode": "reduce-overhead"}
88+ compiled_model = torch .compile (model ) # , **compile_kwargs
8989 torch .compile .reset_mock ()
9090
9191 fabric_model = fabric .setup (compiled_model , _reapply_compile = True )
9292
9393 assert isinstance (fabric_model ._forward_module , OptimizedModule )
9494 assert isinstance (fabric_model ._forward_module ._orig_mod , DistributedDataParallel )
9595 # Assert we called compile again with the same arguments, but on the DDP-wrapped module
96- torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod , ** compile_kwargs )
96+ torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod ) # , **compile_kwargs
9797
9898 assert fabric_model ._original_module == model
9999 assert fabric_model ._forward_module ._orig_mod .module == model
Original file line number Diff line number Diff line change @@ -411,8 +411,8 @@ def test_reapply_compile():
411411 fabric .launch ()
412412
413413 model = BoringModel ()
414- compile_kwargs = {"mode" : "reduce-overhead" }
415- compiled_model = torch .compile (model , ** compile_kwargs )
414+ # compile_kwargs = {"mode": "reduce-overhead"}
415+ compiled_model = torch .compile (model ) # , **compile_kwargs
416416 torch .compile .reset_mock ()
417417
418418 fabric_model = fabric .setup (compiled_model , _reapply_compile = True )
@@ -421,7 +421,7 @@ def test_reapply_compile():
421421 assert isinstance (fabric_model ._forward_module ._orig_mod , FullyShardedDataParallel )
422422
423423 # Assert we called compile again with the same arguments, but on the FSDP-wrapped module
424- torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod , ** compile_kwargs )
424+ torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod ) # , **compile_kwargs
425425
426426 assert fabric_model ._original_module == model
427427 assert fabric_model ._forward_module ._orig_mod .module == model
You can’t perform that action at this time.
0 commit comments