11import functools
2- import warnings
32from typing import Tuple , Union , List , Dict , Any , Optional
43
54import torch
@@ -30,6 +29,7 @@ def __call__(self, mod: torch.optim.Optimizer):
3029
3130class LpWeightDecay (WeightDecayBase ):
3231 def __init__ (self , power : float ):
32+ super ().__init__ ()
3333 self .power = power
3434
3535 def __call__ (self , mod : torch .optim .Optimizer , p : Tensor , idx : int ):
@@ -46,12 +46,17 @@ def __init__(self):
4646 super ().__init__ (1 )
4747
4848
49- def _param_iterator (mod : torch .optim .Optimizer ):
50- yield from (p .detach ().clone () for group in mod .param_groups for p in group ["params" ])
49+ def _detach (x : Tensor ) -> Tensor :
50+ return x .detach ().clone ()
51+
52+
53+ def _param_iterator (mod : torch .optim .Optimizer , fn = _detach ):
54+ yield from (fn (p ) for group in mod .param_groups for p in group ["params" ])
5155
5256
5357class WeightDecayToValue (WeightDecayBase ):
5458 def __init__ (self ):
59+ super ().__init__ ()
5560 self .target_values : List [Tensor ] = ...
5661 self .global_step = 0
5762
@@ -261,6 +266,8 @@ def step(self, closure=None):
261266 original_params = list (_param_iterator (self ))
262267
263268 self .magnitude .step ()
269+ params_flat = list (_param_iterator (self , lambda x : x ))
270+
264271 magnitudes_flat = []
265272 for o , p in zip (original_params , params_flat ):
266273 magnitudes_flat .append (torch .norm (o .double () - p .double ()))
0 commit comments