44from typing import Dict , Optional , Union , Callable
55
66import jax
7+ import numpy as np
78import jax .numpy as jnp
89
910from brainpy import math as bm
@@ -63,8 +64,8 @@ def __init__(
6364 num_out : int ,
6465 W_initializer : Union [Initializer , Callable , ArrayType ] = XavierNormal (),
6566 b_initializer : Optional [Union [Initializer , Callable , ArrayType ]] = ZeroInit (),
66- mode : bm .Mode = None ,
67- name : str = None ,
67+ mode : Optional [ bm .Mode ] = None ,
68+ name : Optional [ str ] = None ,
6869 ):
6970 super (Dense , self ).__init__ (mode = mode , name = name )
7071
@@ -642,7 +643,7 @@ def __init__(
642643 num_out : int ,
643644 prob : float ,
644645 weight : float ,
645- seed : int ,
646+ seed : Optional [ int ] = None ,
646647 sharding : Optional [Sharding ] = None ,
647648 mode : Optional [bm .Mode ] = None ,
648649 name : Optional [str ] = None ,
@@ -654,7 +655,7 @@ def __init__(
654655 self .prob = prob
655656 self .sharding = sharding
656657 self .transpose = transpose
657- self .seed = seed
658+ self .seed = np . random . randint ( 0 , 100000 ) if seed is None else seed
658659 self .atomic = atomic
659660 self .num_in = num_in
660661 self .num_out = num_out
@@ -723,7 +724,7 @@ def __init__(
723724 prob : float ,
724725 w_low : float ,
725726 w_high : float ,
726- seed : int ,
727+ seed : Optional [ int ] = None ,
727728 sharding : Optional [Sharding ] = None ,
728729 mode : Optional [bm .Mode ] = None ,
729730 name : Optional [str ] = None ,
@@ -735,7 +736,7 @@ def __init__(
735736 self .prob = prob
736737 self .sharding = sharding
737738 self .transpose = transpose
738- self .seed = seed
739+ self .seed = np . random . randint ( 0 , 100000 ) if seed is None else seed
739740 self .atomic = atomic
740741 self .num_in = num_in
741742 self .num_out = num_out
@@ -803,7 +804,7 @@ def __init__(
803804 prob : float ,
804805 w_mu : float ,
805806 w_sigma : float ,
806- seed : int ,
807+ seed : Optional [ int ] = None ,
807808 sharding : Optional [Sharding ] = None ,
808809 transpose : bool = False ,
809810 atomic : bool = False ,
@@ -815,7 +816,7 @@ def __init__(
815816 self .prob = prob
816817 self .sharding = sharding
817818 self .transpose = transpose
818- self .seed = seed
819+ self .seed = np . random . randint ( 0 , 100000 ) if seed is None else seed
819820 self .atomic = atomic
820821 self .num_in = num_in
821822 self .num_out = num_out
@@ -881,7 +882,7 @@ def __init__(
881882 num_out : int ,
882883 prob : float ,
883884 weight : float ,
884- seed : int ,
885+ seed : Optional [ int ] = None ,
885886 sharding : Optional [Sharding ] = None ,
886887 mode : Optional [bm .Mode ] = None ,
887888 name : Optional [str ] = None ,
@@ -893,7 +894,7 @@ def __init__(
893894 self .prob = prob
894895 self .sharding = sharding
895896 self .transpose = transpose
896- self .seed = seed
897+ self .seed = np . random . randint ( 0 , 1000000 ) if seed is None else seed
897898 self .atomic = atomic
898899 self .num_in = num_in
899900 self .num_out = num_out
@@ -962,7 +963,7 @@ def __init__(
962963 prob : float ,
963964 w_low : float ,
964965 w_high : float ,
965- seed : int ,
966+ seed : Optional [ int ] = None ,
966967 sharding : Optional [Sharding ] = None ,
967968 mode : Optional [bm .Mode ] = None ,
968969 name : Optional [str ] = None ,
@@ -974,7 +975,7 @@ def __init__(
974975 self .prob = prob
975976 self .sharding = sharding
976977 self .transpose = transpose
977- self .seed = seed
978+ self .seed = np . random . randint ( 0 , 100000 ) if seed is None else seed
978979 self .atomic = atomic
979980 self .num_in = num_in
980981 self .num_out = num_out
@@ -1042,7 +1043,7 @@ def __init__(
10421043 prob : float ,
10431044 w_mu : float ,
10441045 w_sigma : float ,
1045- seed : int ,
1046+ seed : Optional [ int ] = None ,
10461047 sharding : Optional [Sharding ] = None ,
10471048 transpose : bool = False ,
10481049 atomic : bool = False ,
@@ -1054,7 +1055,7 @@ def __init__(
10541055 self .prob = prob
10551056 self .sharding = sharding
10561057 self .transpose = transpose
1057- self .seed = seed
1058+ self .seed = np . random . randint ( 0 , 100000 ) if seed is None else seed
10581059 self .atomic = atomic
10591060 self .num_in = num_in
10601061 self .num_out = num_out
0 commit comments