@@ -68,7 +68,7 @@ def __init__(
6868 betas : BETAS = (0.9 , 0.999 ),
6969 density : float = 1.0 ,
7070 weight_decay : float = 0.0 ,
71- warmup_epoch : int = 150 ,
71+ warmup_epoch : int = 50 ,
7272 threshold : int = 5000 ,
7373 grad_accu_steps : int = 20 ,
7474 update_proj_gap : int = 500 ,
@@ -90,11 +90,12 @@ def __init__(
9090 self .threshold = threshold
9191 self .grad_accu_steps = grad_accu_steps
9292 self .update_proj_gap = update_proj_gap
93- self .warmup = CosineDecay (0.99 , warmup_epoch )
9493
9594 defaults : DEFAULTS = {'lr' : lr , 'betas' : betas , 'weight_decay' : weight_decay , 'eps' : eps , ** kwargs }
9695 super ().__init__ (params , defaults )
9796
97+ self .warmup = CosineDecay (0.99 , self .warmup_epoch )
98+
9899 self .init_masks ()
99100
100101 self .state ['total_step' ] = 0
@@ -119,17 +120,16 @@ def initialize_random_rank_boolean_tensor(m: int, n: int, density: float, device
119120
120121 return tensor .view (m , n )
121122
122- def update_mask_random (self , density : float , p : torch .Tensor , old_mask : torch .Tensor ) -> torch .Tensor :
123+ def update_mask_random (self , p : torch .Tensor , old_mask : torch .Tensor ) -> torch .Tensor :
123124 r"""Update a random mask.
124125
125126 Create a new random mask with the same density, compute overlap ratio with old_mask, and update the EMA for
126127 the overlap region.
127128
128- :param density: float. fraction of elements to keep.
129129 :param p: torch.Tensor. parameter to which the mask is applied.
130130 :param old_mask: torch.Tensor. previous binary mask.
131131 """
132- new_mask : torch .Tensor = torch .rand_like (p ) < density
132+ new_mask : torch .Tensor = torch .rand_like (p ) < self . density
133133
134134 exp_avg = torch .zeros_like (p [new_mask ])
135135 exp_avg_sq = torch .zeros_like (p [new_mask ])
@@ -155,8 +155,8 @@ def update_masks(self) -> None:
155155 for group in self .param_groups :
156156 for p in group ['params' ]:
157157 state = self .state [p ]
158- if 'mask' in state :
159- state ['mask' ] = self .update_mask_random (self . density , p , state ['mask' ])
158+ if p . dim () == 2 and 'mask' in state :
159+ state ['mask' ] = self .update_mask_random (p , state ['mask' ])
160160 p .mask = state ['mask' ]
161161
162162 def init_masks (self ) -> None :
@@ -177,13 +177,7 @@ def __str__(self) -> str:
177177
178178 @torch .no_grad ()
179179 def reset (self ):
180- for group in self .param_groups :
181- group ['step' ] = 0
182- for p in group ['params' ]:
183- state = self .state [p ]
184-
185- state ['exp_avg' ] = torch .zeros_like (p )
186- state ['exp_avg_sq' ] = torch .zeros_like (p )
180+ pass
187181
188182 @torch .no_grad ()
189183 def step (self , closure : CLOSURE = None ) -> LOSS :
@@ -220,11 +214,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
220214 if 'mask' in state :
221215 grad = grad [state ['mask' ]]
222216
223- if 'exp_avg' not in state :
224- state ['exp_avg' ] = torch .zeros_like (grad )
225- state ['exp_avg_sq' ] = torch .zeros_like (grad )
226-
227- if (self .state ['total_step' ] + 1 ) % self .update_proj_gap == 0 :
217+ if ('exp_avg' not in state ) or (self .state ['total_step' ] + 1 ) % self .update_proj_gap == 0 :
228218 state ['exp_avg' ] = torch .zeros_like (grad )
229219 state ['exp_avg_sq' ] = torch .zeros_like (grad )
230220
0 commit comments