Skip to content

Commit 2062afc

Browse files
committed
improved rewrite
1 parent bc64775 commit 2062afc

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

src/methods/rule_based/rule2.jl

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ function check_expr_r(data::SymsType, rule::Expr, matches::MatchDict)
8282
end
8383

8484
# check expression of all arguments
85+
# TODO add types
8586
function ceoaa(arg_data, arg_rule, matches::MatchDict)
86-
println(typeof(arg_data), typeof(arg_rule))
8787
for (a, b) in zip(arg_data, arg_rule)
8888
matches = check_expr_r(a, b, matches)
8989
matches===FAIL_DICT && return FAIL_DICT::MatchDict
@@ -105,39 +105,41 @@ function check_expr_r(data::SymsType, rule::Real, matches::MatchDict)
105105
end
106106

107107
"""
108-
matches is the dictionary
109-
rhs is the expression to be rewritten into
110-
111-
TODO investigate foo in rhs not working
108+
recursively traverse the rhs, and if it finds a expression like:
109+
Expr
110+
head: Symbol call
111+
args: Array{Any}((2,))
112+
1: Symbol ~
113+
2: Symbol m
114+
substitute it with the value found in matches dictionary.
112115
"""
113-
function rewrite(matches::MatchDict, rhs::Expr)::SymsType
114-
if rhs.head != :call
115-
error("It happened") #it should never happen
116-
end
117-
# rhs is a slot or defslot
116+
function rewrite(matches::MatchDict, rhs::Expr)::Union{Expr, SymsType}
117+
# println("called rewrite with rhs ", rhs)
118+
# if a expression of a slot, change it with the matches
118119
if rhs.head == :call && rhs.args[1] == :(~)
119120
var_name = rhs.args[2]
120121
if haskey(matches, var_name)
121-
return matches[var_name]
122+
return matches[var_name]::SymsType
122123
else
123124
error("No match found for variable $(var_name)") #it should never happen
124125
end
125126
end
126-
# rhs is a call, reconstruct it
127-
op = eval(rhs.args[1])
128-
args = SymsType[]
129-
for a in rhs.args[2:end]
130-
push!(args, rewrite(matches, a))
131-
end
132-
return op(args...)
127+
# otherwise call recursively on arguments and then reconstruct expression
128+
args = [rewrite(matches, a) for a in rhs.args]
129+
return Expr(rhs.head, args...)::Expr
133130
end
134131

135-
function rewrite(matches::MatchDict, rhs::Real)::SymsType
136-
return rhs
137-
end
132+
# called every time in the rhs::Expr there is a symbol like
133+
# - custom function names (contains_var, ...)
134+
# - normal functions names (+, ^, ...)
135+
# - nothing
136+
rewrite(matches::MatchDict, rhs::Symbol) = rhs::Symbol
137+
# called each time in the rhs there is a real (like +1 or -2)
138+
rewrite(matches::MatchDict, rhs::Real) = rhs::Real
138139

139140
function rule2(rule::Pair{Expr, Expr}, exp::SymsType)::Union{SymsType, Nothing}
140141
m = check_expr_r(exp, rule.first, NO_MATCHES)
141-
m===FAIL_DICT && return nothing
142-
return rewrite(m, rule.second)
143-
end
142+
m===FAIL_DICT && return nothing::Nothing
143+
r = rewrite(m, rule.second)
144+
return eval(r)::SymsType
145+
end

0 commit comments

Comments
 (0)