Skip to content

Commit 71a6f4a

Browse files
committed
fix: deepcopy grammar constraint domains
1 parent cac9500 commit 71a6f4a

File tree

4 files changed

+94
-29
lines changed

4 files changed

+94
-29
lines changed

LocalPreferences.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[GraphDynamicalSystems]
2+
dispatch_doctor_mode = "debug"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GraphDynamicalSystems"
22
uuid = "13529e2e-ed53-56b1-bd6f-420b01fca819"
33
authors = ["Reuben Gardos Reid <[email protected]>"]
4-
version = "0.0.1"
4+
version = "0.0.2"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/qualitative_networks.jl

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import SciMLBase
33

44
using AbstractTrees: Leaves
55
using DynamicalSystemsBase: ArbitrarySteppable, current_parameters, initial_state
6-
using HerbConstraints: DomainRuleNode, Forbidden, Ordered, VarNode, addconstraint!
6+
using HerbConstraints: DomainRuleNode, Forbidden, Ordered, Unique, VarNode, addconstraint!
77
using HerbCore: AbstractGrammar, RuleNode, get_rule
88
using HerbGrammar: @csgrammar, add_rule!, rulenode2expr
99
using HerbSearch: rand
@@ -22,24 +22,34 @@ const base_qn_grammar = @csgrammar begin
2222
Val = Floor(Val)
2323
end
2424

25-
const default_qn_constants = [2]
25+
const default_qn_constants = [0, 1, 2]
2626

2727
"""
2828
$(TYPEDSIGNATURES)
2929
3030
Builds a grammar based on the base QN grammar adding `entity_names` and `constants`
3131
to the grammar.
3232
33-
Four constraints are currently included
33+
The following constraints are currently included
3434
3535
1. removing symmetry due to commutativity of `+`/`*`/`min`/`max`
3636
2. forbidding same arguments of two argument functions
37-
3. forbidding trivial inputs (consts and entity values) to `floor`/`ceil`
38-
4. forbidding `ceil(floor(_))` and `floor(ceil(_))`
37+
3. forbidding constant arguments to 2-argument functions
38+
4. forbidding constant arguments to 1-argument functions
39+
5. using each of the entities only once per function
40+
6. forbidding adding or subtracting zero
41+
7. forbidding multiplication and division by 1 or 0
42+
8. forcing the first operator inside `ceil` and `floor` to be `÷`
43+
9. forbidding `max(□, X)` and `min(□, X)` where X is either the max or min
44+
constant in the grammar.
3945
4046
"""
41-
function build_qn_grammar(entity_names, constants = default_qn_constants)
42-
g = deepcopy(GraphDynamicalSystems.base_qn_grammar)
47+
function build_qn_grammar(
48+
entity_names,
49+
constants = default_qn_constants;
50+
unique_constr = true,
51+
)
52+
g = deepcopy(base_qn_grammar)
4353

4454
for e in entity_names
4555
add_rule!(g, :(Val = $e))
@@ -57,40 +67,98 @@ function build_qn_grammar(entity_names, constants = default_qn_constants)
5767
template_tree = DomainRuleNode(domain, [VarNode(:a), VarNode(:b)])
5868
order = [:a, :b]
5969

60-
addconstraint!(g, Ordered(template_tree, order))
70+
addconstraint!(g, Ordered(deepcopy(template_tree), order))
6171

6272
# Forbid same arguments for 2-argument functions
6373
domain = BitVector(zeros(length(g.rules)))
6474
@. domain[length(g.childtypes)==2] = true
6575
template_tree = DomainRuleNode(domain, [VarNode(:a), VarNode(:a)])
6676

67-
addconstraint!(g, Forbidden(template_tree))
77+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
6878

69-
# Forbid Ceil and Floor from including an entity or constant directly
70-
domain = BitVector(zeros(length(g.rules)))
71-
n_original_rules = length(GraphDynamicalSystems.base_qn_grammar.rules)
72-
domain[[n_original_rules+1:length(g.rules)...]] .= true
79+
# Forbid constant arguments for 2-argument functions
80+
domain = falses(length(g.rules))
81+
@. domain[length(g.childtypes)==2] = true
82+
consts_domain = falses(length(g.rules))
83+
consts_domain[findall(x -> x isa Int, g.rules)] .= true
84+
consts_domain_rn = DomainRuleNode(consts_domain)
85+
template_tree = DomainRuleNode(domain, [consts_domain_rn, consts_domain_rn])
7386

74-
entities_consts = DomainRuleNode(domain)
87+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
7588

76-
domain = BitVector(zeros(length(g.rules)))
77-
domain[[7, 8]] .= true
89+
# Forbid constant arguments for 1-argument functions
90+
domain = falses(length(g.rules))
91+
@. domain[[7, 8]] = true
92+
consts_domain = falses(length(g.rules))
93+
consts_domain[findall(x -> x isa Int, g.rules)] .= true
94+
consts_domain_rn = DomainRuleNode(consts_domain)
95+
template_tree = DomainRuleNode(domain, [consts_domain_rn])
96+
97+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
98+
99+
n_original_rules = length(base_qn_grammar.rules)
100+
101+
# Only use each of the entities once per function
102+
n_consts = length(constants)
103+
entities = n_original_rules+1:length(g.rules)-n_consts
104+
105+
if unique_constr
106+
addconstraint!.((g,), Unique.(entities))
107+
end
108+
109+
# Forbid □ + 0, □ - 0
110+
plus_or_minus = falses(length(g.rules))
111+
plus_or_minus[[1, 2]] .= true
112+
zero_rule = findfirst(==(0), g.rules)
113+
if !isnothing(zero_rule)
114+
template_tree = DomainRuleNode(plus_or_minus, [VarNode(:a), RuleNode(zero_rule)])
78115

79-
template_tree = DomainRuleNode(domain, [entities_consts])
116+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
80117

81-
addconstraint!(g, Forbidden(template_tree))
118+
# Both orderings, but only for plus. Allow 0 - □
119+
plus_or_minus[2] = false
120+
template_tree = DomainRuleNode(plus_or_minus, [RuleNode(zero_rule), VarNode(:a)])
121+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
122+
end
123+
124+
# Forbid □ * 1, □ / 1, □ * 0, □ / 0
125+
mult_or_div = falses(length(g.rules))
126+
mult_or_div[[3, 4]] .= true
127+
one_zero_domain = falses(length(g.rules))
128+
one_zero_domain[findfirst(==(1), g.rules)] = true
129+
if !isnothing(findfirst(==(0), g.rules))
130+
one_zero_domain[findfirst(==(0), g.rules)] = true
131+
end
132+
133+
template_tree =
134+
DomainRuleNode(mult_or_div, [VarNode(:a), DomainRuleNode(one_zero_domain)])
82135

83-
# Forbid ceil(floor(x)) and vice-versa
136+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
137+
138+
# Forbid ceil(X) and floor(X) unless X = □ ÷ □
84139
ceil_or_floor = BitVector(zeros(length(g.rules)))
85140
ceil_or_floor[[7, 8]] .= true
86-
template_tree =
87-
DomainRuleNode(ceil_or_floor, [DomainRuleNode(ceil_or_floor, [VarNode(:a)])])
141+
all_except_div = trues(length(g.rules))
142+
all_except_div[3] = false
143+
template_tree = DomainRuleNode(ceil_or_floor, [DomainRuleNode(all_except_div)])
144+
145+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
88146

89-
addconstraint!(g, Forbidden(template_tree))
147+
# Forbid max(□, X) and min(□, X) where X is either the largest or smallest constant in the grammar
148+
min_max_rules = falses(length(g.rules))
149+
min_max_rules[[5, 6]] .= true
150+
(min_const, max_const) = extrema(filter(x -> isa(x, Int), g.rules))
151+
extrema_domain = falses(length(g.rules))
152+
extrema_domain[findall(x -> x == min_const || x == max_const, g.rules)] .= true
153+
rule_extrema_consts = DomainRuleNode(extrema_domain)
154+
template_tree = DomainRuleNode(min_max_rules, [VarNode(:a), rule_extrema_consts])
155+
156+
addconstraint!(g, Forbidden(deepcopy(template_tree)))
90157

91158
return g
92159
end
93160

161+
94162
struct Entity{I}
95163
target_function::Any
96164
# _f::Any

test/qn_test.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,11 @@ end
1515

1616
@testitem "QN Grammar Creation" begin
1717
entities = [:a, :b, :c]
18-
constants = [i for i = 1:10]
18+
constants = [i for i = 0:10]
1919
g = build_qn_grammar(entities, constants)
2020

2121
@test issubset(Set(entities), Set(g.rules))
2222
@test issubset(Set(constants), Set(g.rules))
23-
24-
g2 = build_qn_grammar(Symbol[], Integer[])
25-
26-
@test isempty(intersect(Set(g2.rules), Set(entities)))
27-
@test isempty(intersect(Set(g2.rules), Set(constants)))
2823
end
2924

3025
@testitem "QN Sampling" setup = [RandomSetup, ExampleQN] begin

0 commit comments

Comments
 (0)