@@ -67,6 +67,25 @@ Base.isequal(::Symbolic, ::Symbolic) = false
6767
6868# ## End of interface
6969
70+ # ## Metatheory.jl e-graph rewriting integration
71+
72+ """
73+ SymtypeAnalysis
74+
75+ This abstract type is used to identify the EGraph analysis
76+ that keeps track of symtype through an EGraph. This must
77+ be added to every EGraph that is used in SymbolicUtils.
78+ """
79+ abstract type SymtypeAnalysis <: AbstractAnalysis end
80+ _getsymtype (T:: Type{<:Symbolic{X}} ) where X = X
81+ _getsymtype (T:: Type{X} ) where {X} = X
82+ EGraphs. make (an:: Type{SymtypeAnalysis} , g:: EGraph , n:: ENodeLiteral ) = symtype (n. value)
83+ EGraphs. make (an:: Type{SymtypeAnalysis} , g:: EGraph , n:: ENodeTerm{T} ) where {T} = _getsymtype (T)
84+ EGraphs. join (an:: Type{SymtypeAnalysis} , A, B) = Union{A, B}
85+
86+ # TODO JOIN egraph analysis
87+ TermInterface. symtype (ec:: EClass ) = getdata (ec, SymtypeAnalysis, Any)
88+
7089function to_symbolic (x)
7190 Base. depwarn (" `to_symbolic(x)` is deprecated, define the interface for your " *
7291 " symbolic structure using `istree(x)`, `operation(x)`, `arguments(x)` " *
@@ -348,6 +367,24 @@ function term(f, args...; type = nothing)
348367 Term {T} (f, [args... ])
349368end
350369
370+ """
371+ unflatten(t::Symbolic{T})
372+ Binarizes `Term`s with n-ary operations
373+ """
374+ function unflatten (t:: Symbolic{T} ) where {T}
375+ if istree (t)
376+ f = operation (t)
377+ if f == (+ ) || f == (* ) # TODO check out for other n-ary --> binary ops
378+ a = arguments (t)
379+ return foldl ((x,y) -> Term {T} (f, [x, y]), a)
380+ end
381+ end
382+ return t
383+ end
384+
385+ unflatten (t) = t
386+
387+
351388"""
352389 similarterm(t, f, args, symtype; metadata=nothing)
353390
@@ -366,10 +403,17 @@ different type than `t`, because `f` also influences the result.
366403"""
367404TermInterface. similarterm (t:: Type{<:Symbolic} , f, args; metadata= nothing , exprhead= :call ) =
368405 similarterm (t, f, args, _promote_symtype (f, args); metadata= metadata, exprhead= exprhead)
369-
406+
407+ TermInterface. similarterm (t:: Type{<:Symbolic} , f:: Symbol , args; metadata= nothing , exprhead= :call ) =
408+ TermInterface. similarterm (t, eval (f), args; metadata= metadata, exprhead= exprhead)
409+
370410TermInterface. similarterm (t:: Type{<:Term} , f, args, symtype; metadata= nothing , exprhead= :call ) =
371411 Term {_promote_symtype(f, args)} (f, args; metadata= metadata)
372412
413+ TermInterface. similarterm (t:: Type{<:Term} , f:: Symbol , args, symtype; metadata= nothing , exprhead= :call ) =
414+ Term {_promote_symtype(eval(f), args)} (eval (f), args; metadata= metadata)
415+
416+
373417# --------------------
374418# --------------------
375419# ### Pretty printing
@@ -549,6 +593,11 @@ showraw(t) = showraw(stdout, t)
549593sdict (kv... ) = Dict {Any, Number} (kv... )
550594
551595const SN = Symbolic{<: Number }
596+ # TODO Reviewme this is necessary for Metatheory.jl egraph rewriting
597+ # integration. Constructors of `Add, Mul, Pow...` from Base (+, *, ^, ...)
598+ # Should now accepts EClasses as arguments.
599+ const SN_EC = Union{SN, EClass}
600+
552601"""
553602 Add(T, coeff, dict::Dict)
554603
583632
584633TermInterface. symtype (a:: Add{X} ) where {X} = X
585634
586-
587635TermInterface. istree (a:: Type{Add} ) = true
588636
589637TermInterface. operation (a:: Add ) = +
@@ -603,6 +651,17 @@ Base.isequal(a::Add, b::Add) = a.coeff == b.coeff && isequal(a.dict, b.dict)
603651
604652Base. show (io:: IO , a:: Add ) = show_term (io, a)
605653
654+ function toterm (t:: Add{T} ) where T
655+ args = []
656+ for (k, coeff) in t. dict
657+ push! (args, coeff == 1 ? k : Term {T} (* , [coeff, k]))
658+ end
659+ Term {T} (+ , args)
660+ end
661+
662+ toterm (t) = t
663+
664+
606665"""
607666 makeadd(sign, coeff::Number, xs...)
608667
@@ -641,7 +700,7 @@ add_t(a,b) = promote_symtype(+, symtype(a), symtype(b))
641700sub_t (a,b) = promote_symtype (- , symtype (a), symtype (b))
642701sub_t (a) = promote_symtype (- , symtype (a))
643702
644- function + (a:: SN , b:: SN )
703+ function + (a:: SN_EC , b:: SN_EC )
645704 if a isa Add
646705 coeff, dict = makeadd (1 , 0 , b)
647706 T = promote_symtype (+ , symtype (a), symtype (b))
@@ -652,11 +711,11 @@ function +(a::SN, b::SN)
652711 Add (add_t (a,b), makeadd (1 , 0 , a, b)... )
653712end
654713
655- + (a:: Number , b:: SN ) = Add (add_t (a,b), makeadd (1 , a, b)... )
714+ + (a:: Number , b:: SN_EC ) = Add (add_t (a,b), makeadd (1 , a, b)... )
656715
657- + (a:: SN , b:: Number ) = Add (add_t (a,b), makeadd (1 , b, a)... )
716+ + (a:: SN_EC , b:: Number ) = Add (add_t (a,b), makeadd (1 , b, a)... )
658717
659- + (a:: SN ) = a
718+ + (a:: SN_EC ) = a
660719
661720+ (a:: Add , b:: Add ) = Add (add_t (a,b),
662721 a. coeff + b. coeff,
@@ -668,17 +727,17 @@ end
668727
669728- (a:: Add ) = Add (sub_t (a), - a. coeff, mapvalues ((_,v) -> - v, a. dict))
670729
671- - (a:: SN ) = Add (sub_t (a), makeadd (- 1 , 0 , a)... )
730+ - (a:: SN_EC ) = Add (sub_t (a), makeadd (- 1 , 0 , a)... )
672731
673732- (a:: Add , b:: Add ) = Add (sub_t (a,b),
674733 a. coeff - b. coeff,
675734 _merge (- , a. dict, b. dict, filter= _iszero))
676735
677- - (a:: SN , b:: SN ) = a + (- b)
736+ - (a:: SN_EC , b:: SN_EC ) = a + (- b)
678737
679- - (a:: Number , b:: SN ) = a + (- b)
738+ - (a:: Number , b:: SN_EC ) = a + (- b)
680739
681- - (a:: SN , b:: Number ) = a + (- b)
740+ - (a:: SN_EC , b:: Number ) = a + (- b)
682741
683742"""
684743 Mul(T, coeff, dict)
@@ -753,6 +812,16 @@ Base.isequal(a::Mul, b::Mul) = a.coeff == b.coeff && isequal(a.dict, b.dict)
753812
754813Base. show (io:: IO , a:: Mul ) = show_term (io, a)
755814
815+ function toterm (t:: Mul{T} ) where T
816+ args = []
817+ push! (args, t. coeff)
818+ for (k, deg) in t. dict
819+ push! (args, deg == 1 ? k : Term {T} (^ , [k, deg]))
820+ end
821+ Term {T} (* , args)
822+ end
823+
824+
756825function makemul (coeff, xs... ; d= sdict ())
757826 for x in xs
758827 if x isa Pow && x. exp isa Number
777846mul_t (a,b) = promote_symtype (* , symtype (a), symtype (b))
778847mul_t (a) = promote_symtype (* , symtype (a))
779848
780- * (a:: SN ) = a
849+ * (a:: SN_EC ) = a
781850
782- function * (a:: SN , b:: SN )
851+ function * (a:: SN_EC , b:: SN_EC )
783852 # Always make sure Div wraps Mul
784853 if a isa Div && b isa Div
785854 Div (a. num * b. num, a. den * b. den)
796865 a. coeff * b. coeff,
797866 _merge (+ , a. dict, b. dict, filter= _iszero))
798867
799- function * (a:: Number , b:: SN )
868+ function * (a:: Number , b:: SN_EC )
800869 if iszero (a)
801870 a
802871 elseif isone (a)
@@ -812,17 +881,17 @@ function *(a::Number, b::SN)
812881 end
813882end
814883
815- * (a:: SN , b:: Number ) = b * a
884+ * (a:: SN_EC , b:: Number ) = b * a
816885
817- \ (a:: SN , b:: Union{Number, SN } ) = b / a
886+ \ (a:: SN_EC , b:: Union{Number, SN_EC } ) = b / a
818887
819- \ (a:: Number , b:: SN ) = b / a
888+ \ (a:: Number , b:: SN_EC ) = b / a
820889
821- / (a:: SN , b:: Number ) = (b isa Integer ? 1 // b : inv (b)) * a
890+ / (a:: SN_EC , b:: Number ) = (b isa Integer ? 1 // b : inv (b)) * a
822891
823- // (a:: Union{SN , Number} , b:: SN ) = a / b
892+ // (a:: Union{SN_EC , Number} , b:: SN_EC ) = a / b
824893
825- // (a:: SN , b:: T ) where {T <: Number } = (one (T) // b) * a
894+ // (a:: SN_EC , b:: T ) where {T <: Number } = (one (T) // b) * a
826895
827896"""
828897 Div(numerator_factors, denominator_factors, simplified=false)
901970
902971Base. show (io:: IO , d:: Div ) = show_term (io, d)
903972
904- / (a:: Union{SN,Number} , b:: SN ) = Div (a,b)
973+ function toterm (t:: Div{T} ) where T
974+ Term {T} (/ , [t. num, t. den])
975+ end
976+
977+ / (a:: Union{SN_EC,Number} , b:: SN_EC ) = Div (a,b)
905978
906979"""
907980 Pow(base, exp)
@@ -944,6 +1017,10 @@ Base.isequal(p::Pow, b::Pow) = isequal(p.base, b.base) && isequal(p.exp, b.exp)
9441017
9451018Base. show (io:: IO , p:: Pow ) = show_term (io, p)
9461019
1020+ function toterm (t:: Pow{T} ) where T
1021+ Term {T} (^ , [t. base, t. exp])
1022+ end
1023+
9471024function makepow (a, b)
9481025 base = a
9491026 exp = b
@@ -954,11 +1031,11 @@ function makepow(a, b)
9541031 return (base, exp)
9551032end
9561033
957- ^ (a:: SN , b) = Pow (a, b)
1034+ ^ (a:: SN_EC , b) = Pow (a, b)
9581035
959- ^ (a:: SN , b:: SN ) = Pow (a, b)
1036+ ^ (a:: SN_EC , b:: SN_EC ) = Pow (a, b)
9601037
961- ^ (a:: Number , b:: SN ) = Pow (a, b)
1038+ ^ (a:: Number , b:: SN_EC ) = Pow (a, b)
9621039
9631040function ^ (a:: Mul , b:: Number )
9641041 coeff = unstable_pow (a. coeff, b)
0 commit comments