Skip to content

Commit 01315ec

Browse files
authored
Merge pull request #55 from Herb-AI/fix/enforce-constraint-domain
Fix/enforce constraint domain
2 parents f739b2e + 904db89 commit 01315ec

File tree

5 files changed

+49
-2
lines changed

5 files changed

+49
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HerbCore"
22
uuid = "2b23ba43-8213-43cb-b5ea-38c12b45bd45"
33
authors = ["Jaap de Jong <jaapdejong15@gmail.com>", "Nicolae Filat <N.Filat@student.tudelft.nl>", "Tilman Hinnerichs <t.r.hinnerichs@tudelft.nl>", "Sebastijan Dumancic <s.dumancic@tudelft.nl>"]
4-
version = "0.3.10"
4+
version = "0.3.11"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

src/HerbCore.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ export
3535
have_same_shape, AbstractConstraint,
3636
AbstractGrammar,
3737
print_tree,
38-
update_rule_indices!
38+
update_rule_indices!,
39+
is_domain_valid
3940

4041
end # module HerbCore

src/indexing.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,15 @@
44
Updates the rule indices of the given rule node, hole or grammar constraint when the grammar changes.
55
"""
66
function update_rule_indices! end
7+
8+
"""
9+
is_domain_valid(x, n_rules::Integer)
10+
is_domain_valid(x, grammar::AbstractGrammar)
11+
12+
Check if the domain for the given object `x` (ex: [`RuleNode`](@ref),
13+
[`Hole`](@ref) or [`AbstractConstraint`](@ref)) is valid given the provided
14+
grammar or number of rules.
15+
16+
If [`isfilled`](@ref)`(x)` and `x` has children, it checks if all children are valid.
17+
"""
18+
function is_domain_valid end

src/rulenode.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ function update_rule_indices!(
9090
end
9191
end
9292

93+
"""
94+
is_domain_valid(node::RuleNode, n_rules::Integer)
95+
96+
Check whether the `node`'s rule index exceeds the number of rules `n_rules.`
97+
"""
98+
function is_domain_valid(node::RuleNode, n_rules::Integer)
99+
if get_rule(node) > n_rules
100+
return false
101+
end
102+
all(child -> is_domain_valid(child, n_rules), get_children(node))
103+
end
104+
93105
"""
94106
AbstractHole <: AbstractRuleNode
95107
@@ -120,6 +132,18 @@ end
120132

121133
UniformHole(domain) = UniformHole(domain, AbstractRuleNode[])
122134

135+
"""
136+
is_domain_valid(hole::AbstractHole, n_rules::Integer)
137+
138+
Check if `hole`'s domain length matches `n_rules`.
139+
"""
140+
function is_domain_valid(hole::AbstractHole, n_rules::Integer)
141+
if length(hole.domain) != n_rules
142+
return false
143+
end
144+
all(child -> is_domain_valid(child, n_rules), get_children(hole))
145+
end
146+
123147
"""
124148
update_rule_indices!(node::AbstractHole, n_rules::Integer)
125149

test/test_rulenode.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,4 +503,14 @@
503503
end
504504
end
505505
end
506+
@testset "is_domain_valid" begin
507+
node = @rulenode 1{2, 3, 4{5{7}, 6{9{10}}}}
508+
n_rules = 10
509+
@test is_domain_valid(node, n_rules) == true
510+
@test is_domain_valid(node, 9) == false
511+
512+
hole = UniformHole(BitVector((1, 1, 0, 0)), [RuleNode(3), RuleNode(4)])
513+
@test is_domain_valid(hole, 9) == false
514+
@test is_domain_valid(hole, 4) == true
515+
end
506516
end

0 commit comments

Comments
 (0)