@@ -103,27 +103,27 @@ def maybe_resize(a: TensorVariable, size) -> TensorVariable:
103103
104104
105105@_mean .register (BernoulliRV )
106- def bernoulli_mean (op , rv , rng , size , dtype , p ):
106+ def bernoulli_mean (op , rv , rng , size , p ):
107107 return maybe_resize (p , size )
108108
109109
110110@_mean .register (BetaBinomialRV )
111- def betabinomial_mean (op , rv , rng , size , dtype , n , alpha , beta ):
111+ def betabinomial_mean (op , rv , rng , size , n , alpha , beta ):
112112 return maybe_resize ((n * alpha ) / (alpha + beta ), size )
113113
114114
115115@_mean .register (BetaClippedRV )
116- def beta_clipped_mean (op , rv , rng , size , dtype , alpha , beta ):
116+ def beta_clipped_mean (op , rv , rng , size , alpha , beta ):
117117 return maybe_resize (alpha / (alpha + beta ), size )
118118
119119
120120@_mean .register (BetaRV )
121- def beta_mean (op , rv , rng , size , dtype , alpha , beta ):
121+ def beta_mean (op , rv , rng , size , alpha , beta ):
122122 return maybe_resize (alpha / (alpha + beta ), size )
123123
124124
125125@_mean .register (BinomialRV )
126- def binomial_mean (op , rv , rng , size , dtype , n , p ):
126+ def binomial_mean (op , rv , rng , size , n , p ):
127127 return maybe_resize (n * p , size )
128128
129129
@@ -132,7 +132,7 @@ def binomial_mean(op, rv, rng, size, dtype, n, p):
132132
133133
134134@_mean .register (DirichletRV )
135- def dirichlet_mean (op , rv , rng , size , dtype , a ):
135+ def dirichlet_mean (op , rv , rng , size , a ):
136136 norm_constant = pt .sum (a , axis = - 1 )[..., None ]
137137 mean = a / norm_constant
138138 if not rv_size_is_none (size ):
@@ -141,60 +141,60 @@ def dirichlet_mean(op, rv, rng, size, dtype, a):
141141
142142
143143@_mean .register (ExponentialRV )
144- def exponential_mean (op , rv , rng , size , dtype , mu ):
144+ def exponential_mean (op , rv , rng , size , mu ):
145145 return maybe_resize (mu , size )
146146
147147
148148@_mean .register (GammaRV )
149- def gamma_mean (op , rv , rng , size , dtype , alpha , inv_beta ):
149+ def gamma_mean (op , rv , rng , size , alpha , inv_beta ):
150150 # The Aesara `GammaRV` `Op` inverts the `beta` parameter itself
151151 return maybe_resize (alpha * inv_beta , size )
152152
153153
154154@_mean .register (GeometricRV )
155- def geometric_mean (op , rv , rng , size , dtype , p ):
155+ def geometric_mean (op , rv , rng , size , p ):
156156 return maybe_resize (1.0 / p , size )
157157
158158
159159@_mean .register (GumbelRV )
160- def gumbel_mean (op , rv , rng , size , dtype , mu , beta ):
160+ def gumbel_mean (op , rv , rng , size , mu , beta ):
161161 return maybe_resize (mu + beta * np .euler_gamma , size )
162162
163163
164164@_mean .register (HalfNormalRV )
165- def halfnormal_mean (op , rv , rng , size , dtype , loc , sigma ):
165+ def halfnormal_mean (op , rv , rng , size , loc , sigma ):
166166 _ , sigma = pt .broadcast_arrays (loc , sigma )
167167 return maybe_resize (sigma * pt .sqrt (2 / np .pi ), size )
168168
169169
170170@_mean .register (HyperGeometricRV )
171- def hypergeometric_mean (op , rv , rng , size , dtype , good , bad , n ):
171+ def hypergeometric_mean (op , rv , rng , size , good , bad , n ):
172172 N , k = good + bad , good
173173 return maybe_resize (n * k / N , size )
174174
175175
176176@_mean .register (InvGammaRV )
177- def invgamma_mean (op , rv , rng , size , dtype , alpha , beta ):
177+ def invgamma_mean (op , rv , rng , size , alpha , beta ):
178178 return maybe_resize (pt .switch (alpha > 1 , beta / (alpha - 1.0 ), np .nan ), size )
179179
180180
181181@_mean .register (LaplaceRV )
182- def laplace_mean (op , rv , rng , size , dtype , mu , b ):
182+ def laplace_mean (op , rv , rng , size , mu , b ):
183183 return maybe_resize (pt .broadcast_arrays (mu , b )[0 ], size )
184184
185185
186186@_mean .register (LogisticRV )
187- def logistic_mean (op , rv , rng , size , dtype , mu , s ):
187+ def logistic_mean (op , rv , rng , size , mu , s ):
188188 return maybe_resize (pt .broadcast_arrays (mu , s )[0 ], size )
189189
190190
191191@_mean .register (LogNormalRV )
192- def lognormal_mean (op , rv , rng , size , dtype , mu , sigma ):
192+ def lognormal_mean (op , rv , rng , size , mu , sigma ):
193193 return maybe_resize (pt .exp (mu + 0.5 * sigma ** 2 ), size )
194194
195195
196196@_mean .register (MultinomialRV )
197- def multinomial_mean (op , rv , rng , size , dtype , n , p ):
197+ def multinomial_mean (op , rv , rng , size , n , p ):
198198 n = pt .shape_padright (n )
199199 mean = n * p
200200 if not rv_size_is_none (size ):
@@ -204,7 +204,7 @@ def multinomial_mean(op, rv, rng, size, dtype, n, p):
204204
205205
206206@_mean .register (MvNormalRV )
207- def mvnormal_mean (op , rv , rng , size , dtype , mu , cov ):
207+ def mvnormal_mean (op , rv , rng , size , mu , cov ):
208208 mean = mu
209209 if not rv_size_is_none (size ):
210210 mean_size = pt .concatenate ([size , [mu .shape [- 1 ]]])
@@ -213,70 +213,70 @@ def mvnormal_mean(op, rv, rng, size, dtype, mu, cov):
213213
214214
215215@_mean .register (NegBinomialRV )
216- def negative_binomial_mean (op , rv , rng , size , dtype , n , p ):
216+ def negative_binomial_mean (op , rv , rng , size , n , p ):
217217 return maybe_resize (n * (1 - p ) / p , size )
218218
219219
220220@_mean .register (NormalRV )
221- def normal_mean (op , rv , rng , size , dtype , mu , sigma ):
221+ def normal_mean (op , rv , rng , size , mu , sigma ):
222222 return maybe_resize (pt .broadcast_arrays (mu , sigma )[0 ], size )
223223
224224
225225@_mean .register (ParetoRV )
226- def pareto_mean (op , rv , rng , size , dtype , alpha , m ):
226+ def pareto_mean (op , rv , rng , size , alpha , m ):
227227 return maybe_resize (pt .switch (alpha > 1 , alpha * m / (alpha - 1 ), np .nan ), size )
228228
229229
230230@_mean .register (PoissonRV )
231- def poisson_mean (op , rv , rng , size , dtype , mu ):
231+ def poisson_mean (op , rv , rng , size , mu ):
232232 return maybe_resize (mu , size )
233233
234234
235235@_mean .register (TriangularRV )
236- def triangular_mean (op , rv , rng , size , dtype , lower , c , upper ):
236+ def triangular_mean (op , rv , rng , size , lower , c , upper ):
237237 return maybe_resize ((lower + upper + c ) / 3 , size )
238238
239239
240240@_mean .register (UniformRV )
241- def uniform_mean (op , rv , rng , size , dtype , lower , upper ):
241+ def uniform_mean (op , rv , rng , size , lower , upper ):
242242 return maybe_resize ((lower + upper ) / 2 , size )
243243
244244
245245@_mean .register (VonMisesRV )
246- def vonmisses_mean (op , rv , rng , size , dtype , mu , kappa ):
246+ def vonmisses_mean (op , rv , rng , size , mu , kappa ):
247247 return maybe_resize (pt .broadcast_arrays (mu , kappa )[0 ], size )
248248
249249
250250@_mean .register (KumaraswamyRV )
251- def kumaraswamy_mean (op , rv , rng , size , dtype , a , b ):
251+ def kumaraswamy_mean (op , rv , rng , size , a , b ):
252252 return maybe_resize (
253253 pt .exp (pt .log (b ) + pt .gammaln (1 + 1 / a ) + pt .gammaln (b ) - pt .gammaln (1 + 1 / a + b )),
254254 size ,
255255 )
256256
257257
258258@_mean .register (WaldRV )
259- def wald_mean (op , rv , rng , size , dtype , mu , lam , alpha ):
259+ def wald_mean (op , rv , rng , size , mu , lam , alpha ):
260260 return maybe_resize (pt .broadcast_arrays (mu , lam , alpha )[0 ], size )
261261
262262
263263@_mean .register (WeibullBetaRV )
264- def weibull_mean (op , rv , rng , size , dtype , alpha , beta ):
264+ def weibull_mean (op , rv , rng , size , alpha , beta ):
265265 return maybe_resize (beta * pt .gamma (1 + 1 / alpha ), size )
266266
267267
268268@_mean .register (AsymmetricLaplaceRV )
269- def asymmetric_laplace_mean (op , rv , rng , size , dtype , b , kappa , mu ):
269+ def asymmetric_laplace_mean (op , rv , rng , size , b , kappa , mu ):
270270 return maybe_resize (mu - (kappa - 1 / kappa ) / b , size )
271271
272272
273273@_mean .register (StudentTRV )
274- def studentt_mean (op , rv , rng , size , dtype , nu , mu , sigma ):
274+ def studentt_mean (op , rv , rng , size , nu , mu , sigma ):
275275 return maybe_resize (pt .broadcast_arrays (mu , nu , sigma )[0 ], size )
276276
277277
278278@_mean .register (HalfStudentTRV )
279- def half_studentt_mean (op , rv , rng , size , dtype , nu , sigma ):
279+ def half_studentt_mean (op , rv , rng , size , nu , sigma ):
280280 return maybe_resize (
281281 pt .switch (
282282 nu > 1 ,
@@ -291,18 +291,18 @@ def half_studentt_mean(op, rv, rng, size, dtype, nu, sigma):
291291
292292
293293@_mean .register (ExGaussianRV )
294- def exgaussian_mean (op , rv , rng , size , dtype , mu , nu , sigma ):
294+ def exgaussian_mean (op , rv , rng , size , mu , nu , sigma ):
295295 mu , nu , _ = pt .broadcast_arrays (mu , nu , sigma )
296296 return maybe_resize (mu + nu , size )
297297
298298
299299@_mean .register (SkewNormalRV )
300- def skew_normal_mean (op , rv , rng , size , dtype , mu , sigma , alpha ):
300+ def skew_normal_mean (op , rv , rng , size , mu , sigma , alpha ):
301301 return maybe_resize (mu + sigma * (2 / np .pi ) ** 0.5 * alpha / (1 + alpha ** 2 ) ** 0.5 , size )
302302
303303
304304@_mean .register (RiceRV )
305- def rice_mean (op , rv , rng , size , dtype , nu , sigma ):
305+ def rice_mean (op , rv , rng , size , nu , sigma ):
306306 nu_sigma_ratio = - (nu ** 2 ) / (2 * sigma ** 2 )
307307 return maybe_resize (
308308 sigma
@@ -317,22 +317,22 @@ def rice_mean(op, rv, rng, size, dtype, nu, sigma):
317317
318318
319319@_mean .register (MoyalRV )
320- def moyal_mean (op , rv , rng , size , dtype , mu , sigma ):
320+ def moyal_mean (op , rv , rng , size , mu , sigma ):
321321 return maybe_resize (mu + sigma * (np .euler_gamma + pt .log (2 )), size )
322322
323323
324324@_mean .register (PolyaGammaRV )
325- def polya_gamma_mean (op , rv , rng , size , dtype , h , z ):
325+ def polya_gamma_mean (op , rv , rng , size , h , z ):
326326 return maybe_resize (pt .switch (pt .eq (z , 0 ), h / 4 , tanh (z / 2 ) * (h / (2 * z ))), size )
327327
328328
329329@_mean .register (DiscreteUniformRV )
330- def discrete_uniform_mean (op , rv , rng , size , dtype , lower , upper ):
330+ def discrete_uniform_mean (op , rv , rng , size , lower , upper ):
331331 return maybe_resize ((upper + lower ) / 2.0 , size )
332332
333333
334334@_mean .register (DiracDeltaRV )
335- def dirac_delta_mean (op , rv , rng , size , dtype , c ):
335+ def dirac_delta_mean (op , rv , rng , size , c ):
336336 return maybe_resize (c , size )
337337
338338
@@ -355,12 +355,12 @@ def mixture_mean(op, rv, rng, weights, *components):
355355
356356
357357@_mean .register (GeneralizedPoissonRV )
358- def generalized_poisson_mean (op , rv , rng , size , dtype , mu , lam ):
358+ def generalized_poisson_mean (op , rv , rng , size , mu , lam ):
359359 return maybe_resize (mu / (1 - lam ), size )
360360
361361
362362@_mean .register (MvStudentTRV )
363- def mvstudentt_mean (op , rv , rng , size , dtype , nu , mu , scale ):
363+ def mvstudentt_mean (op , rv , rng , size , nu , mu , scale ):
364364 mean = mu
365365 if not rv_size_is_none (size ):
366366 mean_size = pt .concatenate ([size , [mu .shape [- 1 ]]])
@@ -369,7 +369,7 @@ def mvstudentt_mean(op, rv, rng, size, dtype, nu, mu, scale):
369369
370370
371371@_mean .register (DirichletMultinomialRV )
372- def dirichlet_multinomial_mean (op , rv , rng , size , dtype , n , a ):
372+ def dirichlet_multinomial_mean (op , rv , rng , size , n , a ):
373373 mean = pt .shape_padright (n ) * a / pt .sum (a , axis = - 1 , keepdims = True )
374374 if not rv_size_is_none (size ):
375375 output_size = pt .concatenate ([size , [a .shape [- 1 ]]])
@@ -386,17 +386,17 @@ def lkj_cholesky_cov_mean(op, rv, rng, n, eta, sd_dist):
386386
387387
388388@_mean .register (LKJCorrRV )
389- def lkj_corr_mean (op , rv , rng , size , dtype , * args ):
389+ def lkj_corr_mean (op , rv , rng , size , * args ):
390390 return pt .full_like (rv , pt .eye (rv .shape [- 1 ]))
391391
392392
393393@_mean .register (MatrixNormalRV )
394- def matrix_normal_mean (op , rv , rng , size , dtype , mu , rowchol , colchol ):
394+ def matrix_normal_mean (op , rv , rng , size , mu , rowchol , colchol ):
395395 return pt .full_like (rv , mu )
396396
397397
398398@_mean .register (KroneckerNormalRV )
399- def kronecker_normal_mean (op , rv , rng , size , dtype , mu , covs , chols , evds ):
399+ def kronecker_normal_mean (op , rv , rng , size , mu , covs , chols , evds ):
400400 mean = mu
401401 if not rv_size_is_none (size ):
402402 mean_size = pt .concatenate ([size , mu .shape ])
@@ -405,12 +405,12 @@ def kronecker_normal_mean(op, rv, rng, size, dtype, mu, covs, chols, evds):
405405
406406
407407@_mean .register (CARRV )
408- def car_mean (op , rv , rng , size , dtype , mu , W , alpha , tau ):
408+ def car_mean (op , rv , rng , size , mu , W , alpha , tau ):
409409 return pt .full_like (rv , mu )
410410
411411
412412@_mean .register (StickBreakingWeightsRV )
413- def stick_breaking_mean (op , rv , rng , size , dtype , alpha , K ):
413+ def stick_breaking_mean (op , rv , rng , size , alpha , K ):
414414 alpha = alpha [..., np .newaxis ]
415415 mean = (alpha / (1 + alpha )) ** pt .arange (K )
416416 mean *= 1 / (1 + alpha )
0 commit comments