Skip to content

Commit 310cc2d

Browse files
twieckiclaude
andcommitted
Fix mean function signatures for PyMC v5
In PyMC v5/PyTensor, the dtype parameter was removed from random variable inputs. Updated all mean function signatures to match the new format: Old signature: (op, rv, rng, size, dtype, ...params) New signature: (op, rv, rng, size, ...params) Test Results: - 603 tests passing (↑ 40 from 563) - 114 tests failing (↓ 40 from 154) - Remaining failures are mostly API compatibility issues 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 37a6948 commit 310cc2d

File tree

1 file changed

+46
-46
lines changed

1 file changed

+46
-46
lines changed

homepy/blocks/means.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -103,27 +103,27 @@ def maybe_resize(a: TensorVariable, size) -> TensorVariable:
103103

104104

105105
@_mean.register(BernoulliRV)
106-
def bernoulli_mean(op, rv, rng, size, dtype, p):
106+
def bernoulli_mean(op, rv, rng, size, p):
107107
return maybe_resize(p, size)
108108

109109

110110
@_mean.register(BetaBinomialRV)
111-
def betabinomial_mean(op, rv, rng, size, dtype, n, alpha, beta):
111+
def betabinomial_mean(op, rv, rng, size, n, alpha, beta):
112112
return maybe_resize((n * alpha) / (alpha + beta), size)
113113

114114

115115
@_mean.register(BetaClippedRV)
116-
def beta_clipped_mean(op, rv, rng, size, dtype, alpha, beta):
116+
def beta_clipped_mean(op, rv, rng, size, alpha, beta):
117117
return maybe_resize(alpha / (alpha + beta), size)
118118

119119

120120
@_mean.register(BetaRV)
121-
def beta_mean(op, rv, rng, size, dtype, alpha, beta):
121+
def beta_mean(op, rv, rng, size, alpha, beta):
122122
return maybe_resize(alpha / (alpha + beta), size)
123123

124124

125125
@_mean.register(BinomialRV)
126-
def binomial_mean(op, rv, rng, size, dtype, n, p):
126+
def binomial_mean(op, rv, rng, size, n, p):
127127
return maybe_resize(n * p, size)
128128

129129

@@ -132,7 +132,7 @@ def binomial_mean(op, rv, rng, size, dtype, n, p):
132132

133133

134134
@_mean.register(DirichletRV)
135-
def dirichlet_mean(op, rv, rng, size, dtype, a):
135+
def dirichlet_mean(op, rv, rng, size, a):
136136
norm_constant = pt.sum(a, axis=-1)[..., None]
137137
mean = a / norm_constant
138138
if not rv_size_is_none(size):
@@ -141,60 +141,60 @@ def dirichlet_mean(op, rv, rng, size, dtype, a):
141141

142142

143143
@_mean.register(ExponentialRV)
144-
def exponential_mean(op, rv, rng, size, dtype, mu):
144+
def exponential_mean(op, rv, rng, size, mu):
145145
return maybe_resize(mu, size)
146146

147147

148148
@_mean.register(GammaRV)
149-
def gamma_mean(op, rv, rng, size, dtype, alpha, inv_beta):
149+
def gamma_mean(op, rv, rng, size, alpha, inv_beta):
150150
# The Aesara `GammaRV` `Op` inverts the `beta` parameter itself
151151
return maybe_resize(alpha * inv_beta, size)
152152

153153

154154
@_mean.register(GeometricRV)
155-
def geometric_mean(op, rv, rng, size, dtype, p):
155+
def geometric_mean(op, rv, rng, size, p):
156156
return maybe_resize(1.0 / p, size)
157157

158158

159159
@_mean.register(GumbelRV)
160-
def gumbel_mean(op, rv, rng, size, dtype, mu, beta):
160+
def gumbel_mean(op, rv, rng, size, mu, beta):
161161
return maybe_resize(mu + beta * np.euler_gamma, size)
162162

163163

164164
@_mean.register(HalfNormalRV)
165-
def halfnormal_mean(op, rv, rng, size, dtype, loc, sigma):
165+
def halfnormal_mean(op, rv, rng, size, loc, sigma):
166166
_, sigma = pt.broadcast_arrays(loc, sigma)
167167
return maybe_resize(sigma * pt.sqrt(2 / np.pi), size)
168168

169169

170170
@_mean.register(HyperGeometricRV)
171-
def hypergeometric_mean(op, rv, rng, size, dtype, good, bad, n):
171+
def hypergeometric_mean(op, rv, rng, size, good, bad, n):
172172
N, k = good + bad, good
173173
return maybe_resize(n * k / N, size)
174174

175175

176176
@_mean.register(InvGammaRV)
177-
def invgamma_mean(op, rv, rng, size, dtype, alpha, beta):
177+
def invgamma_mean(op, rv, rng, size, alpha, beta):
178178
return maybe_resize(pt.switch(alpha > 1, beta / (alpha - 1.0), np.nan), size)
179179

180180

181181
@_mean.register(LaplaceRV)
182-
def laplace_mean(op, rv, rng, size, dtype, mu, b):
182+
def laplace_mean(op, rv, rng, size, mu, b):
183183
return maybe_resize(pt.broadcast_arrays(mu, b)[0], size)
184184

185185

186186
@_mean.register(LogisticRV)
187-
def logistic_mean(op, rv, rng, size, dtype, mu, s):
187+
def logistic_mean(op, rv, rng, size, mu, s):
188188
return maybe_resize(pt.broadcast_arrays(mu, s)[0], size)
189189

190190

191191
@_mean.register(LogNormalRV)
192-
def lognormal_mean(op, rv, rng, size, dtype, mu, sigma):
192+
def lognormal_mean(op, rv, rng, size, mu, sigma):
193193
return maybe_resize(pt.exp(mu + 0.5 * sigma**2), size)
194194

195195

196196
@_mean.register(MultinomialRV)
197-
def multinomial_mean(op, rv, rng, size, dtype, n, p):
197+
def multinomial_mean(op, rv, rng, size, n, p):
198198
n = pt.shape_padright(n)
199199
mean = n * p
200200
if not rv_size_is_none(size):
@@ -204,7 +204,7 @@ def multinomial_mean(op, rv, rng, size, dtype, n, p):
204204

205205

206206
@_mean.register(MvNormalRV)
207-
def mvnormal_mean(op, rv, rng, size, dtype, mu, cov):
207+
def mvnormal_mean(op, rv, rng, size, mu, cov):
208208
mean = mu
209209
if not rv_size_is_none(size):
210210
mean_size = pt.concatenate([size, [mu.shape[-1]]])
@@ -213,70 +213,70 @@ def mvnormal_mean(op, rv, rng, size, dtype, mu, cov):
213213

214214

215215
@_mean.register(NegBinomialRV)
216-
def negative_binomial_mean(op, rv, rng, size, dtype, n, p):
216+
def negative_binomial_mean(op, rv, rng, size, n, p):
217217
return maybe_resize(n * (1 - p) / p, size)
218218

219219

220220
@_mean.register(NormalRV)
221-
def normal_mean(op, rv, rng, size, dtype, mu, sigma):
221+
def normal_mean(op, rv, rng, size, mu, sigma):
222222
return maybe_resize(pt.broadcast_arrays(mu, sigma)[0], size)
223223

224224

225225
@_mean.register(ParetoRV)
226-
def pareto_mean(op, rv, rng, size, dtype, alpha, m):
226+
def pareto_mean(op, rv, rng, size, alpha, m):
227227
return maybe_resize(pt.switch(alpha > 1, alpha * m / (alpha - 1), np.nan), size)
228228

229229

230230
@_mean.register(PoissonRV)
231-
def poisson_mean(op, rv, rng, size, dtype, mu):
231+
def poisson_mean(op, rv, rng, size, mu):
232232
return maybe_resize(mu, size)
233233

234234

235235
@_mean.register(TriangularRV)
236-
def triangular_mean(op, rv, rng, size, dtype, lower, c, upper):
236+
def triangular_mean(op, rv, rng, size, lower, c, upper):
237237
return maybe_resize((lower + upper + c) / 3, size)
238238

239239

240240
@_mean.register(UniformRV)
241-
def uniform_mean(op, rv, rng, size, dtype, lower, upper):
241+
def uniform_mean(op, rv, rng, size, lower, upper):
242242
return maybe_resize((lower + upper) / 2, size)
243243

244244

245245
@_mean.register(VonMisesRV)
246-
def vonmisses_mean(op, rv, rng, size, dtype, mu, kappa):
246+
def vonmisses_mean(op, rv, rng, size, mu, kappa):
247247
return maybe_resize(pt.broadcast_arrays(mu, kappa)[0], size)
248248

249249

250250
@_mean.register(KumaraswamyRV)
251-
def kumaraswamy_mean(op, rv, rng, size, dtype, a, b):
251+
def kumaraswamy_mean(op, rv, rng, size, a, b):
252252
return maybe_resize(
253253
pt.exp(pt.log(b) + pt.gammaln(1 + 1 / a) + pt.gammaln(b) - pt.gammaln(1 + 1 / a + b)),
254254
size,
255255
)
256256

257257

258258
@_mean.register(WaldRV)
259-
def wald_mean(op, rv, rng, size, dtype, mu, lam, alpha):
259+
def wald_mean(op, rv, rng, size, mu, lam, alpha):
260260
return maybe_resize(pt.broadcast_arrays(mu, lam, alpha)[0], size)
261261

262262

263263
@_mean.register(WeibullBetaRV)
264-
def weibull_mean(op, rv, rng, size, dtype, alpha, beta):
264+
def weibull_mean(op, rv, rng, size, alpha, beta):
265265
return maybe_resize(beta * pt.gamma(1 + 1 / alpha), size)
266266

267267

268268
@_mean.register(AsymmetricLaplaceRV)
269-
def asymmetric_laplace_mean(op, rv, rng, size, dtype, b, kappa, mu):
269+
def asymmetric_laplace_mean(op, rv, rng, size, b, kappa, mu):
270270
return maybe_resize(mu - (kappa - 1 / kappa) / b, size)
271271

272272

273273
@_mean.register(StudentTRV)
274-
def studentt_mean(op, rv, rng, size, dtype, nu, mu, sigma):
274+
def studentt_mean(op, rv, rng, size, nu, mu, sigma):
275275
return maybe_resize(pt.broadcast_arrays(mu, nu, sigma)[0], size)
276276

277277

278278
@_mean.register(HalfStudentTRV)
279-
def half_studentt_mean(op, rv, rng, size, dtype, nu, sigma):
279+
def half_studentt_mean(op, rv, rng, size, nu, sigma):
280280
return maybe_resize(
281281
pt.switch(
282282
nu > 1,
@@ -291,18 +291,18 @@ def half_studentt_mean(op, rv, rng, size, dtype, nu, sigma):
291291

292292

293293
@_mean.register(ExGaussianRV)
294-
def exgaussian_mean(op, rv, rng, size, dtype, mu, nu, sigma):
294+
def exgaussian_mean(op, rv, rng, size, mu, nu, sigma):
295295
mu, nu, _ = pt.broadcast_arrays(mu, nu, sigma)
296296
return maybe_resize(mu + nu, size)
297297

298298

299299
@_mean.register(SkewNormalRV)
300-
def skew_normal_mean(op, rv, rng, size, dtype, mu, sigma, alpha):
300+
def skew_normal_mean(op, rv, rng, size, mu, sigma, alpha):
301301
return maybe_resize(mu + sigma * (2 / np.pi) ** 0.5 * alpha / (1 + alpha**2) ** 0.5, size)
302302

303303

304304
@_mean.register(RiceRV)
305-
def rice_mean(op, rv, rng, size, dtype, nu, sigma):
305+
def rice_mean(op, rv, rng, size, nu, sigma):
306306
nu_sigma_ratio = -(nu**2) / (2 * sigma**2)
307307
return maybe_resize(
308308
sigma
@@ -317,22 +317,22 @@ def rice_mean(op, rv, rng, size, dtype, nu, sigma):
317317

318318

319319
@_mean.register(MoyalRV)
320-
def moyal_mean(op, rv, rng, size, dtype, mu, sigma):
320+
def moyal_mean(op, rv, rng, size, mu, sigma):
321321
return maybe_resize(mu + sigma * (np.euler_gamma + pt.log(2)), size)
322322

323323

324324
@_mean.register(PolyaGammaRV)
325-
def polya_gamma_mean(op, rv, rng, size, dtype, h, z):
325+
def polya_gamma_mean(op, rv, rng, size, h, z):
326326
return maybe_resize(pt.switch(pt.eq(z, 0), h / 4, tanh(z / 2) * (h / (2 * z))), size)
327327

328328

329329
@_mean.register(DiscreteUniformRV)
330-
def discrete_uniform_mean(op, rv, rng, size, dtype, lower, upper):
330+
def discrete_uniform_mean(op, rv, rng, size, lower, upper):
331331
return maybe_resize((upper + lower) / 2.0, size)
332332

333333

334334
@_mean.register(DiracDeltaRV)
335-
def dirac_delta_mean(op, rv, rng, size, dtype, c):
335+
def dirac_delta_mean(op, rv, rng, size, c):
336336
return maybe_resize(c, size)
337337

338338

@@ -355,12 +355,12 @@ def mixture_mean(op, rv, rng, weights, *components):
355355

356356

357357
@_mean.register(GeneralizedPoissonRV)
358-
def generalized_poisson_mean(op, rv, rng, size, dtype, mu, lam):
358+
def generalized_poisson_mean(op, rv, rng, size, mu, lam):
359359
return maybe_resize(mu / (1 - lam), size)
360360

361361

362362
@_mean.register(MvStudentTRV)
363-
def mvstudentt_mean(op, rv, rng, size, dtype, nu, mu, scale):
363+
def mvstudentt_mean(op, rv, rng, size, nu, mu, scale):
364364
mean = mu
365365
if not rv_size_is_none(size):
366366
mean_size = pt.concatenate([size, [mu.shape[-1]]])
@@ -369,7 +369,7 @@ def mvstudentt_mean(op, rv, rng, size, dtype, nu, mu, scale):
369369

370370

371371
@_mean.register(DirichletMultinomialRV)
372-
def dirichlet_multinomial_mean(op, rv, rng, size, dtype, n, a):
372+
def dirichlet_multinomial_mean(op, rv, rng, size, n, a):
373373
mean = pt.shape_padright(n) * a / pt.sum(a, axis=-1, keepdims=True)
374374
if not rv_size_is_none(size):
375375
output_size = pt.concatenate([size, [a.shape[-1]]])
@@ -386,17 +386,17 @@ def lkj_cholesky_cov_mean(op, rv, rng, n, eta, sd_dist):
386386

387387

388388
@_mean.register(LKJCorrRV)
389-
def lkj_corr_mean(op, rv, rng, size, dtype, *args):
389+
def lkj_corr_mean(op, rv, rng, size, *args):
390390
return pt.full_like(rv, pt.eye(rv.shape[-1]))
391391

392392

393393
@_mean.register(MatrixNormalRV)
394-
def matrix_normal_mean(op, rv, rng, size, dtype, mu, rowchol, colchol):
394+
def matrix_normal_mean(op, rv, rng, size, mu, rowchol, colchol):
395395
return pt.full_like(rv, mu)
396396

397397

398398
@_mean.register(KroneckerNormalRV)
399-
def kronecker_normal_mean(op, rv, rng, size, dtype, mu, covs, chols, evds):
399+
def kronecker_normal_mean(op, rv, rng, size, mu, covs, chols, evds):
400400
mean = mu
401401
if not rv_size_is_none(size):
402402
mean_size = pt.concatenate([size, mu.shape])
@@ -405,12 +405,12 @@ def kronecker_normal_mean(op, rv, rng, size, dtype, mu, covs, chols, evds):
405405

406406

407407
@_mean.register(CARRV)
408-
def car_mean(op, rv, rng, size, dtype, mu, W, alpha, tau):
408+
def car_mean(op, rv, rng, size, mu, W, alpha, tau):
409409
return pt.full_like(rv, mu)
410410

411411

412412
@_mean.register(StickBreakingWeightsRV)
413-
def stick_breaking_mean(op, rv, rng, size, dtype, alpha, K):
413+
def stick_breaking_mean(op, rv, rng, size, alpha, K):
414414
alpha = alpha[..., np.newaxis]
415415
mean = (alpha / (1 + alpha)) ** pt.arange(K)
416416
mean *= 1 / (1 + alpha)

0 commit comments

Comments
 (0)