Skip to content

Commit af9bf9c

Browse files
authored
Merge pull request #1580 from chao1995/new-optimizer-for-clone
PyTorchClassifier: use a new optimizer for the cloned classifier
2 parents 8549e0e + d9969d3 commit af9bf9c

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

art/estimators/classification/pytorch.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,13 +497,21 @@ def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwarg
497497

498498
def clone_for_refitting(self) -> "PyTorchClassifier": # lgtm [py/inheritance/incorrect-overridden-signature]
499499
"""
500-
Create a copy of the classifier that can be refit from scratch. Will inherit same architecture, optimizer and
501-
initialization as cloned model, but without weights.
500+
Create a copy of the classifier that can be refit from scratch. Will inherit same architecture, same type of
501+
optimizer and initialization as the original classifier, but without weights.
502502
503503
:return: new estimator
504504
"""
505505
model = copy.deepcopy(self.model)
506-
clone = type(self)(model, self._loss, self.input_shape, self.nb_classes, optimizer=self._optimizer)
506+
507+
if self._optimizer is None: # pragma: no cover
508+
raise ValueError("An optimizer is needed to train the model, but none is provided.")
509+
510+
# create a new optimizer that binds to the cloned model's parameters and uses original optimizer's defaults
511+
new_optimizer = type(self._optimizer)(model.parameters(), **self._optimizer.defaults) # type: ignore
512+
513+
clone = type(self)(model, self._loss, self.input_shape, self.nb_classes, optimizer=new_optimizer)
514+
507515
# reset weights
508516
clone.reset()
509517
params = self.get_params()

0 commit comments

Comments
 (0)