-
Notifications
You must be signed in to change notification settings - Fork 56
Adam with weight decay per parameter + some activation functions #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1011,7 +1011,7 @@ def _step( | |
| v_tp: dr.ArrayBase # Second moment EMA state from previous iteration | ||
|
|
||
| # Unpack optimizer state | ||
| t_p, m_tp, v_tp = extra | ||
| t_p, m_tp, v_tp = extra[:3] | ||
|
|
||
| # Increase the iteration count | ||
| t = t_p + 1 | ||
|
|
@@ -1146,7 +1146,7 @@ def __init__( | |
| beta_1: float = 0.9, | ||
| beta_2: float = 0.999, | ||
| epsilon: float = 1e-8, | ||
| weight_decay: float = 0.01, | ||
| weight_decay: Optional[float | Mapping[str, float]] = None, | ||
|
||
| mask_updates: bool = False, | ||
| promote_fp16: bool = True, | ||
| uniform: bool = False, | ||
|
|
@@ -1172,10 +1172,12 @@ def __init__( | |
| will cause past gradients to persist for a longer amount of | ||
| time. | ||
|
|
||
| weight_decay (float): | ||
| weight_decay (float | Mapping[str, float]): | ||
| Weight decay coefficient for L2 regularization. Unlike Adam, | ||
| this is applied directly to parameters rather than gradients, | ||
| providing better regularization with adaptive learning rates. | ||
| You may also provide a dictionary mapping parameter names to | ||
| individual weight decay values. | ||
|
|
||
| uniform (bool): | ||
| If enabled, the optimizer will use the *UniformAdam* variant of | ||
|
|
@@ -1195,7 +1197,18 @@ def __init__( | |
| Optional dictionary-like object containing an initial set of | ||
| parameters. | ||
| """ | ||
| # Store per-parameter weight decay | ||
| self.weight_decay_dict: Mapping[str, float] = {} | ||
|
|
||
| if (weight_decay is None) or (isinstance(weight_decay, Mapping)): | ||
| self.global_weight_decay = 0.0 | ||
| elif not isinstance(weight_decay, float): | ||
| raise TypeError( | ||
| "'weight_decay' must be None, a float, or a mapping from parameter names to floats" | ||
| ) | ||
| else: | ||
| self.global_weight_decay = weight_decay | ||
|
|
||
| super().__init__( | ||
| lr, | ||
| params, | ||
|
|
@@ -1206,11 +1219,48 @@ def __init__( | |
| promote_fp16=promote_fp16, | ||
| uniform=uniform | ||
| ) | ||
|
|
||
| self.set_weight_decay(weight_decay) | ||
|
|
||
| if weight_decay < 0: | ||
| raise RuntimeError("'weight_decay' must be >= 0") | ||
| def set_weight_decay(self, value: Union[float, Mapping[str, float], None] = None): | ||
| """ | ||
| Set the weight decay globally or per parameter. | ||
| Args: | ||
| value: float to set globally, or dict mapping parameter names to decay. | ||
| """ | ||
| if isinstance(value, float): | ||
| # Global weight decay | ||
| if value < 0: | ||
| raise ValueError("weight_decay must be non-negative.") | ||
| self.global_weight_decay = value | ||
|
|
||
| elif isinstance(value, Mapping): | ||
| for k, wd in value.items(): | ||
| self.weight_decay_dict[k] = wd | ||
| else: | ||
| raise ValueError("weight_decay must be a float or a mapping") | ||
|
|
||
| for k in self.state: | ||
| decay = self.weight_decay_dict.get(k, self.global_weight_decay) | ||
|
|
||
| Float = dr.float32_array_t(dr.leaf_t(self.state[k][0])) | ||
| self.state[k] = self.state[k][0], self.state[k][1], self.state[k][2], ( | ||
| self.state[k][3][0], | ||
| self.state[k][3][1], | ||
| self.state[k][3][2], | ||
| Float(decay), | ||
| ) | ||
|
|
||
| self.weight_decay = weight_decay | ||
| def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None: | ||
| valarr = value.array | ||
| tp = type(valarr) | ||
| UInt = dr.uint32_array_t(dr.leaf_t(tp)) | ||
| Float = dr.float32_array_t(dr.leaf_t(tp)) | ||
| t = UInt(0) | ||
| m_t = dr.opaque(tp, 0, valarr.shape) | ||
| v_t = dr.opaque(tp, 0, valarr.shape) | ||
| decay = self.weight_decay_dict.get(key, self.global_weight_decay) | ||
| self.state[key] = value, promoted, None, (t, m_t, v_t, Float(decay)) | ||
|
|
||
| def _step( | ||
| self, | ||
|
|
@@ -1221,9 +1271,33 @@ def _step( | |
| extra: Tuple[int, dr.ArrayBase, dr.ArrayBase], | ||
| /, | ||
| ) -> Tuple[dr.ArrayBase, Tuple[int, dr.ArrayBase, dr.ArrayBase]]: | ||
|
|
||
| decay = extra[3] | ||
| #Take Adam step | ||
| new_value, new_extra = super()._step(cache, value, grad, lr, extra) | ||
| scaled_value = dr.fma(value, -lr * self.weight_decay, new_value) | ||
|
|
||
| # Take weight decay step | ||
| scaled_value = dr.fma(value, -lr * decay, new_value) | ||
|
|
||
| new_extra = (new_extra[0], new_extra[1], new_extra[2], decay) | ||
| return scaled_value, new_extra | ||
|
|
||
| def _select( | ||
| self, | ||
| mask: dr.ArrayBase, | ||
| extra: Tuple[int, dr.ArrayBase, dr.ArrayBase], | ||
| new_extra: Tuple[int, dr.ArrayBase, dr.ArrayBase, float], | ||
| /, | ||
| ) -> Tuple[int, dr.ArrayBase, dr.ArrayBase]: | ||
| # Known issue: we don't mask the update to 't' here. That would | ||
| # require moving this parameter to the GPU, with a whole bunch | ||
| # of downsides. It is only relevant for AMP training. Oh well. | ||
| return ( | ||
| new_extra[0], | ||
| dr.select(mask, extra[1], new_extra[1]), | ||
| dr.select(mask, extra[2], new_extra[2]), | ||
| new_extra[3] | ||
| ) | ||
|
|
||
| def __repr__(self): | ||
| """Return a human-readable string representation""" | ||
|
|
@@ -1242,7 +1316,7 @@ def __repr__(self): | |
| " lr = %s,\\n" | ||
| " beta = (%g, %g),\\n" | ||
| " epsilon = %g,\\n" | ||
| " weight_decay = %g\\n" | ||
| " weight_decay = %s\\n" | ||
| "]" | ||
| % ( | ||
| list(self.keys()), | ||
|
|
@@ -1251,7 +1325,7 @@ def __repr__(self): | |
| self.beta_1, | ||
| self.beta_2, | ||
| self.epsilon, | ||
| self.weight_decay, | ||
| self.weight_decay_dict, | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you document the expression here as well?