File tree Expand file tree Collapse file tree 2 files changed +14
-1
lines changed
Expand file tree Collapse file tree 2 files changed +14
-1
lines changed Original file line number Diff line number Diff line change @@ -52,7 +52,10 @@ def __init__(
5252 wkey , bkey = jrandom .split (key , 2 )
5353 in_features_ = 1 if in_features == "scalar" else in_features
5454 out_features_ = 1 if out_features == "scalar" else out_features
55- lim = 1 / math .sqrt (in_features_ )
55+ if in_features_ == 0 :
56+ lim = 1.0
57+ else :
58+ lim = 1 / math .sqrt (in_features_ )
5659 wshape = (out_features_ , in_features_ )
5760 self .weight = default_init (wkey , wshape , dtype , lim )
5861 bshape = (out_features_ ,)
Original file line number Diff line number Diff line change @@ -22,6 +22,16 @@ def test_custom_init():
2222
2323
2424def test_linear (getkey ):
25+ # Zero input shape
26+ linear = eqx .nn .Linear (0 , 4 , key = getkey ())
27+ x = jrandom .normal (getkey (), (0 ,))
28+ assert linear (x ).shape == (4 ,)
29+
30+ # Zero output shape
31+ linear = eqx .nn .Linear (4 , 0 , key = getkey ())
32+ x = jrandom .normal (getkey (), (4 ,))
33+ assert linear (x ).shape == (0 ,)
34+
2535 # Positional arguments
2636 linear = eqx .nn .Linear (3 , 4 , key = getkey ())
2737 x = jrandom .normal (getkey (), (3 ,))
You can’t perform that action at this time.
0 commit comments