@@ -8,7 +8,7 @@ using BangBang
88using Accessors
99using .. DynamicPPL: _compose_no_identity
1010
11- export VarNamedTuple
11+ export VarNamedTuple, vnt_size
1212
1313# We define our own getindex, setindex!!, and haskey functions, which we use to
1414# get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be
@@ -81,6 +81,17 @@ const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.Concretize
8181_unwrap_concretized_slice(cs:: AbstractPPL.ConcretizedSlice ) = cs. range
8282_unwrap_concretized_slice(x:: Union{Integer,AbstractUnitRange,Colon} ) = x
8383
84+ """
85+ vnt_size(x)
86+
87+ Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`.
88+
89+ By default, this falls back onto `Base.size`, but can be overloaded for custom types.
90+ This notion of type is used to determine whether a value can be set into a `PartialArray`
91+ as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details.
92+ """
93+ vnt_size(x) = size(x)
94+
8495"""
8596 ArrayLikeBlock{T,I}
8697
@@ -156,11 +167,12 @@ Like `Base.Array`s, `PartialArray`s have a well-defined, compile-time-known elem
156167
157168One can set values in a `PartialArray` either element-by-element, or with ranges like
158169`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set
159- must either be an `AbstractArray` or otherwise something for which `size(value)` is defined,
160- and the size mathces the range. If the value is an `AbstractArray`, the elements are copied
161- individually, but if it is not, the value is stored as a block, that takes up the whole
162- range, e.g. `[1:3,2]`, but is only a single object. Getting such a block-value must be done
163- with the exact same range of indices, otherwise an error is thrown.
170+ must either be an `AbstractArray` or otherwise something for which `vnt_size(value)` or
171+ `Base.size(value)` (which `vnt_size` falls back onto) is defined, and the size matches the
172+ range. If the value is an `AbstractArray`, the elements are copied individually, but if it
173+ is not, the value is stored as a block, that takes up the whole range, e.g. `[1:3,2]`, but
174+ is only a single object. Getting such a block-value must be done with the exact same range
175+ of indices, otherwise an error is thrown.
164176
165177If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check
166178if, after the new value has been set, the element type can be made more concrete. If so,
@@ -594,7 +606,7 @@ The value only depends on the types of the arguments, and should be constant pro
594606function _needs_arraylikeblock(value, inds:: Vararg{INDEX_TYPES} )
595607 return _is_multiindex(inds) &&
596608 ! isa(value, AbstractArray) &&
597- hasmethod(size , Tuple{typeof(value)})
609+ hasmethod(vnt_size , Tuple{typeof(value)})
598610end
599611
600612function _setindex!!(pa:: PartialArray , value, inds:: Vararg{INDEX_TYPES} )
@@ -610,11 +622,11 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES})
610622 new_data = pa. data
611623 if _needs_arraylikeblock(value, inds... )
612624 inds_size = reduce((x, y) -> tuple(x... , y... ), map(size, inds))
613- if size (value) != inds_size
625+ if vnt_size (value) != inds_size
614626 throw(
615627 DimensionMismatch(
616- " Assigned value has size $(size (value)) , which does not match the " *
617- " size implied by the indices $(map(x -> _length_needed(x), inds)) ." ,
628+ " Assigned value has size $(vnt_size (value)) , which does not match " *
629+ " the size implied by the indices $(map(x -> _length_needed(x), inds)) ." ,
618630 ),
619631 )
620632 end
0 commit comments