Skip to content

Commit d08de05

Browse files
aseyboldtpatrick-kidger
authored andcommitted
Allow zero shapes in nn.Linear
1 parent 550967c commit d08de05

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

equinox/nn/_linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ def __init__(
5252
wkey, bkey = jrandom.split(key, 2)
5353
in_features_ = 1 if in_features == "scalar" else in_features
5454
out_features_ = 1 if out_features == "scalar" else out_features
55-
lim = 1 / math.sqrt(in_features_)
55+
if in_features_ == 0:
56+
lim = 1.0
57+
else:
58+
lim = 1 / math.sqrt(in_features_)
5659
wshape = (out_features_, in_features_)
5760
self.weight = default_init(wkey, wshape, dtype, lim)
5861
bshape = (out_features_,)

tests/test_nn.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ def test_custom_init():
2222

2323

2424
def test_linear(getkey):
25+
# Zero input shape
26+
linear = eqx.nn.Linear(0, 4, key=getkey())
27+
x = jrandom.normal(getkey(), (0,))
28+
assert linear(x).shape == (4,)
29+
30+
# Zero output shape
31+
linear = eqx.nn.Linear(4, 0, key=getkey())
32+
x = jrandom.normal(getkey(), (4,))
33+
assert linear(x).shape == (0,)
34+
2535
# Positional arguments
2636
linear = eqx.nn.Linear(3, 4, key=getkey())
2737
x = jrandom.normal(getkey(), (3,))

0 commit comments

Comments
 (0)