Skip to content

Commit a1b02d8

Browse files
authored
Fix pairwise for type-unstable corner case function (#772)
`promote_type` is not a completely correct way of computing an upper bound for the return eltype. Use the same strategy as `map` and `broadcast` in Base instead, but with `promote_eltype` rather than `promote_typejoin`. We can still use `typejoin_union_tuple` since promotion does not happen inside tuple types. On Julia versions before 1.6 we would have to copy the full definition of `typejoin_union_tuple`, which is quite complex, so instead fall back to inferring eltype `Any`.
1 parent f9cfd12 commit a1b02d8

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

src/pairwise.jl

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,33 @@ function _pairwise!(f, dest::AbstractMatrix, x, y;
122122
return _pairwise!(Val(skipmissing), f, dest, x′, y′, symmetric)
123123
end
124124

125+
if VERSION >= v"1.6.0-DEV"
126+
# Function has moved in Julia 1.7
127+
if isdefined(Base, :typejoin_union_tuple)
128+
using Base: typejoin_union_tuple
129+
else
130+
using Base.Broadcast: typejoin_union_tuple
131+
end
132+
else
133+
typejoin_union_tuple(::Type) = Any
134+
end
135+
136+
# Identical to `Base.promote_typejoin` except that it uses `promote_type`
137+
# instead of `typejoin` to combine members of `Union` types
138+
function promote_type_union(::Type{T}) where T
139+
if T === Union{}
140+
return Union{}
141+
elseif T isa UnionAll
142+
return Any # TODO: compute more precise bounds
143+
elseif T isa Union
144+
return promote_type(promote_type_union(T.a), promote_type_union(T.b))
145+
elseif T <: Tuple
146+
return typejoin_union_tuple(T)
147+
else
148+
return T
149+
end
150+
end
151+
125152
function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmissing}
126153
x′ = x isa Union{AbstractArray, Tuple, NamedTuple} ? x : collect(x)
127154
y′ = y isa Union{AbstractArray, Tuple, NamedTuple} ? y : collect(y)
@@ -148,10 +175,11 @@ function _pairwise(::Val{skipmissing}, f, x, y, symmetric::Bool) where {skipmiss
148175
if isconcretetype(eltype(dest))
149176
return dest
150177
else
151-
# Final eltype depends on actual contents (consistent with map and broadcast)
178+
# Final eltype depends on actual contents (consistent with `map` and `broadcast`
179+
# but using `promote_type` rather than `promote_typejoin`)
152180
U = mapreduce(typeof, promote_type, dest)
153181
# V is inferred (contrary to U), but it only gives an upper bound for U
154-
V = promote_type(T, Tsm)
182+
V = promote_type_union(Union{T, Tsm})
155183
return convert(Matrix{U}, dest)::Matrix{<:V}
156184
end
157185
end

test/pairwise.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,34 @@ arbitrary_fun(x, y) = cor(x, y)
258258
end
259259
end
260260
end
261+
262+
@testset "promote_type_union" begin
263+
@test StatsBase.promote_type_union(Int) === Int
264+
@test StatsBase.promote_type_union(Real) === Real
265+
@test StatsBase.promote_type_union(Union{Int, Float64}) === Float64
266+
@test StatsBase.promote_type_union(Union{Int, Missing}) === Union{Int, Missing}
267+
@test StatsBase.promote_type_union(Union{Int, String}) === Any
268+
@test StatsBase.promote_type_union(Vector) === Any
269+
@test StatsBase.promote_type_union(Union{}) === Union{}
270+
if VERSION >= v"1.6.0-DEV"
271+
@test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) ===
272+
Tuple{Real}
273+
else
274+
@test StatsBase.promote_type_union(Tuple{Union{Int, Float64}}) ===
275+
Any
276+
end
277+
end
278+
279+
@testset "type-unstable corner case (#771)" begin
280+
v = [rand(5) for _=1:10]
281+
function f(v)
282+
pairwise(v) do x, y
283+
(x[1] < 0 ? nothing :
284+
x[1] > y[1] ? 1 : 1.5,
285+
0)
286+
end
287+
end
288+
res = f(v)
289+
@test res isa Matrix{Tuple{Real, Int}}
290+
end
261291
end

0 commit comments

Comments
 (0)