Skip to content

Commit ebec467

Browse files
committed
minor
1 parent 87ced72 commit ebec467

File tree

3 files changed

+16
-19
lines changed

3 files changed

+16
-19
lines changed

src/probprog/Modeling.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,9 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where
243243

244244
trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1]))
245245

246+
trace.rng = rng
246247
trace.fn = f
247248
trace.args = args
248-
trace.rng = rng
249249

250250
return trace, trace.weight
251251
end
@@ -355,7 +355,7 @@ function generate(
355355

356356
constraint_ptr = ConcreteRNumber(reinterpret(UInt64, pointer_from_objref(constraint)))
357357

358-
constrained_addresses = _extract_addresses(constraint)
358+
constrained_addresses = extract_addresses(constraint)
359359

360360
function wrapper_fn(rng, constraint_ptr, args...)
361361
return generate_internal(rng, f, args...; constraint_ptr, constrained_addresses)
@@ -372,19 +372,13 @@ function generate(
372372

373373
trace = unsafe_pointer_to_objref(Ptr{Any}(Array(trace)[1]))
374374

375+
trace.rng = rng
375376
trace.fn = f
376377
trace.args = args
377-
trace.rng = rng
378378

379379
return trace, trace.weight
380380
end
381381

382-
function _extract_addresses(constraint::Constraint)
383-
addresses = Set(keys(constraint))
384-
385-
return addresses
386-
end
387-
388382
function generate_internal(
389383
rng::AbstractRNG,
390384
f::Function,

src/probprog/Types.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
using Base: ReentrantLock
22

33
mutable struct ProbProgTrace
4-
fn::Union{Nothing,Function}
5-
args::Union{Nothing,Tuple}
64
choices::Dict{Symbol,Any}
75
retval::Any
86
weight::Any
97
subtraces::Dict{Symbol,Any}
108
rng::Union{Nothing,AbstractRNG}
11-
12-
function ProbProgTrace(fn::Function, args::Tuple)
13-
return new(
14-
fn, args, Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing
15-
)
16-
end
9+
fn::Union{Nothing,Function}
10+
args::Union{Nothing,Tuple}
1711

1812
function ProbProgTrace()
1913
return new(
20-
nothing, (), Dict{Symbol,Any}(), nothing, nothing, Dict{Symbol,Any}(), nothing
14+
Dict{Symbol,Any}(),
15+
nothing,
16+
nothing,
17+
Dict{Symbol,Any}(),
18+
nothing,
19+
nothing,
20+
nothing,
2121
)
2222
end
2323
end
@@ -27,6 +27,7 @@ struct Address
2727

2828
Address(path::Vector{Symbol}) = new(path)
2929
end
30+
3031
Address(sym::Symbol) = Address([sym])
3132
Address(syms::Symbol...) = Address([syms...])
3233

@@ -66,6 +67,8 @@ Base.isempty(c::Constraint) = isempty(c.dict)
6667
Base.haskey(c::Constraint, k::Address) = haskey(c.dict, k)
6768
Base.get(c::Constraint, k::Address, default) = get(c.dict, k, default)
6869

70+
extract_addresses(constraint::Constraint) = Set(keys(constraint))
71+
6972
const Selection = Set{Symbol}
7073
const CompiledFnCache = Dict{Tuple{Type,Set{Symbol}},Any}
7174

test/probprog/generate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ end
105105

106106
constraint1 = ProbProg.Constraint(:s => (fill(0.1, shape),))
107107

108-
constrained_addresses = ProbProg._extract_addresses(constraint1)
108+
constrained_addresses = ProbProg.extract_addresses(constraint1)
109109

110110
constraint_ptr1 = Reactant.ConcreteRNumber(
111111
reinterpret(UInt64, pointer_from_objref(constraint1))

0 commit comments

Comments
 (0)