Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions drjit/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,50 @@ class Tanh(Module):
DRJIT_STRUCT = { }
def __call__(self, arg: CoopVec, /) -> CoopVec:
return drjit.tanh(arg)


class Sigmoid(Module):
r"""
Sigmoid activation function.

.. math::
\mathrm{Sigmoid}(x) = \frac{1}{1 + e^{-x}} = 0.5 + 0.5 \cdot \tanh(x/2)
"""
DRJIT_STRUCT = {}

def __call__(self, arg: CoopVec, /) -> CoopVec:
# Use the identity: sigmoid(x) = 0.5 + 0.5 * tanh(x/2)
half_x = arg * 0.5
tanh_half_x = drjit.tanh(half_x)
return drjit.fma(0.5, tanh_half_x, 0.5) # 0.5 * tanh + 0.5

class SiLU(Module):
r"""
SiLU activation function. Also known as the "swish" function.
Copy link
Member

@wjakob wjakob Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document the expression here as well?

"""
DRJIT_STRUCT = {}

def __call__(self, arg: CoopVec, /) -> CoopVec:
# Use the identity: sigmoid(x) = 0.5 + 0.5 * tanh(x/2)
half_x = arg * 0.5
tanh_half_x = drjit.tanh(half_x)
sigmoid = drjit.fma(0.5, tanh_half_x, 0.5) # 0.5 * tanh + 0.5
return arg * sigmoid

class Softplus(Module):
r"""
Softplus activation function.

.. math::
\mathrm{Softplus}(x) = \log(1 + e^x)
"""
DRJIT_STRUCT = {}

def __call__(self, arg: CoopVec, /) -> CoopVec:
# For numerical stability: log(1 + exp(x)) = x + log(1 + exp(-x)) when x > 0
# Using exp2: log(1 + exp(x)) = log(1 + 2^(x/ln(2))) = ln(2) * log2(1 + 2^(x/ln(2)))
x_log2 = arg * (1 / drjit.log(2))
return drjit.log(2) * drjit.log2(1.0 + drjit.exp2(x_log2))

class ScaleAdd(Module):
r"""
Expand Down
92 changes: 83 additions & 9 deletions drjit/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def _step(
v_tp: dr.ArrayBase # Second moment EMA state from previous iteration

# Unpack optimizer state
t_p, m_tp, v_tp = extra
t_p, m_tp, v_tp = extra[:3]

# Increase the iteration count
t = t_p + 1
Expand Down Expand Up @@ -1146,7 +1146,7 @@ def __init__(
beta_1: float = 0.9,
beta_2: float = 0.999,
epsilon: float = 1e-8,
weight_decay: float = 0.01,
weight_decay: Optional[float | Mapping[str, float]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight decay is an AdamW-specific feature. I would prefer to keep this 100% in the AdamW subclass so that we don't have to pay any cost (like unpacking above) in the general optimizer.

When optimizing simple calculations (e.g. neural nets) without any ray tracing, the Python code in the optimizer can actually inhibit saturating the GPU, hence this focus on keeping the code here as simple/efficient as possible. It was also the motivation for a larger rewrite when moving the optimizer code from Mitsuba to Dr.Jit.

AdamW can store an optional per-parameter weight decay override using the extra field (setting this value to None by default).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not get your comment here. The other optimizers do not have any access to the weight_decay parameter. I just have an additional dictionary in AdamW and increase the size of the tuple (extra) inset_weight_decay function. Can you elaborate your comment about keeping everything in AdamW?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misread the diff and had thought that you were modifying the base optimizer. I think it would be better if the set_weight_decay function works like the set_learning_rate function, please check out the docstring of this function. Right now, all of the optimizers are written so that they don't use key-based dictionary lookups in opt.step(). They iterate over dictionaries but never explicitly search a dictionary for a string-based key (which involves string hashing etc.)

mask_updates: bool = False,
promote_fp16: bool = True,
uniform: bool = False,
Expand All @@ -1172,10 +1172,12 @@ def __init__(
will cause past gradients to persist for a longer amount of
time.

weight_decay (float):
weight_decay (float | Mapping[str, float]):
Weight decay coefficient for L2 regularization. Unlike Adam,
this is applied directly to parameters rather than gradients,
providing better regularization with adaptive learning rates.
You may also provide a dictionary mapping parameter names to
individual weight decay values.

uniform (bool):
If enabled, the optimizer will use the *UniformAdam* variant of
Expand All @@ -1195,7 +1197,18 @@ def __init__(
Optional dictionary-like object containing an initial set of
parameters.
"""
# Store per-parameter weight decay
self.weight_decay_dict: Mapping[str, float] = {}

if (weight_decay is None) or (isinstance(weight_decay, Mapping)):
self.global_weight_decay = 0.0
elif not isinstance(weight_decay, float):
raise TypeError(
"'weight_decay' must be None, a float, or a mapping from parameter names to floats"
)
else:
self.global_weight_decay = weight_decay

super().__init__(
lr,
params,
Expand All @@ -1206,11 +1219,48 @@ def __init__(
promote_fp16=promote_fp16,
uniform=uniform
)

self.set_weight_decay(weight_decay)

if weight_decay < 0:
raise RuntimeError("'weight_decay' must be >= 0")
def set_weight_decay(self, value: Union[float, Mapping[str, float], None] = None):
"""
Set the weight decay globally or per parameter.
Args:
value: float to set globally, or dict mapping parameter names to decay.
"""
if isinstance(value, float):
# Global weight decay
if value < 0:
raise ValueError("weight_decay must be non-negative.")
self.global_weight_decay = value

elif isinstance(value, Mapping):
for k, wd in value.items():
self.weight_decay_dict[k] = wd
else:
raise ValueError("weight_decay must be a float or a mapping")

for k in self.state:
decay = self.weight_decay_dict.get(k, self.global_weight_decay)

Float = dr.float32_array_t(dr.leaf_t(self.state[k][0]))
self.state[k] = self.state[k][0], self.state[k][1], self.state[k][2], (
self.state[k][3][0],
self.state[k][3][1],
self.state[k][3][2],
Float(decay),
)

self.weight_decay = weight_decay
def _reset(self, key: str, value: dr.ArrayBase, promoted: bool, /) -> None:
valarr = value.array
tp = type(valarr)
UInt = dr.uint32_array_t(dr.leaf_t(tp))
Float = dr.float32_array_t(dr.leaf_t(tp))
t = UInt(0)
m_t = dr.opaque(tp, 0, valarr.shape)
v_t = dr.opaque(tp, 0, valarr.shape)
decay = self.weight_decay_dict.get(key, self.global_weight_decay)
self.state[key] = value, promoted, None, (t, m_t, v_t, Float(decay))

def _step(
self,
Expand All @@ -1221,9 +1271,33 @@ def _step(
extra: Tuple[int, dr.ArrayBase, dr.ArrayBase],
/,
) -> Tuple[dr.ArrayBase, Tuple[int, dr.ArrayBase, dr.ArrayBase]]:

decay = extra[3]
#Take Adam step
new_value, new_extra = super()._step(cache, value, grad, lr, extra)
scaled_value = dr.fma(value, -lr * self.weight_decay, new_value)

# Take weight decay step
scaled_value = dr.fma(value, -lr * decay, new_value)

new_extra = (new_extra[0], new_extra[1], new_extra[2], decay)
return scaled_value, new_extra

def _select(
self,
mask: dr.ArrayBase,
extra: Tuple[int, dr.ArrayBase, dr.ArrayBase],
new_extra: Tuple[int, dr.ArrayBase, dr.ArrayBase, float],
/,
) -> Tuple[int, dr.ArrayBase, dr.ArrayBase]:
# Known issue: we don't mask the update to 't' here. That would
# require moving this parameter to the GPU, with a whole bunch
# of downsides. It is only relevant for AMP training. Oh well.
return (
new_extra[0],
dr.select(mask, extra[1], new_extra[1]),
dr.select(mask, extra[2], new_extra[2]),
new_extra[3]
)

def __repr__(self):
"""Return a human-readable string representation"""
Expand All @@ -1242,7 +1316,7 @@ def __repr__(self):
" lr = %s,\\n"
" beta = (%g, %g),\\n"
" epsilon = %g,\\n"
" weight_decay = %g\\n"
" weight_decay = %s\\n"
"]"
% (
list(self.keys()),
Expand All @@ -1251,7 +1325,7 @@ def __repr__(self):
self.beta_1,
self.beta_2,
self.epsilon,
self.weight_decay,
self.weight_decay_dict,
)
)

Expand Down