From 899583d9ef58a66b6498b3155faef93db14f3988 Mon Sep 17 00:00:00 2001 From: efylmzr Date: Thu, 25 Sep 2025 10:44:11 +0200 Subject: [PATCH 1/2] Adam with weight decay per parameter + some activation functions --- drjit/nn.py | 44 +++++++++++++++++++++++++ drjit/opt.py | 92 +++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 127 insertions(+), 9 deletions(-) diff --git a/drjit/nn.py b/drjit/nn.py index 25a27e919..a19840f3b 100644 --- a/drjit/nn.py +++ b/drjit/nn.py @@ -232,6 +232,50 @@ class Tanh(Module): DRJIT_STRUCT = { } def __call__(self, arg: CoopVec, /) -> CoopVec: return drjit.tanh(arg) + + +class Sigmoid(Module): + r""" + Sigmoid activation function. + + .. math:: + \mathrm{Sigmoid}(x) = \frac{1}{1 + e^{-x}} = 0.5 + 0.5 \cdot \tanh(x/2) + """ + DRJIT_STRUCT = {} + + def __call__(self, arg: CoopVec, /) -> CoopVec: + # Use the identity: sigmoid(x) = 0.5 + 0.5 * tanh(x/2) + half_x = arg * 0.5 + tanh_half_x = drjit.tanh(half_x) + return drjit.fma(0.5, tanh_half_x, 0.5) # 0.5 * tanh + 0.5 + +class SiLU(Module): + r""" + SiLU activation function. Also known as the "swish" function. + """ + DRJIT_STRUCT = {} + + def __call__(self, arg: CoopVec, /) -> CoopVec: + # Use the identity: sigmoid(x) = 0.5 + 0.5 * tanh(x/2) + half_x = arg * 0.5 + tanh_half_x = drjit.tanh(half_x) + sigmoid = drjit.fma(0.5, tanh_half_x, 0.5) # 0.5 * tanh + 0.5 + return arg * sigmoid + +class Softplus(Module): + r""" + Softplus activation function. + + .. math:: + \mathrm{Softplus}(x) = \log(1 + e^x) + """ + DRJIT_STRUCT = {} + + def __call__(self, arg: CoopVec, /) -> CoopVec: + # For numerical stability: log(1 + exp(x)) = x + log(1 + exp(-x)) when x > 0 + # Using exp2: log(1 + exp(x)) = log(1 + 2^(x/ln(2))) = ln(2) * log2(1 + 2^(x/ln(2))) + x_log2 = arg * (1 / drjit.log(2)) + return drjit.log(2) * drjit.log2(1.0 + drjit.exp2(x_log2)) class ScaleAdd(Module): r""" diff --git a/drjit/opt.py b/drjit/opt.py index 3c6197b40..e3146d30d 100644 --- a/drjit/opt.py +++ b/drjit/opt.py @@ -1011,7 +1011,7 @@ def _step( v_tp: dr.ArrayBase # Second moment EMA state from previous iteration # Unpack optimizer state - t_p, m_tp, v_tp = extra + t_p, m_tp, v_tp = extra[:3] # Increase the iteration count t = t_p + 1 @@ -1146,7 +1146,7 @@ def __init__( beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-8, - weight_decay: float = 0.01, + weight_decay: Optional[float | Mapping[str, float]] = None, mask_updates: bool = False, promote_fp16: bool = True, uniform: bool = False, @@ -1172,10 +1172,12 @@ def __init__( will cause past gradients to persist for a longer amount of time. - weight_decay (float): + weight_decay (float | Mapping[str, float]): Weight decay coefficient for L2 regularization. Unlike Adam, this is applied directly to parameters rather than gradients, providing better regularization with adaptive learning rates. + You may also provide a dictionary mapping parameter names to + individual weight decay values. uniform (bool): If enabled, the optimizer will use the *UniformAdam* variant of @@ -1195,7 +1197,18 @@ def __init__( Optional dictionary-like object containing an initial set of parameters. """ + # Store per-parameter weight decay + self.weight_decay_dict: Mapping[str, float] = {} + if (weight_decay is None) or (isinstance(weight_decay, Mapping)): + self.global_weight_decay = 0.0 + elif not isinstance(weight_decay, float): + raise TypeError( + "'weight_decay' must be None, a float, or a mapping from parameter names to floats" + ) + else: + self.global_weight_decay = weight_decay + super().__init__( lr, params, @@ -1206,11 +1219,48 @@ def __init__( promote_fp16=promote_fp16, uniform=uniform ) + + self.set_weight_decay(weight_decay) - if weight_decay < 0: - raise RuntimeError("'weight_decay' must be >= 0") + def set_weight_decay(self, value: Union[float, Mapping[str, float], None] = None): + """ + Set the weight decay globally or per parameter. + Args: + value: float to set globally, or dict mapping parameter names to decay. + """ + if isinstance(value, float): + # Global weight decay + if value < 0: + raise ValueError("weight_decay must be non-negative.") + self.global_weight_decay = value + + elif isinstance(value, Mapping): + for k, wd in value.items(): + self.weight_decay_dict[k] = wd + else: + raise ValueError("weight_decay must be a float or a mapping") + + for k in self.state: + decay = self.weight_decay_dict.get(k, self.global_weight_decay) + + Float = dr.float32_array_t(dr.leaf_t(self.state[k][0])) + self.state[k] = self.state[k][0], self.state[k][1], self.state[k][2], ( + self.state[k][3][0], + self.state[k][3][1], + self.state[k][3][2], + Float(decay), + ) - self.weight_decay = weight_decay + def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None: + valarr = value.array + tp = type(valarr) + UInt = dr.uint32_array_t(dr.leaf_t(tp)) + Float = dr.float32_array_t(dr.leaf_t(tp)) + t = UInt(0) + m_t = dr.opaque(tp, 0, valarr.shape) + v_t = dr.opaque(tp, 0, valarr.shape) + decay = self.weight_decay_dict.get(key, self.global_weight_decay) + self.state[key] = value, promoted, None, (t, m_t, v_t, Float(decay)) def _step( self, @@ -1221,9 +1271,33 @@ def _step( extra: Tuple[int, dr.ArrayBase, dr.ArrayBase], /, ) -> Tuple[dr.ArrayBase, Tuple[int, dr.ArrayBase, dr.ArrayBase]]: + + decay = extra[3] + #Take Adam step new_value, new_extra = super()._step(cache, value, grad, lr, extra) - scaled_value = dr.fma(value, -lr * self.weight_decay, new_value) + + # Take weight decay step + scaled_value = dr.fma(value, -lr * decay, new_value) + + new_extra = (new_extra[0], new_extra[1], new_extra[2], decay) return scaled_value, new_extra + + def _select( + self, + mask: dr.ArrayBase, + extra: Tuple[int, dr.ArrayBase, dr.ArrayBase], + new_extra: Tuple[int, dr.ArrayBase, dr.ArrayBase, float], + /, + ) -> Tuple[int, dr.ArrayBase, dr.ArrayBase]: + # Known issue: we don't mask the update to 't' here. That would + # require moving this parameter to the GPU, with a whole bunch + # of downsides. It is only relevant for AMP training. Oh well. + return ( + new_extra[0], + dr.select(mask, extra[1], new_extra[1]), + dr.select(mask, extra[2], new_extra[2]), + new_extra[3] + ) def __repr__(self): """Return a human-readable string representation""" @@ -1242,7 +1316,7 @@ def __repr__(self): " lr = %s,\\n" " beta = (%g, %g),\\n" " epsilon = %g,\\n" - " weight_decay = %g\\n" + " weight_decay = %s\\n" "]" % ( list(self.keys()), @@ -1251,7 +1325,7 @@ def __repr__(self): self.beta_1, self.beta_2, self.epsilon, - self.weight_decay, + self.weight_decay_dict, ) ) From 163e423515da54e1d2a0284d9428f6a02e2f494e Mon Sep 17 00:00:00 2001 From: efylmzr Date: Fri, 26 Sep 2025 11:25:44 +0200 Subject: [PATCH 2/2] simple changes for PR --- drjit/nn.py | 3 +++ drjit/opt.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/drjit/nn.py b/drjit/nn.py index a19840f3b..8f1ab5cb5 100644 --- a/drjit/nn.py +++ b/drjit/nn.py @@ -252,6 +252,9 @@ def __call__(self, arg: CoopVec, /) -> CoopVec: class SiLU(Module): r""" SiLU activation function. Also known as the "swish" function. + .. math:: + \mathrm{SiLU}(x) = x \cdot \mathrm{Sigmoid}(x) + = \frac{x}{1 + e^{-x}} """ DRJIT_STRUCT = {} diff --git a/drjit/opt.py b/drjit/opt.py index e3146d30d..a4ccfd4cb 100644 --- a/drjit/opt.py +++ b/drjit/opt.py @@ -1011,7 +1011,7 @@ def _step( v_tp: dr.ArrayBase # Second moment EMA state from previous iteration # Unpack optimizer state - t_p, m_tp, v_tp = extra[:3] + t_p, m_tp, v_tp = extra # Increase the iteration count t = t_p + 1 @@ -1136,7 +1136,7 @@ class AdamW(Adam): """ # Weight decay coefficient - weight_decay: float + global_weight_decay: float def __init__( self, @@ -1146,7 +1146,7 @@ def __init__( beta_1: float = 0.9, beta_2: float = 0.999, epsilon: float = 1e-8, - weight_decay: Optional[float | Mapping[str, float]] = None, + weight_decay: Union[float, Mapping[str, float]] = None, mask_updates: bool = False, promote_fp16: bool = True, uniform: bool = False, @@ -1274,7 +1274,7 @@ def _step( decay = extra[3] #Take Adam step - new_value, new_extra = super()._step(cache, value, grad, lr, extra) + new_value, new_extra = super()._step(cache, value, grad, lr, extra[:3]) # Take weight decay step scaled_value = dr.fma(value, -lr * decay, new_value)