@@ -151,6 +151,47 @@ def gaussian_generator(shape: Sequence[int], dKey: jax.Array | None = None) -> j
151151
152152 return gaussian_generator
153153
154+ @staticmethod
155+ def log_gaussian (sigma : float = 1.0 , ** params : Unpack [DistributionParams ]) -> DistributionInitializer :
156+ """
157+ Produces a distribution initializer for a log-Gaussian/normal distribution. Note that this
158+ distribution is constrained to be centered (zero-mean); thus, only a scale/standard-devation
159+ `sigma` can be provided as argument. This is a useful distribution to produce non-negative/
160+ positive-valued sample values.
161+
162+ Args:
163+ sigma: standard deviation of the underlying normal distribution (Default: 1.)
164+ **params: the extra distribution parameters
165+
166+ Returns:
167+ a distribution initializer
168+ """
169+ using_np = params .get ("use_numpy" , False )
170+
171+ if using_np :
172+ def log_gaussian_generator (shape : Sequence [int ], seed : int | None = None ) -> numpy .ndarray :
173+ rng = numpy .random .default_rng (seed )
174+ matrix = rng .lognormal (mean = 0.0 , sigma = sigma , size = shape ).astype (
175+ params .get ("dtype" , numpy .float32 ))
176+ matrix = DistributionGenerator ._process_params_numpy (matrix , params , seed )
177+ return matrix
178+ else :
179+ def log_gaussian_generator (shape : Sequence [int ], dKey : jax .Array | None = None ) -> jax .Array :
180+ if dKey is None :
181+ dKey = jax .random .PRNGKey (time .time_ns ())
182+ dKey , subKey = jax .random .split (dKey , 2 )
183+
184+ matrix = jax .random .lognormal (
185+ dkey ,
186+ sigma = sigma ,
187+ shape = shape ,
188+ dtype = params .get ("dtype" , jax .numpy .float32 )
189+ )
190+ matrix = DistributionGenerator ._process_params_jax (matrix , params , subKey )
191+ return matrix
192+
193+ return log_gaussian_generator
194+
154195 @staticmethod
155196 def fan_in_uniform (** params : Unpack [DistributionParams ]) -> DistributionInitializer :
156197 """
0 commit comments