Skip to content

Commit 563a463

Browse files
penelopeysmclaude
andauthored
Fix MustNotOverwrite when setting slices/indices of slices (#1325)
Closes #1321. Claude did most of this, I tweaked the tests a fair bit from what it wrote though. If anyone is interested, this is the prompt I fed Claude <details> # Issue In VarNamedTuple you can set different indices of the same *slice* of the same array (i.e. x[1:2][1] and x[1:2][2]): ```julia julia> using DynamicPPL julia> vnt = VarNamedTuple() VarNamedTuple() julia> x = zeros(2) 2-element Vector{Float64}: 0.0 0.0 julia> vnt = DynamicPPL.templated_setindex!!(vnt, 1.0, @varname(x[1:2][1]), x) VarNamedTuple └─ x => PartialArray size=(2,) data::Vector{Float64} └─ (1,) => 1.0 julia> vnt = DynamicPPL.templated_setindex!!(vnt, 2.0, @varname(x[1:2][2]), x) VarNamedTuple └─ x => PartialArray size=(2,) data::Vector{Float64} ├─ (1,) => 1.0 └─ (2,) => 2.0 ``` This required some careful thought to make sure that the second set didn't overwrite the first one. See e.g. lines 217 to 234 of src/varnamedtuple/getset.jl. However, I forgot to handle this for the case where we check for MustNotOverwrite: ```julia julia> using AbstractPPL: @opticof julia> using DynamicPPL.VarNamedTuples: _setindex_optic!!, SkipTemplate, MustNotOverwrite julia> vnt = VarNamedTuple() VarNamedTuple() julia> vnt = _setindex_optic!!(vnt, 1.0, @opticof(_.x[1:2][1]), SkipTemplate{1}(x), MustNotOverwrite(@varname(x[1:2][1]))) VarNamedTuple └─ x => PartialArray size=(2,) data::Vector{Float64} └─ (1,) => 1.0 julia> vnt = _setindex_optic!!(vnt, 2.0, @opticof(_.x[1:2][2]), SkipTemplate{1}(x), MustNotOverwrite(@varname(x[1:2][2]))) ERROR: MustNotOverwriteError: Attempted to set a value for x[1:2][2], but a value already existed. This indicates that a value is being set twice (e.g. if the same variable occurs in a model twice). Stacktrace: [1] _setindex_optic!!(pa::DynamicPPL.VarNamedTuples.PartialArray{…}, value::Float64, optic::AbstractPPL.Index{…}, template::Vector{…}, permissions::MustNotOverwrite{…}) @ DynamicPPL.VarNamedTuples ~/ppl/dppl/src/varnamedtuple/getset.jl:173 [2] _setindex_optic!!(vnt::VarNamedTuple{…}, value::Float64, optic::AbstractPPL.Property{…}, template::SkipTemplate{…}, permissions::MustNotOverwrite{…}) @ DynamicPPL.VarNamedTuples ~/ppl/dppl/src/varnamedtuple/getset.jl:276 [3] top-level scope @ REPL[20]:1 Some type information was truncated. Use `show(err)` to see complete types. ``` Two requests: (1) could you help me fix that? And after we've done that (but don't launch into this first; just do the fix first) (2) could you reason about whether there's a better way to handle the ordinary case (the first code block)? Right now we do a merge, but I think it might be possible for us to create a separate method for setting a (slice of a) PartialArray into a PartialArray, where we make sure to only copy the indices where the mask is set to true. </details> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 297e2e0 commit 563a463

File tree

5 files changed

+85
-2
lines changed

5 files changed

+85
-2
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 0.40.14
2+
3+
Fixed `check_model()` erroneously failing for models such as `x[1:2] .~ univariate_dist`.`
4+
15
# 0.40.13
26

37
Fixed `densify!!` not recursing into `VarNamedTuple`s or `ArrayLikeBlock`s inside `PartialArray`s.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.40.13"
3+
version = "0.40.14"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/varnamedtuple/getset.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ function _setindex_optic!!(
168168

169169
is_multiindex = _is_multiindex(template, coptic.ix...; coptic.kw...)
170170

171-
if permissions isa MustNotOverwrite
171+
if permissions isa MustNotOverwrite && optic.child isa AbstractPPL.Iden
172172
if any(view(pa.mask, coptic.ix...; coptic.kw...))
173173
throw(MustNotOverwriteError(permissions))
174174
end
@@ -248,6 +248,16 @@ function _setindex_optic!!(
248248
sub_value
249249
end
250250

251+
# In the merge path, some indices in the slice already have data but not all of them
252+
# (haskey returned false but any(mask) was true). If MustNotOverwrite is set, check
253+
# that the new sub-value doesn't overlap with existing data at the specific sub-indices.
254+
if need_merge && permissions isa MustNotOverwrite && grown_sub_value isa PartialArray
255+
existing_mask = view(pa.mask, coptic.ix...; coptic.kw...)
256+
if any(existing_mask .& grown_sub_value.mask)
257+
throw(MustNotOverwriteError(permissions))
258+
end
259+
end
260+
251261
return if need_merge
252262
new_pa = BangBang.setindex!!(copy(pa), grown_sub_value, coptic.ix...; coptic.kw...)
253263
_merge(pa, new_pa, Val(false))

test/debug_utils.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,34 @@ end
3636
test_model_can_run_but_fails_check(buggy_demo_model())
3737
end
3838

39+
@testset "different sub-indices of the same slice" begin
40+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1321
41+
@model function demo_slice_subindices()
42+
x = Vector{Float64}(undef, 2)
43+
x[1:2] .~ Normal()
44+
return x
45+
end
46+
@test check_model(demo_slice_subindices(); error_on_failure=true)
47+
48+
# Same sub-index twice should still fail
49+
@model function buggy_slice_subindices()
50+
x = Vector{Float64}(undef, 2)
51+
x[1:2][1] ~ Normal()
52+
x[1:2][1] ~ Normal()
53+
return x
54+
end
55+
test_model_can_run_but_fails_check(buggy_slice_subindices())
56+
57+
# Slices of slices
58+
@model function buggy_slice_subindices2()
59+
x = Vector{Float64}(undef, 3)
60+
x[1:3][1:2] ~ MvNormal(zeros(2), I)
61+
x[1:3][2:3] ~ MvNormal(zeros(2), I)
62+
return x
63+
end
64+
test_model_can_run_but_fails_check(buggy_slice_subindices2())
65+
end
66+
3967
@testset "submodel" begin
4068
@model ModelInner() = x ~ Normal()
4169
@model function ModelOuterBroken()

test/varnamedtuple.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,8 @@ Base.size(st::SizedThing) = st.size
11011101

11021102
@testset "_setindex_optic!! with MustNotOverwrite" begin
11031103
function test_must_not_overwrite(vnt, value, vn, template)
1104+
# Avoid mutating
1105+
vnt = deepcopy(vnt)
11041106
# Check that calling `_setindex_optic!!` with `MustNotOverwrite` errors if the
11051107
# variable already exists, and that it works if it doesn't.
11061108
@test_throws MustNotOverwriteError templated_setindex_no_overwrite!!(
@@ -1174,6 +1176,45 @@ Base.size(st::SizedThing) = st.size
11741176
)
11751177
test_must_not_overwrite(vnt, 2.0, @varname(x.a), NoTemplate())
11761178
end
1179+
1180+
@testset "different sub-indices of the same slice" begin
1181+
# Setting different sub-indices of the same slice should NOT error.
1182+
# This is the bug from issue #1321.
1183+
x = zeros(2)
1184+
vnt = templated_setindex_no_overwrite!!(
1185+
VarNamedTuple(), 1.0, @varname(x[1:2][1]), x
1186+
)
1187+
vnt = templated_setindex_no_overwrite!!(vnt, 2.0, @varname(x[1:2][2]), x)
1188+
@test vnt[@varname(x)] == [1.0, 2.0]
1189+
1190+
# But setting the SAME sub-index twice should still error.
1191+
vnt2 = templated_setindex_no_overwrite!!(
1192+
VarNamedTuple(), 1.0, @varname(x[1:2][1]), x
1193+
)
1194+
test_must_not_overwrite(vnt2, 2.0, @varname(x[1:2][1]), x)
1195+
1196+
# Also test with a larger array and different slices.
1197+
y = zeros(4)
1198+
vnt3 = templated_setindex_no_overwrite!!(
1199+
VarNamedTuple(), 1.0, @varname(y[1:3][1]), y
1200+
)
1201+
vnt3 = templated_setindex_no_overwrite!!(vnt3, 2.0, @varname(y[1:3][2]), y)
1202+
test_must_not_overwrite(vnt3, [3.0, 4.0], @varname(y[1:3][2:3]), y)
1203+
vnt3 = templated_setindex_no_overwrite!!(vnt3, 3.0, @varname(y[1:3][3]), y)
1204+
@test vnt3[@varname(y[1:3])] == [1.0, 2.0, 3.0]
1205+
# Setting an already-set sub-index, or slice, should error.
1206+
test_must_not_overwrite(vnt3, 4.0, @varname(y[1:3][2]), y)
1207+
test_must_not_overwrite(vnt3, [4.0, 5.0], @varname(y[2:3][1:2]), y)
1208+
1209+
# Also try with indices of different slices
1210+
z = zeros(3)
1211+
vnt4 = templated_setindex_no_overwrite!!(
1212+
VarNamedTuple(), [1.0, 2.0], @varname(z[1:2]), z
1213+
)
1214+
test_must_not_overwrite(vnt4, 3.0, @varname(z[2:3][1]), z)
1215+
# but you can set it if it's not overlapping
1216+
templated_setindex_no_overwrite!!(vnt4, 3.0, @varname(z[2:3][2]), z)
1217+
end
11771218
end
11781219

11791220
@testset "subset" begin

0 commit comments

Comments
 (0)