Skip to content

Commit c1e2ca1

Browse files
Bug fix: DictStateful fails to load state_dict (#1028)
Summary: Pull Request resolved: #1028 ### Bug Fix: DictStateful Fails to Load State Dict #### Overview This diff addresses a bug in the `DictStateful` class where it fails to load the state dictionary. The fix involves modifying the `state_dict` method to return a copy of the dictionary instead of the original dictionary. Additionally, the `load_state_dict` method is updated to clear the dictionary before loading the new state. #### Changes * `fbcode/torchtnt/utils/stateful.py`: * The `state_dict` method is updated to return a copy of the dictionary (`self.copy()`) instead of the original dictionary (`self`). * `fbcode/torchtnt/tests/utils/meta/test_stateful.py`: * Test cases are added to verify the correctness of the `DictStateful` class, including a test for checkpointing. #### Impact This bug fix ensures that the `DictStateful` class functions correctly when saving and loading its state dictionary. This is crucial for maintaining the accuracy and reliability of machine learning models that rely on this class. Reviewed By: JKSenthil Differential Revision: D81674115 fbshipit-source-id: 6ee8e8a5d91733dd46a59198316554cff68f8b7f
1 parent a3f20b0 commit c1e2ca1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchtnt/utils/stateful.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class DictStateful(Stateful, Dict[str, Any]):
7979
"""A dictionary that implements the stateful interface that can be saved and loaded from checkpoints."""
8080

8181
def state_dict(self) -> Dict[str, Any]:
82-
return self
82+
return self.copy()
8383

8484
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
8585
self.clear()

0 commit comments

Comments
 (0)