From 47801aae7f7f96c9a2d43bbf538b8099b38fc439 Mon Sep 17 00:00:00 2001 From: Manuel Berkemeier Date: Fri, 6 Jan 2023 17:02:57 +0100 Subject: [PATCH 1/2] add `promote_symtype` for `_map` and `_mapreduce`. --- src/array-lib.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/array-lib.jl b/src/array-lib.jl index d9748115f..47583dcaa 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -300,8 +300,48 @@ function _map(f, x, xs...) Term{Any}(map, [f, x, xs...])) end +function SymbolicUtils.promote_symtype(::typeof(_map), F, XS...) + # like `propagate_atype` but without filtering out non-symbolic + # arrays: + As = [atype(symtype(T)) for T in XS] + Atype = if length(As) <= 1 + _propagate_atype(As...) + else + foldl(_propagate_atype, As) + end + + T = if Base.issingletontype(F) + mapreduce(Base.Fix1(promote_symtype, F.instance), promote_type, eltype.(XS)) + else + promote_type(Real,mapreduce(eltype, promote_type, XS)) + end + return Atype{T} + # TODO: check consistency with result from calling `map`, + # i.e., return result should correspond to type-parameter + # of ArrayOp. + # Difficulty: We only have the type `F` of the mapped function + # and can not easily call it or pass it to `promote_symtype` + # as the first argument. + # See also comments in `promote_symtype(::typeof(_mapreduce))`. +end + @inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...) +function SymbolicUtils.promote_symtype( + ::typeof(_mapreduce), F, OP, X, D, K +) + A = promote_symtype(_map, F, X) + if Base.issingletontype(OP) + return promote_symtype(OP.instance, eltype(A), eltype(A)) + else + return promote_type(Real, eltype(A)) + end + # NOTE it would be easier and more precise to define + # `_promote_symtype` with the actual arguments instead of + # their types. Alternatively, it would be convient to be able + # to call `promote_symtype` with the operator type `OP`. +end + function scalarize_op(::typeof(_mapreduce), t) f,g,x,dims,kw = arguments(t) # we wrap and unwrap to make things work smoothly. From f4d61210ad5e6b03289e7bf02f1ecc13bad19e4c Mon Sep 17 00:00:00 2001 From: Manuel Berkemeier Date: Mon, 16 Jan 2023 12:45:21 +0100 Subject: [PATCH 2/2] 1) remove `promote_symtype` for `_map` and `_mapreduce` 2) introduce `_promote_symtype` instead, logic taken from the respective methods 3) adapt method definitions to avoid redundancy 4) some tests for the problem described in #814 --- src/array-lib.jl | 91 ++++++++++++++++++++---------------------------- test/arrays.jl | 10 +++++- 2 files changed, 46 insertions(+), 55 deletions(-) diff --git a/src/array-lib.jl b/src/array-lib.jl index 47583dcaa..8419d879d 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -287,59 +287,43 @@ end @wrapped Base.map(f, x, y, z::AbstractArray, w...) = _map(f, x, y, z, w...) function _map(f, x, xs...) + return ArrayOp( + SymbolicUtils._promote_symtype(_map, (x,xs)), + (idx...,), + expr, + +, + Term{Any}(map, [f, x, xs...]) + ) +end + +function SymbolicUtils._promote_symtype(::typeof(_map), args) + f, x, xs... = args + N = ndims(x) idx = makesubscripts(N) - + expr = f(map(a->a[idx...], [x, xs...])...) Atype = propagate_atype(map, f, x, xs...) - ArrayOp(Atype{symtype(expr), N}, - (idx...,), - expr, - +, - Term{Any}(map, [f, x, xs...])) -end -function SymbolicUtils.promote_symtype(::typeof(_map), F, XS...) - # like `propagate_atype` but without filtering out non-symbolic - # arrays: - As = [atype(symtype(T)) for T in XS] - Atype = if length(As) <= 1 - _propagate_atype(As...) - else - foldl(_propagate_atype, As) - end - - T = if Base.issingletontype(F) - mapreduce(Base.Fix1(promote_symtype, F.instance), promote_type, eltype.(XS)) - else - promote_type(Real,mapreduce(eltype, promote_type, XS)) - end - return Atype{T} - # TODO: check consistency with result from calling `map`, - # i.e., return result should correspond to type-parameter - # of ArrayOp. - # Difficulty: We only have the type `F` of the mapped function - # and can not easily call it or pass it to `promote_symtype` - # as the first argument. - # See also comments in `promote_symtype(::typeof(_mapreduce))`. + return Atype{symtype(expr), N} end @inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...) -function SymbolicUtils.promote_symtype( - ::typeof(_mapreduce), F, OP, X, D, K -) - A = promote_symtype(_map, F, X) - if Base.issingletontype(OP) - return promote_symtype(OP.instance, eltype(A), eltype(A)) - else - return promote_type(Real, eltype(A)) +function SymbolicUtils._promote_symtype(::typeof(_mapreduce), args) + @assert length(args) == 5 + f, op, x, dims, kw = args + + N = ndims(x) + idx = makesubscripts(N) + expr = f(x[idx...]) + T = symtype(op(expr, expr)) + if dims === (:) + return T end - # NOTE it would be easier and more precise to define - # `_promote_symtype` with the actual arguments instead of - # their types. Alternatively, it would be convient to be able - # to call `promote_symtype` with the operator type `OP`. + Atype = propagate_atype(_mapreduce, f, op, x, dims, (kw...,)) + return Atype{T, N} end function scalarize_op(::typeof(_mapreduce), t) @@ -350,20 +334,19 @@ function scalarize_op(::typeof(_mapreduce), t) end @wrapped function Base.mapreduce(f, g, x::AbstractArray; dims=:, kw...) - idx = makesubscripts(ndims(x)) - out_idx = [dims == (:) || i in dims ? 1 : idx[i] for i = 1:ndims(x)] - expr = f(x[idx...]) - T = symtype(g(expr, expr)) + Stype = SymbolicUtils._promote_symtype(_mapreduce, (f,g,x,dims,kw)) if dims === (:) - return Term{T}(_mapreduce, [f, g, x, dims, (kw...,)]) + return Term{Stype}(_mapreduce, [f, g, x, dims, (kw...,)]) end - - Atype = propagate_atype(_mapreduce, f, g, x, dims, (kw...,)) - ArrayOp(Atype{T, ndims(x)}, - (out_idx...,), - expr, - g, - Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)])) + idx = makesubscripts(ndims(x)) + out_idx = [dims == (:) || i in dims ? 1 : idx[i] for i = 1:ndims(x)] + return ArrayOp( + Stype, + (out_idx...,), + expr, + g, + Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)]) + ) end for (ff, opts) in [sum => (identity, +, false), diff --git a/test/arrays.jl b/test/arrays.jl index b90f05a61..83a110f9e 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -1,6 +1,6 @@ using Symbolics using SymbolicUtils, Test -using Symbolics: symtype, shape, wrap, unwrap, Unknown, Arr, arrterm, jacobian, @variables, value, get_variables, @arrayop, getname +using Symbolics: symtype, shape, wrap, unwrap, Unknown, Arr, arrterm, jacobian, @variables, value, get_variables, @arrayop, getname, simplify using Base: Slice using SymbolicUtils: Sym, term, operation @@ -79,6 +79,14 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) # #417 @test isequal(Symbolics.scalarize(x', (1,1)), x[1]) + # #814 + @test isa(simplify(sum(b .* 1) ), Num) + @test isa(simplify(prod(x .+ 2.0) ), Num) + @test isa(simplify(mapreduce(x -> (x+1)/(0.1 + abs(x)), ^, u)), Num) # exponent(s) must not be Int until #455 is fixed + @test Symbolics.symtype(simplify(sum(b .* 1))) <: Real + @test Symbolics.symtype(simplify(prod(x .+ 2.0))) <: Real + @test Symbolics.symtype(simplify(mapreduce(x -> (x+1)/(0.1 + abs(x)), ^, u))) <: Real + # #483 # examples by @gronniger @variables A[1:2, 1:2]