Skip to content

Commit 34924cb

Browse files
committed
transform urandn in input
rather than in body of sample_ζresid_norm
1 parent a76b5eb commit 34924cb

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/elbo.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,8 @@ function sample_ζresid_norm(rng::Random.AbstractRNG, ζP::AbstractVector, ζMs:
393393
# get_concrete(ComponentArrayInterpreter(
394394
# P = (n_MC, n_θP), Ms = (n_MC, n_θM, n_site_batch)))
395395
# end
396-
urandn = _create_randn(rng, CA.getdata(ζP), n_MC, n_θP + n_θMs)
396+
#urandn = _create_randn(rng, CA.getdata(ζP), n_MC, n_θP + n_θMs)
397+
urandn = _create_randn(rng, CA.getdata(ζP), n_θP + n_θMs, n_MC)
397398
sample_ζresid_norm(urandn, CA.getdata(ζP), CA.getdata(ζMs), args...;
398399
cor_ends, int_unc=get_concrete(int_unc))
399400
end
@@ -405,7 +406,6 @@ function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM,
405406
) where {T,TP<:AbstractVector{T},TM<:AbstractMatrix{T}}
406407
ϕuncc = int_unc(CA.getdata(ϕunc))
407408
n_θP, n_θMs, (n_θM, n_batch) = length(ζP), length(ζMs), size(ζMs)
408-
n_MC = size(urandn, 1) # TODO transform urandn
409409
# do not create a UpperTriangular Matrix of an AbstractGÜUArray in transformU_cholesky1
410410
ρsP = isempty(ϕuncc.ρsP) ? similar(ϕuncc.ρsP) : ϕuncc.ρsP # required by zygote
411411
UP = transformU_block_cholesky1(ρsP, cor_ends.P)
@@ -421,8 +421,9 @@ function sample_ζresid_norm(urandn::AbstractMatrix, ζP::TP, ζMs::TM,
421421
# need to construct full matrix for CUDA
422422
= _create_blockdiag(UP, UM, σP, σMs, n_batch)
423423
σ = diag(Uσ) # elements of the diagonal: standard deviations
424-
ζ_resids_parfirst =' * urandn' # n_par x n_MC
425-
#ζ_resids_parfirst = urandn * Uσ # n_MC x n_par
424+
n_MC = size(urandn, 2) # TODO transform urandn
425+
ζ_resids_parfirst =' * urandn # n_par x n_MC
426+
#ζ_resids_parfirst = urandn' *# n_MC x n_par
426427
ζP_resids = ζ_resids_parfirst[1:n_θP, :]
427428
ζMs_parfirst_resids = reshape(ζ_resids_parfirst[(n_θP+1):end, :], n_θM, n_batch, n_MC)
428429
ζP_resids, ζMs_parfirst_resids, σ

0 commit comments

Comments
 (0)