Skip to content

Commit 9d0ed8a

Browse files
lockwopatrick-kidger
authored andcommitted
comment
1 parent c3c6504 commit 9d0ed8a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

equinox/nn/_batch_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _norm(y, m, v, w, b):
244244
counter = state.get(self.batch_counter)
245245
hidden_mean, hidden_var = state.get(self.batch_state_index)
246246
if inference:
247-
# Zero-debias approach: average_ = hidden_ / (1 - decay^counter)
247+
# Zero-debias approach: mean = hidden_mean / (1 - momentum^counter)
248248
# For simplicity we do the minimal version here (no warmup).
249249
scale = 1 - self.momentum**counter
250250
mean = hidden_mean / scale

0 commit comments

Comments
 (0)