Skip to content

Commit 2e940aa

Browse files
authored
Fix type instability when setindex!! (#549)
* Add `BangBang.possible` for general arrays * Remove redundant BangBang.possible * add comments rt JuliaFold2 PR * copy tor's fix to JuliaFolds here * add BangBang prefix to the functions * remove tests of possible, no longer needed * Add Tor's tests * Import `setindex!!`
1 parent 04b03cd commit 2e940aa

File tree

2 files changed

+145
-26
lines changed

2 files changed

+145
-26
lines changed

src/utils.jl

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -538,36 +538,40 @@ function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) whe
538538
end
539539

540540
# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
541-
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
542-
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.
543-
function BangBang.possible(
544-
::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer
545-
) where {C<:AbstractMatrix,T<:AbstractVector}
546-
return BangBang.implements(setindex!, C) &&
547-
promote_type(eltype(C), eltype(T)) <: eltype(C)
548-
end
549-
function BangBang.possible(
550-
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractPPL.ConcretizedSlice, ::Integer
551-
) where {C<:AbstractMatrix,T<:AbstractVector}
552-
return BangBang.implements(setindex!, C) &&
553-
promote_type(eltype(C), eltype(T)) <: eltype(C)
554-
end
555-
# HACK: Makes it possible to use ranges, etc. for setting a vector.
556-
# For example, without this hack, BangBang.jl will consider
541+
# and https://github.com/JuliaFolds/BangBang.jl/pull/238, https://github.com/JuliaFolds2/BangBang.jl/pull/16.
542+
# This avoids type-instability in `dot_assume` for `SimpleVarInfo`.
543+
# The following code a copy from https://github.com/JuliaFolds2/BangBang.jl/pull/16 authored by torfjelde
544+
# Default implementation for `_setindex!` with `AbstractArray`.
545+
# But this will return `false` even in cases such as
546+
#
547+
# setindex!!([1, 2, 3], [4, 5, 6], :)
548+
#
549+
# because `promote_type(eltype(C), T) <: eltype(C)` is `false`.
550+
# To address this, we specialize on the case where `T<:AbstractArray`.
551+
# In addition, we need to support a wide range of indexing behaviors:
557552
#
558-
# x[1:2] = [1, 2]
553+
# We also need to ensure that the dimensionality of the index is
554+
# valid, i.e. that we're not returning `true` in cases such as
559555
#
560-
# as NOT supported. This results is calling the immutable
561-
# `BangBang.setindex` instead, which also ends up expanding the
562-
# type of the containing array (`x` in the above scenario) to
563-
# have element type `Any`.
564-
# The below code just, correctly, marks this as possible and
565-
# thus we hit the mutable `setindex!` instead.
556+
# setindex!!([1, 2, 3], [4, 5], 1)
557+
#
558+
# which should return `false`.
559+
_index_dimension(::Any) = 0
560+
_index_dimension(::Colon) = 1
561+
_index_dimension(::AbstractVector) = 1
562+
_index_dimension(indices::Tuple) = sum(map(_index_dimension, indices))
563+
566564
function BangBang.possible(
567-
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractVector{<:Integer}
568-
) where {C<:AbstractVector,T<:AbstractVector}
565+
::typeof(BangBang._setindex!), ::C, ::T, indices::Vararg
566+
) where {M,C<:AbstractArray{<:Real},T<:AbstractArray{<:Real,M}}
569567
return BangBang.implements(setindex!, C) &&
570-
promote_type(eltype(C), eltype(T)) <: eltype(C)
568+
promote_type(eltype(C), eltype(T)) <: eltype(C) &&
569+
# This will still return `false` for scenarios such as
570+
#
571+
# setindex!!([1, 2, 3], [4, 5, 6], :, 1)
572+
#
573+
# which are in fact valid. However, this cases are rare.
574+
(_index_dimension(indices) == M || _index_dimension(indices) == 1)
571575
end
572576

573577
# HACK(torfjelde): This makes it so it works on iterators, etc. by default.

test/utils.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,119 @@
4848
x = rand(dist)
4949
@test vectorize(dist, x) == vec(x.UL)
5050
end
51+
52+
@testset "BangBang.possible" begin
53+
using DynamicPPL.BangBang: setindex!!
54+
55+
# Some utility methods for testing `setindex!`.
56+
test_linear_index_only(::Tuple, ::AbstractArray) = false
57+
test_linear_index_only(inds::NTuple{1}, ::AbstractArray) = true
58+
test_linear_index_only(inds::NTuple{1}, ::AbstractVector) = false
59+
60+
function replace_colon_with_axis(inds::Tuple, x)
61+
ntuple(length(inds)) do i
62+
inds[i] isa Colon ? axes(x, i) : inds[i]
63+
end
64+
end
65+
function replace_colon_with_vector(inds::Tuple, x)
66+
ntuple(length(inds)) do i
67+
inds[i] isa Colon ? collect(axes(x, i)) : inds[i]
68+
end
69+
end
70+
function replace_colon_with_range(inds::Tuple, x)
71+
ntuple(length(inds)) do i
72+
inds[i] isa Colon ? (1:size(x, i)) : inds[i]
73+
end
74+
end
75+
function replace_colon_with_booleans(inds::Tuple, x)
76+
ntuple(length(inds)) do i
77+
inds[i] isa Colon ? trues(size(x, i)) : inds[i]
78+
end
79+
end
80+
81+
function replace_colon_with_range_linear(inds::NTuple{1}, x::AbstractArray)
82+
return inds[1] isa Colon ? (1:length(x),) : inds
83+
end
84+
85+
@testset begin
86+
@test setindex!!((1, 2, 3), :two, 2) === (1, :two, 3)
87+
@test setindex!!((a=1, b=2, c=3), :two, :b) === (a=1, b=:two, c=3)
88+
@test setindex!!([1, 2, 3], :two, 2) == [1, :two, 3]
89+
@test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 10, :a) ==
90+
Dict(:a => 10, :b => 2)
91+
@test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 3, "c") ==
92+
Dict(:a => 1, :b => 2, "c" => 3)
93+
end
94+
95+
@testset "mutation" begin
96+
@testset "without type expansion" begin
97+
for args in [([1, 2, 3], 20, 2), (Dict(:a => 1, :b => 2), 10, :a)]
98+
@test setindex!!(args...) === args[1]
99+
end
100+
end
101+
102+
@testset "with type expansion" begin
103+
@test setindex!!([1, 2, 3], [4, 5], 1) == [[4, 5], 2, 3]
104+
@test setindex!!([1, 2, 3], [4, 5, 6], :, 1) == [4, 5, 6]
105+
end
106+
end
107+
108+
@testset "slices" begin
109+
@testset "$(typeof(x)) with $(src_idx)" for (x, src_idx) in [
110+
# Vector.
111+
(randn(2), (:,)),
112+
(randn(2), (1:2,)),
113+
# Matrix.
114+
(randn(2, 3), (:,)),
115+
(randn(2, 3), (:, 1)),
116+
(randn(2, 3), (:, 1:3)),
117+
# 3D array.
118+
(randn(2, 3, 4), (:, 1, :)),
119+
(randn(2, 3, 4), (:, 1:3, :)),
120+
(randn(2, 3, 4), (1, 1:3, :)),
121+
]
122+
# Base case.
123+
@test @inferred(setindex!!(x, x[src_idx...], src_idx...)) === x
124+
125+
# If we have `Colon` in the index, we replace this with other equivalent indices.
126+
if any(Base.Fix2(isa, Colon), src_idx)
127+
if test_linear_index_only(src_idx, x)
128+
# With range instead of `Colon`.
129+
@test @inferred(
130+
setindex!!(
131+
x,
132+
x[src_idx...],
133+
replace_colon_with_range_linear(src_idx, x)...,
134+
)
135+
) === x
136+
else
137+
# With axis instead of `Colon`.
138+
@test @inferred(
139+
setindex!!(
140+
x, x[src_idx...], replace_colon_with_axis(src_idx, x)...
141+
)
142+
) === x
143+
# With range instead of `Colon`.
144+
@test @inferred(
145+
setindex!!(
146+
x, x[src_idx...], replace_colon_with_range(src_idx, x)...
147+
)
148+
) === x
149+
# With vectors instead of `Colon`.
150+
@test @inferred(
151+
setindex!!(
152+
x, x[src_idx...], replace_colon_with_vector(src_idx, x)...
153+
)
154+
) === x
155+
# With boolean index instead of `Colon`.
156+
@test @inferred(
157+
setindex!!(
158+
x, x[src_idx...], replace_colon_with_booleans(src_idx, x)...
159+
)
160+
) === x
161+
end
162+
end
163+
end
164+
end
165+
end
51166
end

0 commit comments

Comments
 (0)