@@ -261,42 +261,6 @@ invlink_transform(dist) = inverse(link_transform(dist))
261261# Helper functions for vectorize/reconstruct values #
262262# ####################################################
263263
264- """
265- UnwrapSingletonTransform(input_size::InSize)
266-
267- A transformation that unwraps a singleton array, returning a scalar.
268-
269- The `input_size` field is the expected size of the input. In practice this only determines
270- the number of indices, since all dimensions must be 1 for a singleton. `input_size` is used
271- to check the validity of the input, but also to determine the correct inverse operation.
272-
273- By default `input_size` is `(1,)`, in which case `tovec` is the inverse.
274- """
275- struct UnwrapSingletonTransform{InSize} <: Bijectors.Bijector
276- input_size:: InSize
277- end
278-
279- UnwrapSingletonTransform () = UnwrapSingletonTransform ((1 ,))
280-
281- function (f:: UnwrapSingletonTransform )(x)
282- if size (x) != f. input_size
283- throw (DimensionMismatch (" Expected input of size $(f. input_size) , got $(size (x)) " ))
284- end
285- return only (x)
286- end
287-
288- function Bijectors. with_logabsdet_jacobian (f:: UnwrapSingletonTransform , x)
289- return f (x), zero (LogProbType)
290- end
291-
292- function Bijectors. with_logabsdet_jacobian (
293- inv_f:: Bijectors.Inverse{<:UnwrapSingletonTransform} , x
294- )
295- f = inv_f. orig
296- result = reshape ([x], f. input_size)
297- return result, zero (LogProbType)
298- end
299-
300264"""
301265 ReshapeTransform(input_size::InSize, output_size::OutSize)
302266
@@ -370,14 +334,26 @@ function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y)
370334 )
371335end
372336
337+ struct Only end
338+ struct NotOnly end
339+ (:: Only )(x) = x[]
340+ (:: NotOnly )(y) = [y]
341+ function Bijectors. with_logabsdet_jacobian (:: Only , x:: AbstractVector{T} ) where {T<: Real }
342+ return (x[], zero (T))
343+ end
344+ Bijectors. with_logabsdet_jacobian (:: Only , x:: AbstractVector ) = (x[], zero (LogProbType))
345+ Bijectors. inverse (:: Only ) = NotOnly ()
346+ Bijectors. with_logabsdet_jacobian (:: NotOnly , y:: T ) where {T<: Real } = ([y], zero (T))
347+ Bijectors. with_logabsdet_jacobian (:: NotOnly , y) = ([y], zero (LogProbType))
348+
373349"""
374350 from_vec_transform(x)
375351
376352Return the transformation from the vector representation of `x` to original representation.
377353"""
378354from_vec_transform (x:: AbstractArray ) = from_vec_transform_for_size (size (x))
379355from_vec_transform (C:: Cholesky ) = ToChol (C. uplo) ∘ ReshapeTransform (size (C. UL))
380- from_vec_transform (:: Real ) = UnwrapSingletonTransform ()
356+ from_vec_transform (:: Real ) = Only ()
381357
382358"""
383359 from_vec_transform_for_size(sz::Tuple)
@@ -395,7 +371,7 @@ Return the transformation from the vector representation of a realization from
395371distribution `dist` to the original representation compatible with `dist`.
396372"""
397373from_vec_transform (dist:: Distribution ) = from_vec_transform_for_size (size (dist))
398- from_vec_transform (:: UnivariateDistribution ) = UnwrapSingletonTransform ()
374+ from_vec_transform (:: UnivariateDistribution ) = Only ()
399375from_vec_transform (dist:: LKJCholesky ) = ToChol (dist. uplo) ∘ ReshapeTransform (size (dist))
400376
401377struct ProductNamedTupleUnvecTransform{names,T<: NamedTuple{names} }
441417# This function returns the length of the vector that the function from_vec_transform
442418# expects. This helps us determine which segment of a concatenated vector belongs to which
443419# variable.
444- _input_length (from_vec_trfm :: UnwrapSingletonTransform ) = 1
420+ _input_length (:: Only ) = 1
445421_input_length (from_vec_trfm:: ReshapeTransform ) = prod (from_vec_trfm. output_size)
446422function _input_length (trfm:: ProductNamedTupleUnvecTransform )
447423 return sum (_input_length ∘ from_vec_transform, values (trfm. dists))
@@ -477,18 +453,9 @@ function from_linked_vec_transform(dist::Distribution)
477453 f_vec = from_vec_transform (inverse (f_invlink), size (dist))
478454 return f_invlink ∘ f_vec
479455end
480-
481- # UnivariateDistributions need to be handled as a special case, because size(dist) is (),
482- # which makes the usual machinery think we are dealing with a 0-dim array, whereas in
483- # actuality we are dealing with a scalar.
484- # TODO (mhauru) Hopefully all this can go once the old Gibbs sampler is removed and
485- # VarNamedVector takes over from Metadata.
486456function from_linked_vec_transform (dist:: UnivariateDistribution )
487- f_invlink = invlink_transform (dist)
488- f_vec = from_vec_transform (inverse (f_invlink), size (dist))
489- f_combined = f_invlink ∘ f_vec
490- sz = Bijectors. output_size (f_combined, size (dist))
491- return UnwrapSingletonTransform (sz) ∘ f_combined
457+ # This is a performance optimisation
458+ return Only () ∘ invlink_transform (dist)
492459end
493460function from_linked_vec_transform (dist:: Distributions.ProductNamedTupleDistribution )
494461 return invlink_transform (dist)
0 commit comments