Skip to content

Commit 20539c9

Browse files
Rename NNCheckpoint to ModelCheckpoint as Model can be NN or ONNX (#4540)
1 parent 1377421 commit 20539c9

File tree

5 files changed

+31
-27
lines changed

5 files changed

+31
-27
lines changed

ml-agents/mlagents/trainers/policy/checkpoint_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010

1111
@attr.s(auto_attribs=True)
12-
class NNCheckpoint:
12+
class ModelCheckpoint:
1313
steps: int
1414
file_path: str
1515
reward: Optional[float]
1616
creation_time: float
1717

1818

19-
class NNCheckpointManager:
19+
class ModelCheckpointManager:
2020
@staticmethod
2121
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
2222
checkpoint_list = GlobalTrainingStatus.get_parameter_state(
@@ -60,12 +60,12 @@ def _cleanup_extra_checkpoints(
6060
while len(checkpoints) > keep_checkpoints:
6161
if keep_checkpoints <= 0 or len(checkpoints) == 0:
6262
break
63-
NNCheckpointManager.remove_checkpoint(checkpoints.pop(0))
63+
ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0))
6464
return checkpoints
6565

6666
@classmethod
6767
def add_checkpoint(
68-
cls, behavior_name: str, new_checkpoint: NNCheckpoint, keep_checkpoints: int
68+
cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int
6969
) -> None:
7070
"""
7171
Make room for new checkpoint if needed and insert new checkpoint information.
@@ -83,7 +83,7 @@ def add_checkpoint(
8383

8484
@classmethod
8585
def track_final_checkpoint(
86-
cls, behavior_name: str, final_checkpoint: NNCheckpoint
86+
cls, behavior_name: str, final_checkpoint: ModelCheckpoint
8787
) -> None:
8888
"""
8989
Ensures number of checkpoints stored is within the max number of checkpoints

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import os
88

99
import numpy as np
10-
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
10+
from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint
1111

1212
from mlagents_envs.logging_util import get_logger
1313
from mlagents_envs.timers import timed
@@ -88,7 +88,7 @@ def __init__(
8888

8989
self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer
9090

91-
def _checkpoint(self) -> NNCheckpoint:
91+
def _checkpoint(self) -> ModelCheckpoint:
9292
"""
9393
Writes a checkpoint model to memory
9494
Overrides the default to save the replay buffer.

ml-agents/mlagents/trainers/tests/test_rl_trainer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from unittest import mock
33
import pytest
44
import mlagents.trainers.tests.mock_brain as mb
5-
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
5+
from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint
66
from mlagents.trainers.trainer.rl_trainer import RLTrainer
77
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
88
from mlagents.trainers.agent_processor import AgentManagerQueue
@@ -126,7 +126,9 @@ def test_advance(mocked_clear_update_buffer, mocked_save_model):
126126
"framework", [FrameworkType.TENSORFLOW, FrameworkType.PYTORCH], ids=["tf", "torch"]
127127
)
128128
@mock.patch("mlagents.trainers.trainer.trainer.StatsReporter.write_stats")
129-
@mock.patch("mlagents.trainers.trainer.rl_trainer.NNCheckpointManager.add_checkpoint")
129+
@mock.patch(
130+
"mlagents.trainers.trainer.rl_trainer.ModelCheckpointManager.add_checkpoint"
131+
)
130132
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework):
131133
trainer = create_rl_trainer(framework)
132134
mock_policy = mock.Mock()
@@ -170,7 +172,7 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework):
170172
add_checkpoint_calls = [
171173
mock.call(
172174
trainer.brain_name,
173-
NNCheckpoint(
175+
ModelCheckpoint(
174176
step,
175177
f"{trainer.model_saver.model_path}/{trainer.brain_name}-{step}.{export_ext}",
176178
None,

ml-agents/mlagents/trainers/tests/test_training_status.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
GlobalTrainingStatus,
1010
)
1111
from mlagents.trainers.policy.checkpoint_manager import (
12-
NNCheckpointManager,
13-
NNCheckpoint,
12+
ModelCheckpointManager,
13+
ModelCheckpoint,
1414
)
1515

1616

@@ -78,25 +78,27 @@ def test_model_management(tmpdir):
7878
brain_name, StatusType.CHECKPOINTS, test_checkpoint_list
7979
)
8080

81-
new_checkpoint_4 = NNCheckpoint(
81+
new_checkpoint_4 = ModelCheckpoint(
8282
4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time()
8383
)
84-
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
85-
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
84+
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
85+
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
8686

87-
new_checkpoint_5 = NNCheckpoint(
87+
new_checkpoint_5 = ModelCheckpoint(
8888
5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time()
8989
)
90-
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
91-
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
90+
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
91+
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
9292

9393
final_model_path = f"{final_model_path}.nn"
9494
final_model_time = time.time()
9595
current_step = 6
96-
final_model = NNCheckpoint(current_step, final_model_path, 3.294, final_model_time)
96+
final_model = ModelCheckpoint(
97+
current_step, final_model_path, 3.294, final_model_time
98+
)
9799

98-
NNCheckpointManager.track_final_checkpoint(brain_name, final_model)
99-
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
100+
ModelCheckpointManager.track_final_checkpoint(brain_name, final_model)
101+
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4
100102

101103
check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
102104
StatusType.CHECKPOINTS.value

ml-agents/mlagents/trainers/trainer/rl_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import time
66
import attr
77
from mlagents.trainers.policy.checkpoint_manager import (
8-
NNCheckpoint,
9-
NNCheckpointManager,
8+
ModelCheckpoint,
9+
ModelCheckpointManager,
1010
)
1111
from mlagents_envs.logging_util import get_logger
1212
from mlagents_envs.timers import timed
@@ -176,7 +176,7 @@ def _policy_mean_reward(self) -> Optional[float]:
176176
return sum(rewards) / len(rewards)
177177

178178
@timed
179-
def _checkpoint(self) -> NNCheckpoint:
179+
def _checkpoint(self) -> ModelCheckpoint:
180180
"""
181181
Checkpoints the policy associated with this trainer.
182182
"""
@@ -187,13 +187,13 @@ def _checkpoint(self) -> NNCheckpoint:
187187
)
188188
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step)
189189
export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx"
190-
new_checkpoint = NNCheckpoint(
190+
new_checkpoint = ModelCheckpoint(
191191
int(self.step),
192192
f"{checkpoint_path}.{export_ext}",
193193
self._policy_mean_reward(),
194194
time.time(),
195195
)
196-
NNCheckpointManager.add_checkpoint(
196+
ModelCheckpointManager.add_checkpoint(
197197
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints
198198
)
199199
return new_checkpoint
@@ -217,7 +217,7 @@ def save_model(self) -> None:
217217
final_checkpoint = attr.evolve(
218218
model_checkpoint, file_path=f"{self.model_saver.model_path}.{export_ext}"
219219
)
220-
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
220+
ModelCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
221221

222222
@abc.abstractmethod
223223
def _update_policy(self) -> bool:

0 commit comments

Comments
 (0)