Skip to content

Commit 2ee1d7a

Browse files
authored
Merge pull request #510 from ztangent/20230627-ztangent-fix_load_generated_functions
Re-implement get_schema using type params.
2 parents 3287746 + 6604c09 commit 2ee1d7a

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

src/static_ir/trace.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ struct StaticIRTraceAssmt{T} <: ChoiceMap
66
trace::T
77
end
88

9-
function get_schema end
10-
119
@inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T)
1210

1311
@inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace)
@@ -35,7 +33,11 @@ static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap()
3533
# trace type generation #
3634
#########################
3735

38-
abstract type StaticIRTrace <: Trace end
36+
abstract type StaticIRTrace{T} <: Trace end
37+
38+
function get_schema(::Type{<:StaticIRTrace{T}}) where {T}
39+
StaticAddressSchema(Set{Symbol}(T))
40+
end
3941

4042
@inline function static_get_subtrace(trace::StaticIRTrace, addr)
4143
error("Not implemented")
@@ -124,7 +126,11 @@ function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options:
124126
mutable = false
125127
fields = get_trace_fields(ir, options)
126128
field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields)
127-
Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)),
129+
choice_addrs = [node.addr for node in ir.choice_nodes]
130+
call_addrs = [node.addr for node in ir.call_nodes]
131+
addrs = Tuple(vcat(choice_addrs, call_addrs))
132+
parent_type = Expr(:curly, QuoteNode(StaticIRTrace), addrs)
133+
Expr(:struct, mutable, Expr(:(<:), trace_struct_name, parent_type),
128134
Expr(:block, field_exprs..., Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any))))
129135
end
130136

@@ -271,17 +277,6 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol)
271277
methods
272278
end
273279

274-
function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol)
275-
choice_addrs = [QuoteNode(node.addr) for node in ir.choice_nodes]
276-
call_addrs = [QuoteNode(node.addr) for node in ir.call_nodes]
277-
addrs = vcat(choice_addrs, call_addrs)
278-
Expr(:function,
279-
Expr(:call, GlobalRef(Gen, :get_schema), :(::Type{$trace_struct_name})),
280-
Expr(:block,
281-
:($(QuoteNode(StaticAddressSchema))(
282-
Set{Symbol}([$(addrs...)])))))
283-
end
284-
285280
function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions)
286281
trace_struct_name = gensym("StaticIRTrace_$name")
287282
trace_struct_expr = generate_trace_struct(ir, trace_struct_name, options)
@@ -290,7 +285,6 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St
290285
get_args_expr = generate_get_args(ir, trace_struct_name)
291286
get_retval_expr = generate_get_retval(ir, trace_struct_name)
292287
get_choices_expr = generate_get_choices(trace_struct_name)
293-
get_schema_expr = generate_get_schema(ir, trace_struct_name)
294288
get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name)
295289
get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name)
296290
static_get_value_exprs = generate_static_get_value(ir, trace_struct_name)
@@ -299,10 +293,10 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St
299293
getindex_exprs = generate_getindex(ir, trace_struct_name)
300294

301295
exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr,
302-
get_args_expr, get_retval_expr,
303-
get_choices_expr, get_schema_expr, get_values_shallow_expr,
304-
get_submaps_shallow_expr, static_get_value_exprs...,
305-
static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...)
296+
get_args_expr, get_retval_expr, get_choices_expr,
297+
get_values_shallow_expr, get_submaps_shallow_expr,
298+
static_get_value_exprs..., static_has_value_exprs...,
299+
static_get_submap_exprs..., getindex_exprs...)
306300
(exprs, trace_struct_name)
307301
end
308302

test/dsl/static_dsl.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,18 @@ ch = get_choices(tr)
609609
@test length(get_values_shallow(ch)) == 1
610610
@test length(get_submaps_shallow(ch)) == 1
611611

612+
@gen (static) function baz(trace)
613+
x ~ normal(trace[:x], 0.1)
614+
return x
615+
end
616+
617+
ch, w, rval = propose(baz, (tr,))
618+
@test has_value(ch, :x)
619+
@test ch[:x] == rval
620+
621+
new_tr, _ = generate(bar1, (), ch)
622+
@test new_tr[:x] == ch[:x]
623+
612624
end
613625

614626
@testset "returning a SML function from macro" begin

0 commit comments

Comments
 (0)