Skip to content

Commit 3d78a70

Browse files
authored
Merge pull request #123 from JuliaSymbolics/s/array-impl
WIP: Supporting functions for array symbolics
2 parents 9bf4e62 + e982660 commit 3d78a70

File tree

6 files changed

+52
-17
lines changed

6 files changed

+52
-17
lines changed

src/abstractalgebra.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,16 @@ let
5959
@rule(zero(~x) => 0)
6060
@rule(one(~x) => 1)]
6161

62+
simterm(x, f, args;metadata=nothing) = similarterm(x,f,args, symtype(x); metadata=metadata)
6263
mpoly_rules = [@rule(~x::ismpoly - ~y::ismpoly => ~x + -1 * (~y))
6364
@rule(-(~x) => -1 * ~x)
6465
@acrule(~x::ismpoly + ~y::ismpoly => ~x + ~y)
6566
@rule(+(~x) => ~x)
6667
@acrule(~x::ismpoly * ~y::ismpoly => ~x * ~y)
6768
@rule(*(~x) => ~x)
6869
@rule((~x::ismpoly)^(~a::isnonnegint) => (~x)^(~a))]
69-
global const MPOLY_CLEANUP = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_preprocess))))
70-
MPOLY_MAKER = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_rules))))
70+
global const MPOLY_CLEANUP = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_preprocess)), similarterm=simterm))
71+
MPOLY_MAKER = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_rules)), similarterm=simterm))
7172

7273
global to_mpoly
7374
function to_mpoly(t, dicts=_dicts())

src/api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ function substitute(expr, dict; fold=true)
6161
else
6262
args = map(x->substitute(x, dict), arguments(expr))
6363
end
64-
similarterm(expr, operation(expr), args, metadata=metadata(expr))
64+
similarterm(expr, operation(expr), args, symtype(expr), metadata=metadata(expr))
6565
else
6666
expr
6767
end

src/methods.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,8 @@ promote_symtype(::typeof(ifelse), _, ::Type{T}, ::Type{S}) where {T,S} = Union{T
150150
# Specially handle inv and literal pow
151151
Base.inv(x::Symbolic{<:Number}) = Base.:^(x, -1)
152152
Base.literal_pow(::typeof(^), x::Symbolic{<:Number}, ::Val{p}) where {p} = Base.:^(x, p)
153+
154+
# Array-like operations
155+
Base.size(x::Symbolic{<:Number}) = ()
156+
Base.length(x::Symbolic{<:Number}) = 1
157+
Base.ndims(x::Symbolic{<:Number}) = 0

src/ordering.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,22 @@ end
4646

4747
<(a::Sym, b::Sym) = a.name < b.name
4848

49+
<(a::Function, b::Function) = nameof(a) <nameof(b)
50+
4951
function cmp_term_term(a, b)
5052
la = arglength(a)
5153
lb = arglength(b)
5254

5355
if la == 0 && lb == 0
54-
return nameof(operation(a)) <nameof(operation(b))
56+
return operation(a) <operation(b)
5557
elseif la === 0
5658
return operation(a) <ₑ b
5759
elseif lb === 0
5860
return a <operation(b)
5961
end
6062

61-
na = nameof(operation(a))
62-
nb = nameof(operation(b))
63+
na = operation(a)
64+
nb = operation(b)
6365

6466
if 0 < arglength(a) <= 2 && 0 < arglength(b) <= 2
6567
# e.g. a < sin(a) < b ^ 2 < b

src/rule.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@ function makepattern(expr, keys)
6969
else
7070
:(term($(map(x->makepattern(x, keys), expr.args)...); type=Any))
7171
end
72+
elseif expr.head === :ref
73+
:(term(getindex, $(map(x->makepattern(x, keys), expr.args)...); type=Any))
7274
elseif expr.head === :$
7375
return esc(expr.args[1])
7476
else
75-
error("Unsupported Expr of type $(expr.head) found in pattern")
77+
Expr(expr.head, makepattern.(expr.args, (keys,))...)
7678
end
7779
else
7880
# treat as a literal
@@ -327,7 +329,7 @@ function (acr::ACRule)(term)
327329
if !isnothing(result)
328330
# Assumption: inds are unique
329331
length(args) == length(inds) && return result
330-
return similarterm(term, f, [result, (args[i] for i in eachindex(args) if i inds)...])
332+
return similarterm(term, f, [result, (args[i] for i in eachindex(args) if i inds)...], symtype(term))
331333
end
332334
end
333335
end

src/types.jl

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ Base.show(io::IO, v::Sym) = Base.show_unquoted(io, v.name)
196196
# Maybe don't even need a new type, can just use Sym{FnType}
197197
struct FnType{X<:Tuple,Y} end
198198

199-
(f::Sym{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, [args...])
199+
(f::Symbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, [args...])
200200

201-
function (f::Sym)(args...)
201+
function (f::Symbolic)(args...)
202202
error("Sym $f is not callable. " *
203203
"Use @syms $f(var1, var2,...) to create it as a callable. " *
204204
"See ?@fun for more options")
@@ -210,7 +210,7 @@ end
210210
The output symtype of applying variable `f` to arugments of symtype `arg_symtypes...`.
211211
if the arguments are of the wrong type then this function will error.
212212
"""
213-
function promote_symtype(f::Sym{FnType{X,Y}}, args...) where {X, Y}
213+
function promote_symtype(f::Symbolic{FnType{X,Y}}, args...) where {X, Y}
214214
if X === Tuple
215215
return Y
216216
end
@@ -250,8 +250,10 @@ macro syms(xs...)
250250
defs = map(xs) do x
251251
n, t = _name_type(x)
252252
:($(esc(n)) = Sym{$(esc(t))}($(Expr(:quote, n))))
253+
nt = _name_type(x)
254+
n, t = nt.name, nt.type
255+
:($(esc(n)) = Sym{$(esc(t))}($(Expr(:quote, n))))
253256
end
254-
255257
Expr(:block, defs...,
256258
:(tuple($(map(x->esc(_name_type(x).name), xs)...))))
257259
end
@@ -275,14 +277,20 @@ function _name_type(x)
275277
else
276278
return (name=lhs, type=rhs)
277279
end
280+
elseif x isa Expr && x.head === :ref
281+
ntype = _name_type(x.args[1]) # a::Number
282+
N = length(x.args)-1
283+
return (name=ntype.name,
284+
type=:(Array{$(ntype.type), $N}),
285+
array_metadata=:(Base.Slice.(($(x.args[2:end]...),))))
278286
elseif x isa Expr && x.head === :call
279287
return _name_type(:($x::Number))
280288
else
281289
syms_syntax_error()
282290
end
283291
end
284292

285-
function Base.show(io::IO, f::Sym{<:FnType{X,Y}}) where {X,Y}
293+
function Base.show(io::IO, f::Symbolic{<:FnType{X,Y}}) where {X,Y}
286294
print(io, f.name)
287295
# Use `Base.unwrap_unionall` to handle `Tuple{T} where T`. This is not the
288296
# best printing, but it's better than erroring.
@@ -433,7 +441,7 @@ setargs(t, args) = Term{symtype(t)}(operation(t), args)
433441
cdrargs(args) = setargs(t, cdr(args))
434442

435443
print_arg(io, x::Union{Complex, Rational}; paren=true) = print(io, "(", x, ")")
436-
isbinop(f) = istree(f) && Base.isbinaryoperator(nameof(operation(f)))
444+
isbinop(f) = istree(f) && !istree(operation(f)) && Base.isbinaryoperator(nameof(operation(f)))
437445
function print_arg(io, x; paren=false)
438446
if paren && isbinop(x)
439447
print(io, "(", x, ")")
@@ -506,8 +514,23 @@ function show_mul(io, args)
506514
end
507515
end
508516

517+
function show_ref(io, f, args)
518+
x = args[1]
519+
idx = args[2:end]
520+
521+
istree(x) && print(io, "(")
522+
print(io, x)
523+
istree(x) && print(io, ")")
524+
print(io, "[")
525+
for i=1:length(idx)
526+
print_arg(io, idx[i])
527+
i != length(idx) && print(io, ", ")
528+
end
529+
print(io, "]")
530+
end
531+
509532
function show_call(io, f, args)
510-
fname = nameof(f)
533+
fname = istree(f) ? Symbol(repr(f)) : nameof(f)
511534
binary = Base.isbinaryoperator(fname)
512535
if binary
513536
for (i, t) in enumerate(args)
@@ -543,6 +566,8 @@ function show_term(io::IO, t)
543566
show_mul(io, args)
544567
elseif f === (^)
545568
show_pow(io, args)
569+
elseif f === (getindex)
570+
show_ref(io, f, args)
546571
else
547572
show_call(io, f, args)
548573
end
@@ -573,7 +598,7 @@ where `coeff` and the vals are `<:Number` and keys are symbolic.
573598
- `arguments(::Add)` -- returns a totally ordered vector of arguments. i.e.
574599
`[coeff, keyM*valM, keyN*valN...]`
575600
"""
576-
struct Add{X, T<:Number, D, M} <: Symbolic{X}
601+
struct Add{X<:Number, T<:Number, D, M} <: Symbolic{X}
577602
coeff::T
578603
dict::D
579604
sorted_args_cache::Ref{Any}
@@ -699,7 +724,7 @@ where `coeff` and the vals are `<:Number` and keys are symbolic.
699724
- `arguments(::Mul)` -- returns a totally ordered vector of arguments. i.e.
700725
`[coeff, keyM^valM, keyN^valN...]`
701726
"""
702-
struct Mul{X, T<:Number, D, M} <: Symbolic{X}
727+
struct Mul{X<:Number, T<:Number, D, M} <: Symbolic{X}
703728
coeff::T
704729
dict::D
705730
sorted_args_cache::Ref{Any}

0 commit comments

Comments
 (0)