Skip to content

Commit fd62045

Browse files
committed
Fix distributions tests
1 parent 014c590 commit fd62045

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

homepy/blocks/distributions.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,8 @@ def rng_fn(cls, rng, dim, alpha, size):
130130
return v / np.sqrt(np.sum(v**2, axis=-1, keepdims=True) + W[..., None])
131131

132132
def _supp_shape_from_params(self, dist_params, param_shapes=None, rep_param_idx=1):
133-
dim = dist_params[0]
134-
ref_param = dist_params[rep_param_idx]
135-
if param_shapes is not None:
136-
ref_param_shape = param_shapes[rep_param_idx]
137-
return (*ref_param_shape, dim.flatten()[0])
138-
else:
139-
return (*ref_param.shape, dim.flatten()[0])
133+
dim = dist_params[0].flatten()[0]
134+
return (dim,)
140135

141136

142137
hyperballUniformRV = HyperballUniformRV()
@@ -156,9 +151,10 @@ def dist(cls, dim, alpha=1.0, no_assert: bool = False, **kwargs):
156151
@staticmethod
157152
def support_point(rv, size, dim, alpha):
158153
"""Define the moment (initial point) for the RV"""
154+
dim = dim.flatten()[0]
159155
moment = pt.ones((dim,), dtype=rv.dtype) * 0.5 / pt.sqrt(dim)
160156
if not rv_size_is_none(size):
161-
moment_size = pt.concatenate([size, [dim.flatten()[0]]])
157+
moment_size = pt.concatenate([size, [dim]])
162158
moment = pt.full(moment_size, moment)
163159
return moment
164160

@@ -168,8 +164,17 @@ def ball_transform(op, rv):
168164
return ballTransform
169165

170166

167+
def log1mexp_numpy(x):
168+
return np.where(
169+
x < np.log(0.5),
170+
np.log1p(-np.exp(x)),
171+
np.log(-np.expm1(x)),
172+
)
173+
174+
171175
@_logprob.register(HyperballUniformRV)
172176
def logp(op, values, rng, size, dim, alpha, **kwargs):
177+
dim = dim.flatten()[0]
173178
value = values[0]
174179
norm = pt.sqrt(pt.sum(value**2, axis=-1))
175180
return check_parameters(
@@ -239,7 +244,7 @@ def _inverse_rng_fn(cls, rng, theta, lam, dist_size, idxs_mask):
239244
log1p_lam_m_C = np.where(
240245
pos_lam,
241246
np.log1p(np.exp(abs_log_lam - log_c)),
242-
pm.math.log1mexp_numpy(abs_log_lam - log_c, negative_input=True),
247+
log1mexp_numpy(abs_log_lam - log_c),
243248
)
244249
log_p = log_c + log1p_lam_m_C * (x_ - 1) + log_p - np.log(x_) - lam
245250
log_s = np.logaddexp(log_s, log_p)

homepy/blocks/likelihoods.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ def get_link(link_function=None):
3838
try:
3939
link = getattr(pm.math, link_function)
4040
except AttributeError:
41-
link = getattr(pt, link_function)
41+
try:
42+
link = getattr(pt, link_function)
43+
except AttributeError:
44+
link = getattr(pt.special, link_function)
4245
else:
4346
raise TypeError(
4447
f"Cannot understand supplied link function type {type(link_function)}."

homepy/tests/blocks/test_gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def test_build_correlator_matrix(dims):
552552
assert m.coords["feature_correlator_"] == (*range(dims - 1), "ignored")
553553
ball_rv = m["prefix_correlator"]
554554
assert isinstance(ball_rv.owner.op, HyperballUniformRV)
555-
_, size, _, dim, alpha = m["prefix_correlator"].owner.inputs
555+
_, size, dim, alpha = m["prefix_correlator"].owner.inputs
556556
assert tuple(size.eval()) == ()
557557
assert dim.eval() == dims
558558
assert alpha.eval() == 1.5

homepy/tests/blocks/test_likelihoods.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_parse_observed_features_and_dims_error():
141141
(
142142
NegBinLikelihood,
143143
dict(concentration_prior=[1.0]),
144-
("add", "exp", "exponential_rv", "nbinom_rv"),
144+
("add", "exp", "exponential_rv", "negative_binomial_rv"),
145145
("Abs", "sqr", "halfnormal_rv"),
146146
),
147147
(
@@ -151,7 +151,7 @@ def test_parse_observed_features_and_dims_error():
151151
concentration_dist=pm.Exponential,
152152
concentration_prior=dict(lam=pt.sqr(1.0)),
153153
),
154-
("add", "exp", "sqr", "exponential_rv", "nbinom_rv"),
154+
("add", "exp", "sqr", "exponential_rv", "negative_binomial_rv"),
155155
("Abs", "halfnormal_rv"),
156156
),
157157
(
@@ -162,7 +162,7 @@ def test_parse_observed_features_and_dims_error():
162162
concentration_name="concentration",
163163
concentration_prior=None,
164164
),
165-
("add", "Abs", "halfnormal_rv", "nbinom_rv"),
165+
("add", "Abs", "halfnormal_rv", "negative_binomial_rv"),
166166
("exp", "sqr", "exponential_rv"),
167167
),
168168
(

0 commit comments

Comments
 (0)