Skip to content

Commit 2de604d

Browse files
committed
Restructure synthesis experiments
1 parent c47c333 commit 2de604d

File tree

13 files changed

+240
-136
lines changed

13 files changed

+240
-136
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
DynamicalSystems = "61744808-ddfa-5f27-97ff-6e42cc95d634"
1010
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
11+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
12+
HerbConstraints = "1fa96474-3206-4513-b4fa-23913f296dfc"
1113
HerbCore = "2b23ba43-8213-43cb-b5ea-38c12b45bd45"
1214
HerbGrammar = "4ef9e186-2fe5-4b24-8de7-9f7291f24af7"
1315
HerbSearch = "3008d8e8-f9aa-438a-92ed-26e9c7b4829f"
@@ -21,6 +23,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2123
DocStringExtensions = "0.9.3"
2224
DynamicalSystems = "3"
2325
FileIO = "1"
26+
HerbConstraints = "0.2.4"
2427
HerbSearch = "0.4.1"
2528
MLStyle = "0.4.17"
2629
MetaGraphsNext = "0.7"

experiments/Synth/Manifest.toml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
julia_version = "1.10.2"
44
manifest_format = "2.0"
5-
project_hash = "19cda9b05e0584d6a0dc95f7351f7b83848951bf"
5+
project_hash = "db74a993eef408d9b910e97f9a72ec5fee351591"
66

77
[[deps.ADTypes]]
88
git-tree-sha1 = "72af59f5b8f09faee36b4ec48e014a79210f2f4f"
@@ -936,10 +936,8 @@ uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
936936
version = "2.82.4+0"
937937

938938
[[deps.GraphDynamicalSystems]]
939-
deps = ["AbstractTrees", "DocStringExtensions", "DynamicalSystems", "FileIO", "HerbCore", "HerbGrammar", "HerbSearch", "MLStyle", "MetaGraphsNext", "Random", "SoleLogics", "Statistics"]
940-
git-tree-sha1 = "8ca1dbca624e42094220e1d15da56da999e82b1b"
941-
repo-rev = "feat/qn"
942-
repo-url = "../.."
939+
deps = ["AbstractTrees", "DocStringExtensions", "DynamicalSystems", "FileIO", "Graphs", "HerbConstraints", "HerbCore", "HerbGrammar", "HerbSearch", "MLStyle", "MetaGraphsNext", "Random", "SoleLogics", "Statistics"]
940+
path = "../.."
943941
uuid = "13529e2e-ed53-56b1-bd6f-420b01fca819"
944942
version = "0.2.0"
945943

@@ -1004,9 +1002,7 @@ version = "0.1.6"
10041002

10051003
[[deps.HerbSearch]]
10061004
deps = ["DataStructures", "HerbConstraints", "HerbCore", "HerbGrammar", "HerbInterpret", "HerbSpecification", "MLStyle", "Random", "StatsBase"]
1007-
git-tree-sha1 = "913c9d8549076d4e13dd5255656177b6893ce6c1"
1008-
repo-rev = "compathelper/new_version/2024-12-13-01-52-34-511-01072321669"
1009-
repo-url = "https://github.com/Herb-AI/HerbSearch.jl.git"
1005+
git-tree-sha1 = "95a5c1e87cd61b14cf9785f293e5633b39a69fc5"
10101006
uuid = "3008d8e8-f9aa-438a-92ed-26e9c7b4829f"
10111007
version = "0.4.1"
10121008

experiments/Synth/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
77
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
88
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
99
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
10+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1011
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
1112
DynamicalSystems = "61744808-ddfa-5f27-97ff-6e42cc95d634"
1213
Git = "d7ba0133-e1db-5d97-8f8c-041e4b3a1eb2"
@@ -39,6 +40,7 @@ SoleLogics = "b002da8f-3cb3-4d91-bbe3-2953433912b5"
3940
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
4041
Term = "22787eb5-b846-44ae-b979-8e399b8463ab"
4142
TidierData = "fe2206b3-d496-4ee9-a338-6a095c4ece80"
43+
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
4244
XML = "72c71f33-b9b6-44de-8c94-c961784809e2"
4345

4446
[compat]

experiments/Synth/scripts/synth_biodivine.jl

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,73 @@ end
2121
@everywhere quickactivate(pwd())
2222
@everywhere using Synth
2323

24-
@everywhere using ProgressMeter, DataFrames, HerbSearch, GraphDynamicalSystems
24+
@everywhere using ProgressMeter, DataFrames, HerbSearch, GraphDynamicalSystems, Random
25+
using MetaGraphsNext: labels
2526

26-
res = collect_results(datadir("sims", "biodivine_split"))
27-
res.ID = ((x -> x[end-1]["id"]) parse_savename).(res.path)
28-
rename!(res, :path => "Trajectory Path")
29-
mg_df = collect_results(datadir("src_parsed", "biodivine_benchmark_as_metagraphs");)
30-
mg_df.ID = ((x -> parse(Int, x)) (x -> x[1]) splitext basename).(mg_df.path)
31-
rename!(mg_df, :path => "Model Path")
27+
traj_df = collect_results(datadir("sims", "biodivine_split"))
28+
path2id = path -> parse_savename(path)[end-1]["id"]
29+
traj_df.ID = path2id.(traj_df.path)
3230

33-
res = innerjoin(res, mg_df, on = :ID)
31+
model_df = collect_results(datadir("src_parsed", "biodivine_benchmark_as_metagraphs");)
32+
path2id = path -> parse(Int, splitext(basename(path))[1])
33+
model_df.ID = path2id.(model_df.path)
34+
model_df.vertex = collect.(labels.(model_df.metagraph_model))
35+
# add a copy so that after flattening we have all of the vertices of a model in each row of df
36+
model_df.vertices = model_df.vertex
37+
38+
# Filter only smaller models
39+
# model_df = model_df[length.(model_df.vertices).<15, :]
40+
41+
per_vertex_df = flatten(model_df, :vertex)
42+
43+
grammars_df = model_df[!, [:ID, :vertices]]
44+
grammars_df.dnf_grammar = build_dnf_grammar.(grammars_df.vertices)
45+
grammars_df.qn_grammar = build_qn_grammar.(grammars_df.vertices)
46+
47+
get_evaluator = g -> Dict([:DNF => evaluate_bn, :QN => evaluate_qn])[g]
48+
49+
function get_grammar(unique, grammar_type)
50+
s = :unknown
51+
if grammar_type == :DNF
52+
s = :dnf_grammar
53+
elseif grammar_type == :QN
54+
s = :qn_grammar
55+
end
56+
57+
return only(grammars_df[grammars_df.ID.==unique.ID, s])
58+
end
59+
60+
function select_trajectories(df, N, id, vertex, seed)
61+
Random.seed!(seed)
62+
selected_trajectories = rand(only(df[df.ID.==id, :split_traj]), N)
63+
filtered_on_vertex =
64+
reduce(union, map(x -> get(x, vertex, Set()), selected_trajectories))
65+
66+
return filtered_on_vertex
67+
end
3468

3569
synth_params = Dict(
3670
"seed" => 42,
3771
"max_depth" => 6,
38-
"id" => res.ID,
72+
"unique" => collect(eachrow(per_vertex_df[!, [:ID, :vertex, :vertices]])),
73+
"id" => Derived("unique", x -> x.ID),
74+
"vertex_names" => Derived("unique", x -> getfield.(x.vertices, :value)),
75+
"index_of_vertex" => Derived("unique", x -> findfirst(==(x.vertex), x.vertices)),
76+
"vertex" => Derived("unique", x -> string(x.vertex.value)),
3977
"n_trajectories" => collect(10:45:110),
78+
"selected_trajectories" => Derived(
79+
["n_trajectories", "unique", "index_of_vertex", "seed"],
80+
(N, unique, index_of_vertex, seed) ->
81+
select_trajectories(traj_df, N, unique.ID, index_of_vertex, seed),
82+
),
4083
"iterator_type" => [BFSIterator],
41-
"grammar_builder" => [build_dnf_grammar, build_qn_grammar],
84+
"iter_name" => Derived("iterator_type", string),
85+
"grammar_type" => [:DNF, :QN],
86+
"grammar" => Derived(["unique", "grammar_type"], get_grammar),
87+
"evaluator" => Derived("grammar_type", get_evaluator),
4288
"max_iterations" => 1_000_000,
4389
)
4490

4591
@showprogress pmap(dict_list(synth_params)) do params
46-
synth_one_biodivine(params, res)
92+
@produce_or_load(synth_one_vertex, params, datadir("exp_raw", "biodivine_search"))
4793
end

experiments/Synth/src/Synth.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,20 @@ include("gather_bn_data.jl")
2727

2828
export gather_bn_data, split_state_space, get_split_state_space
2929

30-
include("create_problem.jl")
30+
include("synth_process.jl")
3131

32-
export examples_to_problem
32+
export synth, synth_biodivine
33+
34+
include("undirected_specification.jl")
35+
36+
export UndirectedExample, UndirectedProblem
3337

3438
include("evaluator.jl")
3539

36-
export evaluate_bn, interpret
40+
export evaluate_bn, evaluate_qn, interpret
3741

38-
include("synth_process.jl")
42+
include("create_problem.jl")
3943

40-
export synth, synth_biodivine
44+
export examples_to_problem
4145

4246
end

experiments/Synth/src/biodivine_benchmark.jl

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -146,49 +146,17 @@ function convert_aeon_models_to_metagraphs(excluded_files = Regex[])
146146
end
147147
end
148148

149-
function synth_one_biodivine(outer_params::AbstractDict{String,Any}, res::DataFrame)
150-
params = deepcopy(outer_params)
151-
@unpack seed = params
152-
Random.seed!(seed)
153-
154-
@unpack n_trajectories, id = params
155-
selected_trajs = rand(only(res[res.ID.==id, :split_traj]), n_trajectories)
156-
157-
merged_selected_trajs = reduce(mergewith(union), selected_trajs)
158-
159-
@unpack id, grammar_builder = params
160-
@info "Synthsizing for model $id with $n_trajectories traj."
161-
162-
model = only(res[res.ID.==id, :metagraph_model])
163-
grammar = grammar_builder(nv(model))
164-
165-
@showprogress map(collect(merged_selected_trajs)) do (vertex, examples)
166-
@info "Synthesizing model $id, node $vertex, $n_trajectories traj."
167-
save_data = deepcopy(params)
168-
delete!(save_data, "specifications")
169-
save_data["grammar"] = grammar
170-
save_data["vertex"] = vertex
171-
save_data["examples"] = examples
172-
file_name = savename(save_data)
173-
@produce_or_load(
174-
synth_one_vertex,
175-
save_data,
176-
datadir("exp_raw", "biodivine_search");
177-
filename = file_name
178-
)
179-
@info "Completed synthesis for model $id, node $vertex, $n_trajectories traj."
180-
end
181-
end
182-
183-
function synth_one_vertex(save_data)
184-
@unpack vertex, examples = save_data
185-
problem = examples_to_problem(vertex, examples)
149+
function synth_one_vertex(params::AbstractDict{String,Any})
150+
@unpack vertex, selected_trajectories = params
151+
problem = examples_to_problem(vertex, selected_trajectories)
186152

187-
@unpack max_depth, iterator_type, max_iterations, grammar = save_data
153+
@unpack max_depth, iterator_type, max_iterations, grammar, evaluator, vertex_names =
154+
params
188155
iterator = iterator_type(grammar, :Start, max_depth = max_depth)
189-
exprs_and_scores = synth_biodivine(problem, iterator, grammar, max_iterations)
156+
exprs_and_scores =
157+
synth_biodivine(problem, iterator, grammar, max_iterations, evaluator, vertex_names)
190158

191159
# Save output
192-
save_data["exprs_and_scores"] = exprs_and_scores
193-
return save_data
160+
params["exprs_and_scores"] = exprs_and_scores
161+
return params
194162
end
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
function examples_to_problem(node, examples)
2-
io_examples =
3-
map(((in, out),) -> IOExample(Dict([:state => in]), out), collect(examples))
4-
problem = Problem("$node", io_examples)
1+
function _e2p(n, e)
2+
_pair_to_undirected(p) = UndirectedExample(Dict(:state => p[1]), Dict(:state => p[2]))
53

6-
return problem
4+
return UndirectedProblem(n, _pair_to_undirected.(e))
75
end
6+
7+
examples_to_problem(node::Integer, examples::AbstractSet) = _e2p(string(node), examples)
8+
examples_to_problem(node::Atom, examples) = _e2p(node.value, examples)
9+
examples_to_problem(node::AbstractString, examples) = _e2p(node, examples)

experiments/Synth/src/evaluator.jl

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,73 @@ function interpret(φ::Expr, i::SoleLogics.AbstractInterpretation, args...; kwar
2929
return interpret(syntax_branch, i, args...; kwargs...)
3030
end
3131

32-
function evaluate_bn(problem, expr)
33-
sat_examples = 0
32+
function interpret(
33+
e::Union{AbstractString,Integer,Expr},
34+
qn_state::AbstractVector{<:Integer},
35+
vertex_names::AbstractVector{<:AbstractString},
36+
)
37+
state_map = Dict(zip(vertex_names, deepcopy(qn_state)))
3438

35-
for example problem.spec
36-
truth = TruthDict(Dict(enumerate(example.in[:state])))
39+
_int(e) = @match e begin
40+
::AbstractString => state_map[e]
41+
::Integer => e
42+
:($v1 + $v2) => _int(v1) + _int(v2)
43+
:($v1 - $v2) => _int(v1) - _int(v2)
44+
:($v1 / $v2) => _int(v1) / _int(v2)
45+
:($v1 * $v2) => _int(v1) * _int(v2)
46+
:(Min($v1, $v2)) => min(_int(v1), _int(v2))
47+
:(Max($v1, $v2)) => max(_int(v1), _int(v2))
48+
:(Ceil($v)) => ceil(_int(v))
49+
:(Floor($v)) => floor(_int(v))
50+
_ => error("Unhandled Expr in `interpret`: $e, $(typeof(e))")
51+
end
52+
53+
return _int(e)
54+
end
55+
56+
function evaluate_bn(problem::UndirectedProblem, expr, vertex_names)
57+
sat_examples = BitVector[]
58+
59+
function _eval_1_dir(in, out)
60+
truth = TruthDict(Dict(zip(vertex_names, in[:state])))
3761
res = interpret(expr, truth)
38-
sat_examples += res.flag == example.out
62+
expected = BooleanTruth(out[:state][findfirst(==(problem.name), vertex_names)])
63+
success = expected == res
64+
65+
return success
66+
end
67+
68+
for example problem.examples
69+
success_direction1 = _eval_1_dir(example.data1, example.data2)
70+
success_direction2 = _eval_1_dir(example.data2, example.data1)
71+
72+
success = BitVector([success_direction1, success_direction2])
73+
74+
push!(sat_examples, success)
75+
end
76+
77+
return sat_examples
78+
end
79+
80+
function evaluate_qn(problem::UndirectedProblem, expr, vertex_names)
81+
sat_examples = BitVector[]
82+
83+
function _eval_1_dir(in, out)
84+
res = interpret(expr, in[:state], vertex_names)
85+
expected = out[:state][findfirst(==(problem.name), vertex_names)]
86+
success = expected == res
87+
88+
return success
89+
end
90+
91+
for example problem.examples
92+
success_direction1 = _eval_1_dir(example.data1, example.data2)
93+
success_direction2 = _eval_1_dir(example.data2, example.data1)
94+
95+
success = BitVector([success_direction1, success_direction2])
96+
97+
push!(sat_examples, success)
3998
end
4099

41-
return sat_examples / length(problem.spec)
100+
return sat_examples
42101
end

experiments/Synth/src/gather_bn_data.jl

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,24 @@ end
77

88
function split_state_space(trajectory::StateSpaceSet)
99
# split into pairs of input (all values) and output (changed value)
10-
input_output_pairs_per_node = Dict{Int,Set{Tuple{Vector{Int},Int}}}()
10+
#
11+
input_output_pairs_per_node = Dict{Int,Set{Tuple{Vector{Int},Vector{Int}}}}()
1112
for i = 1:length(trajectory)-1
1213
changed = findfirst(trajectory[i+1] .!= trajectory[i])
1314
# only proceed if there was a change
1415
if !isnothing(changed)
1516

1617
# in real data we don't know the direction of the transition
1718
# was it from i -> i+1 or i+1 -> i, we only know that two
18-
# states are adjacent, so for gathering data, we add both
19-
# directions as IO pairs
20-
# 1. state `i` and the new value of the single variable in
21-
# the state that changed
22-
# 2. state `i+1` and the previous value of the single variable
23-
# in the state that changed
19+
# states are adjacent, so for synthesis, we want to add the pair
20+
# of i and i+1, and test possible programs on both.
2421

25-
new_value = trajectory[i+1][changed]
26-
old_value = trajectory[i][changed]
27-
existing_pairs =
28-
get(input_output_pairs_per_node, changed, Set{Tuple{Vector{Int},Int}}())
29-
push!(
30-
existing_pairs,
31-
(trajectory[i], new_value), # 1
32-
(trajectory[i+1], old_value), # 2
22+
existing_pairs = get(
23+
input_output_pairs_per_node,
24+
changed,
25+
Set{Tuple{Vector{Int},Vector{Int}}}(),
3326
)
27+
push!(existing_pairs, (trajectory[i], trajectory[i+1]))
3428
input_output_pairs_per_node[changed] = existing_pairs
3529
end
3630
end

0 commit comments

Comments
 (0)