Skip to content

Commit 6254348

Browse files
committed
bugfixes
1 parent 2221715 commit 6254348

File tree

8 files changed

+32
-35
lines changed

8 files changed

+32
-35
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.1.0"
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
88
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10+
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
1011
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1112
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/combinators/for.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ depending on `CartesianIndices(base)`.
2121
2222
```
2323
julia> For(3) do λ Exponential(λ) end |> marginals
24-
3-element mappedarray(MeasureTheory.var"#17#18"{var"#15#16"}(var"#15#16"()), ::CartesianIndices{1, Tuple{Base.OneTo{Int64}}}) with eltype Exponential{(:λ,), Tuple{Int64}}:
24+
3-element mappedarray(MeasureBase.var"#17#18"{var"#15#16"}(var"#15#16"()), ::CartesianIndices{1, Tuple{Base.OneTo{Int64}}}) with eltype Exponential{(:λ,), Tuple{Int64}}:
2525
Exponential(λ = 1,)
2626
Exponential(λ = 2,)
2727
Exponential(λ = 3,)
2828
```
2929
3030
```
3131
julia> For(4,3) do μ,σ Normal(μ,σ) end |> marginals
32-
4×3 mappedarray(MeasureTheory.var"#17#18"{var"#11#12"}(var"#11#12"()), ::CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}) with eltype Normal{(:μ, :σ), Tuple{Int64, Int64}}:
32+
4×3 mappedarray(MeasureBase.var"#17#18"{var"#11#12"}(var"#11#12"()), ::CartesianIndices{2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}) with eltype Normal{(:μ, :σ), Tuple{Int64, Int64}}:
3333
Normal(μ = 1, σ = 1) Normal(μ = 1, σ = 2) Normal(μ = 1, σ = 3)
3434
Normal(μ = 2, σ = 1) Normal(μ = 2, σ = 2) Normal(μ = 2, σ = 3)
3535
Normal(μ = 3, σ = 1) Normal(μ = 3, σ = 2) Normal(μ = 3, σ = 3)

src/combinators/likelihood.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Finally, let's return to the expression for Bayes's Law,
8484
``P(θ|x) ∝ P(θ) P(x|θ)``
8585
8686
The product on the right side is computed pointwise. To work with this in
87-
MeasureTheory, we have a "pointwise product" `⊙`, which takes a measure and a
87+
MeasureBase, we have a "pointwise product" `⊙`, which takes a measure and a
8888
likelihood, and returns a new measure, that is, the unnormalized posterior that has density ``P(θ) P(x|θ)`` with respect to the base measure of the prior.
8989
9090
For example, say we have

src/combinators/product.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,35 +135,35 @@ end
135135

136136

137137

138-
# @propagate_inbounds function MeasureTheory.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,1}}
138+
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,1}}
139139
# data = marginals(d)
140140
# @boundscheck size(data) == size(x) || throw(BoundsError)
141141
# @tullio s = logdensity(data[i], x[i])
142142
# s
143143
# end
144144

145-
# @propagate_inbounds function MeasureTheory.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,2}}
145+
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,2}}
146146
# data = marginals(d)
147147
# @boundscheck size(data) == size(x) || throw(BoundsError)
148148
# @tullio s = @inbounds logdensity(data[i,j], x[i,j])
149149
# s
150150
# end
151151

152-
# @propagate_inbounds function MeasureTheory.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,3}}
152+
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,3}}
153153
# data = marginals(d)
154154
# @boundscheck size(data) == size(x) || throw(BoundsError)
155155
# @tullio s = @inbounds logdensity(data[i,j,k], x[i,j,k])
156156
# s
157157
# end
158158

159-
# @propagate_inbounds function MeasureTheory.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,4}}
159+
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,4}}
160160
# data = marginals(d)
161161
# @boundscheck size(data) == size(x) || throw(BoundsError)
162162
# @tullio s = @inbounds logdensity(data[i,j,k,l], x[i,j,k,l])
163163
# s
164164
# end
165165

166-
# @propagate_inbounds function MeasureTheory.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,5}}
166+
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,5}}
167167
# data = marginals(d)
168168
# @boundscheck size(data) == size(x) || throw(BoundsError)
169169
# @tullio s = @inbounds logdensity(data[i,j,k,l,m], x[i,j,k,l,m])

src/combinators/weighted.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export WeightedMeasure
1+
export WeightedMeasure, AbstractWeightedMeasure
22

33
"""
44
struct WeightedMeasure{R,M} <: AbstractMeasure

src/domains.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
export IntegerRange
2+
13
abstract type AbstractDomain end
24

35
"""

src/macros.jl

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11

22
using MLStyle
33
using Random: AbstractRNG
4-
export @parameterized
4+
using KeywordCalls
5+
using ConstructionBase
6+
export @parameterized, @μσ_methods, @σ_methods, @half
57

68
# A fold over ASTs. Example usage in `replace`
79
function foldast(leaf, branch; kwargs...)
@@ -49,12 +51,12 @@ function _parameterized(__module__, expr)
4951

5052
# @gensym basename
5153
q = quote
52-
struct $μ{N,T} <: MeasureTheory.ParameterizedMeasure{N}
54+
struct $μ{N,T} <: MeasureBase.ParameterizedMeasure{N}
5355
par :: NamedTuple{N,T}
5456
end
5557

5658
const $μbase = $base
57-
MeasureTheory.basemeasure(::$μ) = $μbase
59+
MeasureBase.basemeasure(::$μ) = $μbase
5860
end
5961

6062
if !isempty(p)
@@ -70,7 +72,7 @@ function _parameterized(__module__, expr)
7072
μ = esc(μ)
7173

7274
q = quote
73-
struct $μ{N,T} <: MeasureTheory.ParameterizedMeasure{N}
75+
struct $μ{N,T} <: MeasureBase.ParameterizedMeasure{N}
7476
par :: NamedTuple{N,T}
7577
end
7678
end
@@ -88,7 +90,7 @@ function _parameterized(__module__, expr)
8890
μ = esc(μ)
8991

9092
q = quote
91-
struct $μ{N,T} <: MeasureTheory.ParameterizedMeasure{N}
93+
struct $μ{N,T} <: MeasureBase.ParameterizedMeasure{N}
9294
par :: NamedTuple{N,T}
9395
end
9496
end
@@ -141,15 +143,16 @@ end
141143
function _μσ_methods(__module__, ex)
142144
@match ex begin
143145
:($dist($(args...))) => begin
146+
144147
argnames = QuoteNode.(args)
145148

146149
d_args = (:(d.$arg) for arg in args)
147150

148151
method_μσ = KeywordCalls._kwstruct(__module__, :($dist($(args...), μ, σ)))
149152
method_μ = KeywordCalls._kwstruct(__module__, :($dist($(args...), μ)))
150153
method_σ = KeywordCalls._kwstruct(__module__, :($dist($(args...), σ)))
151-
152154
C = constructorof(getproperty(__module__, dist))
155+
dist = esc(dist)
153156
q = quote
154157

155158
$method_μσ
@@ -160,7 +163,7 @@ function _μσ_methods(__module__, ex)
160163
d.σ * rand(rng, T, $dist($(d_args...))) + d.μ
161164
end
162165

163-
function MeasureTheory.logdensity(d::$dist{($(argnames...), :μ, :σ)}, x)
166+
function MeasureBase.logdensity(d::$dist{($(argnames...), :μ, :σ)}, x)
164167
z = (x - d.μ) / d.σ
165168
return logdensity($dist($(d_args...)), z) - log(d.σ)
166169
end
@@ -169,7 +172,7 @@ function _μσ_methods(__module__, ex)
169172
d.σ * rand(rng, T, $dist($(d_args...)))
170173
end
171174

172-
function MeasureTheory.logdensity(d::$dist{($(argnames...), :σ)}, x)
175+
function MeasureBase.logdensity(d::$dist{($(argnames...), :σ)}, x)
173176
z = x / d.σ
174177
return logdensity($dist($(d_args...)), z) - log(d.σ)
175178
end
@@ -178,15 +181,10 @@ function _μσ_methods(__module__, ex)
178181
rand(rng, T, $dist($(d_args...))) + d.μ
179182
end
180183

181-
function MeasureTheory.logdensity(d::$dist{($(argnames...), :μ)}, x)
184+
function MeasureBase.logdensity(d::$dist{($(argnames...), :μ)}, x)
182185
z = x - d.μ
183186
return logdensity($dist($(d_args...)), z)
184187
end
185-
186-
187-
MeasureTheory.asparams(::Type{<: $C}, ::Val{:μ}) = asℝ
188-
MeasureTheory.asparams(::Type{<: $C}, ::Val{:σ}) = asℝ₊
189-
MeasureTheory.asparams(::Type{<: $C}, ::Val{:logσ}) = asℝ
190188
end
191189

192190
return q
@@ -208,14 +206,15 @@ function _σ_methods(__module__, ex)
208206

209207
method_σ = KeywordCalls._kwstruct(__module__, :($dist($(args...), σ)))
210208

209+
dist = esc(dist)
211210
q = quote
212211
$method_σ
213212

214213
function Base.rand(rng::AbstractRNG, T::Type, d::$dist{($(argnames...), :σ)})
215214
d.σ * rand(rng, T, $dist($(d_args...)))
216215
end
217216

218-
function MeasureTheory.logdensity(d::$dist{($(argnames...), :σ)}, x)
217+
function MeasureBase.logdensity(d::$dist{($(argnames...), :σ)}, x)
219218
z = x / d.σ
220219
return logdensity($dist($(d_args...)), z) - log(d.σ)
221220
end
@@ -241,34 +240,29 @@ creates `HalfNormal()`, and
241240
creates `HalfStudentT(ν)`.
242241
"""
243242
macro half(ex)
244-
esc(_half(__module__, ex))
243+
_half(__module__, ex)
245244
end
246245

247246
function _half(__module__, ex)
248247
@match ex begin
249248
:($dist($(args...))) => begin
250-
halfdist = Symbol(:Half, dist)
251-
252-
TV = TransformVariables
253-
249+
halfdist = esc(Symbol(:Half, dist))
250+
254251
quote
255-
256-
export $halfdist
257-
258252
struct $halfdist{N,T} <: ParameterizedMeasure{N}
259253
par :: NamedTuple{N,T}
260254
end
261255

262256
unhalf::$halfdist) = $dist(getfield(μ, :par))
263257

264-
function MeasureTheory.basemeasure::$halfdist)
258+
function MeasureBase.basemeasure::$halfdist)
265259
b = basemeasure(unhalf(μ))
266260
@assert basemeasure(b) == Lebesgue(ℝ)
267261
lw = b.logweight
268262
return WeightedMeasure(logtwo + lw, Lebesgue(ℝ₊))
269263
end
270264

271-
function MeasureTheory.logdensity::$halfdist, x)
265+
function MeasureBase.logdensity::$halfdist, x)
272266
return logdensity(unhalf(μ), x)
273267
end
274268

src/resettablerng.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Random
2-
2+
export ResettableRNG
33
struct ResettableRNG{R,S} <: Random.AbstractRNG
44
rng::R
55
seed::S

0 commit comments

Comments
 (0)