|
54 | 54 | DistSpec(Poisson, (0.5,), 1), |
55 | 55 | DistSpec(Poisson, (0.5,), [1, 1]), |
56 | 56 |
|
57 | | - DistSpec(Skellam, (1.0, 2.0), -2; broken=(:Zygote,)), |
58 | | - DistSpec(Skellam, (1.0, 2.0), [-2, -2]; broken=(:Zygote,)), |
| 57 | + DistSpec(Skellam, (1.0, 2.0), -2), |
| 58 | + DistSpec(Skellam, (1.0, 2.0), [-2, -2]), |
59 | 59 |
|
60 | 60 | DistSpec(PoissonBinomial, ([0.5, 0.5],), 0), |
61 | 61 |
|
|
162 | 162 |
|
163 | 163 | DistSpec(NormalCanon, (1.0, 2.0), 0.5), |
164 | 164 |
|
165 | | - DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5; broken=(:Zygote,)), |
| 165 | + DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5), |
166 | 166 |
|
| 167 | + DistSpec(Pareto, (), 1.5), |
167 | 168 | DistSpec(Pareto, (1.0,), 1.5), |
168 | 169 | DistSpec(Pareto, (1.0, 1.0), 1.5), |
169 | 170 |
|
|
213 | 214 | # Stackoverflow caused by SpecialFunctions.besselix |
214 | 215 | DistSpec(VonMises, (1.0,), 1.0), |
215 | 216 | DistSpec(VonMises, (1, 1), 1), |
216 | | - |
217 | | - # Only some Zygote tests are broken and therefore this can not be checked |
218 | | - DistSpec(Pareto, (), 1.5; broken=(:Zygote,)), |
219 | | - |
| 217 | + |
220 | 218 | # Some tests are broken on some Julia versions, therefore it can't be checked reliably |
221 | 219 | DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)), |
222 | 220 | ] |
|
395 | 393 | # Skellam only fails in these tests with ReverseDiff |
396 | 394 | # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126 |
397 | 395 | # PoissonBinomial fails with Zygote |
| 396 | + # Matrix case does not work with Skellam: |
| 397 | + # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 |
398 | 398 | filldist_broken = if d.f(d.θ...) isa Skellam |
399 | | - (d.broken..., :ReverseDiff) |
| 399 | + ((d.broken..., :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff)) |
400 | 400 | elseif d.f(d.θ...) isa PoissonBinomial |
401 | | - (d.broken..., :Zygote) |
| 401 | + ((d.broken..., :Zygote), (d.broken..., :Zygote)) |
402 | 402 | else |
403 | | - d.broken |
| 403 | + (d.broken, d.broken) |
404 | 404 | end |
405 | | - arraydist_broken = if d.f(d.θ...) isa PoissonBinomial |
406 | | - (d.broken..., :Zygote) |
| 405 | + arraydist_broken = if d.f(d.θ...) isa Skellam |
| 406 | + (d.broken, (d.broken..., :Zygote)) |
| 407 | + elseif d.f(d.θ...) isa PoissonBinomial |
| 408 | + ((d.broken..., :Zygote), (d.broken..., :Zygote)) |
407 | 409 | else |
408 | | - d.broken |
| 410 | + (d.broken, d.broken) |
409 | 411 | end |
410 | 412 |
|
411 | 413 | # Create `filldist` distribution |
|
416 | 418 | f_arraydist = (θ...,) -> arraydist([d.f(θ...) for _ in 1:n]) |
417 | 419 | d_arraydist = f_arraydist(d.θ...) |
418 | 420 |
|
419 | | - for sz in ((n,), (n, 2)) |
| 421 | + for (i, sz) in enumerate(((n,), (n, 2))) |
420 | 422 | # Matrix case doesn't work for continuous distributions for some reason |
421 | 423 | # now but not too important (?!) |
422 | 424 | if length(sz) == 2 && Distributions.value_support(typeof(d)) === Continuous |
|
434 | 436 | d.θ, |
435 | 437 | x, |
436 | 438 | d.xtrans; |
437 | | - broken=filldist_broken, |
| 439 | + broken=filldist_broken[i], |
438 | 440 | ) |
439 | 441 | ) |
440 | 442 | test_ad( |
|
444 | 446 | d.θ, |
445 | 447 | x, |
446 | 448 | d.xtrans; |
447 | | - broken=arraydist_broken, |
| 449 | + broken=arraydist_broken[i], |
448 | 450 | ) |
449 | 451 | ) |
450 | 452 | end |
|
0 commit comments