Skip to content

Commit 47801aa

Browse files
committed
add promote_symtype for _map and _mapreduce.
1 parent a3fa921 commit 47801aa

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

src/array-lib.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,48 @@ function _map(f, x, xs...)
300300
Term{Any}(map, [f, x, xs...]))
301301
end
302302

303+
function SymbolicUtils.promote_symtype(::typeof(_map), F, XS...)
304+
# like `propagate_atype` but without filtering out non-symbolic
305+
# arrays:
306+
As = [atype(symtype(T)) for T in XS]
307+
Atype = if length(As) <= 1
308+
_propagate_atype(As...)
309+
else
310+
foldl(_propagate_atype, As)
311+
end
312+
313+
T = if Base.issingletontype(F)
314+
mapreduce(Base.Fix1(promote_symtype, F.instance), promote_type, eltype.(XS))
315+
else
316+
promote_type(Real,mapreduce(eltype, promote_type, XS))
317+
end
318+
return Atype{T}
319+
# TODO: check consistency with result from calling `map`,
320+
# i.e., return result should correspond to type-parameter
321+
# of ArrayOp.
322+
# Difficulty: We only have the type `F` of the mapped function
323+
# and can not easily call it or pass it to `promote_symtype`
324+
# as the first argument.
325+
# See also comments in `promote_symtype(::typeof(_mapreduce))`.
326+
end
327+
303328
@inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...)
304329

330+
function SymbolicUtils.promote_symtype(
331+
::typeof(_mapreduce), F, OP, X, D, K
332+
)
333+
A = promote_symtype(_map, F, X)
334+
if Base.issingletontype(OP)
335+
return promote_symtype(OP.instance, eltype(A), eltype(A))
336+
else
337+
return promote_type(Real, eltype(A))
338+
end
339+
# NOTE it would be easier and more precise to define
340+
# `_promote_symtype` with the actual arguments instead of
341+
# their types. Alternatively, it would be convient to be able
342+
# to call `promote_symtype` with the operator type `OP`.
343+
end
344+
305345
function scalarize_op(::typeof(_mapreduce), t)
306346
f,g,x,dims,kw = arguments(t)
307347
# we wrap and unwrap to make things work smoothly.

0 commit comments

Comments
 (0)