Skip to content

Commit beb10a4

Browse files
skoghoernbvdmitri
andauthored
GammaShapeScale - changed the dispatch version of default_parametrization for Gamma (#526)
* Updated default_parametrization function for Gamma and GammaShapeScale without kwargs * Update initialization_plugin_tests.jl - added tests for gamma pos args * Update src/model/graphppl.jl --------- Co-authored-by: Bagaev Dmitry <bvdmitri@gmail.com>
1 parent 95283b8 commit beb10a4

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

src/model/graphppl.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,10 @@ GraphPPL.interface_aliases(::ReactiveMPGraphPPLBackend, ::Type{MvNormal}) = Grap
262262

263263
GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :θ)}) = ExponentialFamily.GammaShapeScale
264264
GraphPPL.factor_alias(::ReactiveMPGraphPPLBackend, ::Type{Gamma}, ::GraphPPL.StaticInterfaces{(:α, :β)}) = ExponentialFamily.GammaShapeRate
265-
GraphPPL.default_parametrization(::ReactiveMPGraphPPLBackend, ::GraphPPL.Atomic, ::Type{Gamma}, rhs) = error(
266-
"`Gamma` cannot be constructed without keyword arguments. Use `Gamma(shape = ..., rate = ...)` or `Gamma(shape = ..., scale = ...)`."
267-
)
265+
GraphPPL.default_parametrization(backend::ReactiveMPGraphPPLBackend, nodetype::GraphPPL.Atomic, factor::Type{Gamma}, rhs) = begin
266+
@warn "'Gamma' and 'GammaShapeScale' without keywords are constructed with parameters (Shape, Scale)." maxlog=1
267+
return GraphPPL.default_parametrization(backend, nodetype, ReactiveMP.inputinterfaces(factor), factor, rhs)
268+
end
268269

269270
GraphPPL.interface_aliases(::ReactiveMPGraphPPLBackend, ::Type{Gamma}) = GraphPPL.StaticInterfaceAliases((
270271
(:a, ), (:shape, ), (:β⁻¹, ), (:scale, ), (:θ⁻¹, ), (:rate, )

test/model/initialization_plugin_tests.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,3 +716,33 @@ end
716716
@test occursin("q(x)", repr(init))
717717
@test occursin("NormalMeanVariance", repr(init))
718718
end
719+
720+
@testitem "gamma pos args warn and construct shape-scale" begin
721+
using RxInfer
722+
import GraphPPL: create_model, with_plugins
723+
724+
@model function gamma_only()
725+
s ~ Gamma(shape = 2.0, rate = 1.0)
726+
end
727+
728+
# Gamma without kwargs
729+
init_gamma = @initialization begin
730+
q(s) = Gamma(1, 1)
731+
end
732+
model = create_model(with_plugins(gamma_only(), GraphPPL.PluginsCollection(RxInfer.InitializationPlugin(init_gamma))))
733+
node = GraphPPL.getcontext(model)[:s]
734+
@test GraphPPL.hasextra(model[node], RxInfer.InitMarExtraKey)
735+
@test GraphPPL.getextra(model[node], RxInfer.InitMarExtraKey) == Distributions.Gamma(1.0, 1.0)
736+
@test occursin("Gamma{Float64}(α=1.0, θ=1.0)", repr(GraphPPL.getextra(model[node], RxInfer.InitMarExtraKey)))
737+
738+
# GammaShapeScale without kwargs
739+
init_gss = @initialization begin
740+
q(s) = GammaShapeScale(1, 1)
741+
end
742+
743+
model = create_model(with_plugins(gamma_only(), GraphPPL.PluginsCollection(RxInfer.InitializationPlugin(init_gss))))
744+
node = GraphPPL.getcontext(model)[:s]
745+
@test GraphPPL.hasextra(model[node], RxInfer.InitMarExtraKey)
746+
@test GraphPPL.getextra(model[node], RxInfer.InitMarExtraKey) == Distributions.Gamma(1.0, 1.0)
747+
@test occursin("Gamma{Float64}(α=1.0, θ=1.0)", repr(GraphPPL.getextra(model[node], RxInfer.InitMarExtraKey)))
748+
end

0 commit comments

Comments
 (0)