Skip to content

Commit 009ab50

Browse files
author
Alexander Ororbia
committed
added log-gaussian initializer to distribution_generator
1 parent 2b418c1 commit 009ab50

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

ngclearn/utils/distribution_generator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,47 @@ def gaussian_generator(shape: Sequence[int], dKey: jax.Array | None = None) -> j
151151

152152
return gaussian_generator
153153

154+
@staticmethod
155+
def log_gaussian(sigma: float = 1.0, **params: Unpack[DistributionParams]) -> DistributionInitializer:
156+
"""
157+
Produces a distribution initializer for a log-Gaussian/normal distribution. Note that this
158+
distribution is constrained to be centered (zero-mean); thus, only a scale/standard-devation
159+
`sigma` can be provided as argument. This is a useful distribution to produce non-negative/
160+
positive-valued sample values.
161+
162+
Args:
163+
sigma: standard deviation of the underlying normal distribution (Default: 1.)
164+
**params: the extra distribution parameters
165+
166+
Returns:
167+
a distribution initializer
168+
"""
169+
using_np = params.get("use_numpy", False)
170+
171+
if using_np:
172+
def log_gaussian_generator(shape: Sequence[int], seed: int | None = None) -> numpy.ndarray:
173+
rng = numpy.random.default_rng(seed)
174+
matrix = rng.lognormal(mean=0.0, sigma=sigma, size=shape).astype(
175+
params.get("dtype", numpy.float32))
176+
matrix = DistributionGenerator._process_params_numpy(matrix, params, seed)
177+
return matrix
178+
else:
179+
def log_gaussian_generator(shape: Sequence[int], dKey: jax.Array | None = None) -> jax.Array:
180+
if dKey is None:
181+
dKey = jax.random.PRNGKey(time.time_ns())
182+
dKey, subKey = jax.random.split(dKey, 2)
183+
184+
matrix = jax.random.lognormal(
185+
dkey,
186+
sigma=sigma,
187+
shape=shape,
188+
dtype=params.get("dtype", jax.numpy.float32)
189+
)
190+
matrix = DistributionGenerator._process_params_jax(matrix, params, subKey)
191+
return matrix
192+
193+
return log_gaussian_generator
194+
154195
@staticmethod
155196
def fan_in_uniform(**params: Unpack[DistributionParams]) -> DistributionInitializer:
156197
"""

0 commit comments

Comments
 (0)