Skip to content

Commit f7e2c25

Browse files
committed
Merge branch 'master' of github.com:marius311/CMBLensing.jl
2 parents 395883f + ca8936d commit f7e2c25

File tree

5 files changed

+15
-7
lines changed

5 files changed

+15
-7
lines changed

src/field_tuples.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ getindex(D::DiagOp{<:FieldTuple}, i::Int, j::Int) = (i==j) ? D.diag[:][i] : diag
4242

4343
### array interface
4444
size(f::FieldTuple) = (sum(map(length, f.fs)),)
45-
copyto!(dest::FT, src::FT) where {FT<:FieldTuple} = (map(copyto!,dest.fs,src.fs); dest)
45+
copyto!(dest::FieldTuple{B}, src::FieldTuple{B}) where {B} = (map(copyto!,dest.fs,src.fs); dest)
4646
iterate(ft::FieldTuple, args...) = iterate(ft.fs, args...)
4747
getindex(f::FieldTuple, i::Union{Int,UnitRange}) = getindex(f.fs, i)
4848
fill!(ft::FieldTuple, x) = (map(f->fill!(f,x), ft.fs); ft)

src/flat_s0.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ function similar(f::F,::Type{T},dims::Dims) where {P,F<:FlatS0{P},T<:Number}
5858
@assert size(f)==dims "Tried to make a field similar to $F but dims should have been $(size(f)), not $dims."
5959
basetype(F){P}(similar(firstfield(f),T))
6060
end
61+
copyto!(dst::Field{B,S0,P}, src::Field{B,S0,P}) where {B,P} = (copyto!(firstfield(dst),firstfield(src)); dst)
6162

6263

6364
### broadcasting

src/gpu.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ BroadcastStyle(::FlatS0Style{F,Array}, ::FlatS0Style{F,CuArray}) where {P,F<:Fla
3838
# the generic versions of these trigger scalar indexing of CUDA, so provide
3939
# specialized versions:
4040

41-
function copyto!(dst::F, src::F) where {F<:CuFlatS0}
42-
copyto!(firstfield(dst),firstfield(src))
43-
dst
44-
end
4541
pinv(D::Diagonal{T,<:CuFlatS0}) where {T} = Diagonal(@. ifelse(isfinite(inv(D.diag)), inv(D.diag), $zero(T)))
4642
inv(D::Diagonal{T,<:CuFlatS0}) where {T} = any(Array((D.diag.==0)[:])) ? throw(SingularException(-1)) : Diagonal(inv.(D.diag))
4743
fill!(f::CuFlatS0, x) = (fill!(firstfield(f),x); f)

src/posterior.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function δlnP_δϕ(
134134
end
135135

136136
# gQD for the real data
137-
gQD = get_gQD(Lϕ, ds, f_wf_guess)
137+
gQD_future = @spawnat :any get_gQD(Lϕ, ds, f_wf_guess)
138138

139139
# gQD for several simulated datasets
140140
if use_previous_MF
@@ -149,6 +149,7 @@ function δlnP_δϕ(
149149
end
150150

151151
# final total posterior gradient, including gradient of the prior
152+
gQD = fetch(gQD_future)
152153
g = gQD.g --\ϕ
153154

154155
if return_state

src/specialops.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,17 @@ function (L::ParamDependentOp)(θ::NamedTuple)
218218
# filtering out non-dependent parameters disabled until I can find a fix to:
219219
# https://discourse.julialang.org/t/can-zygote-do-derivatives-w-r-t-keyword-arguments-which-get-captured-in-kwargs/34553/8
220220
# dependent_θ = filter(((k,_),)->k in L.parameters, pairs(θ))
221-
L.recompute_function(;θ...)
221+
= L.recompute_function(;θ...)
222+
ifisa typeof(L.op)
223+
224+
else
225+
# if L got adapt'ed to CuArray since this op was created,
226+
# L.op will be GPU-backed, but depending on how
227+
# recompute_function is written, recompute_function may
228+
# still return something CPU-backed. in that case, copy it
229+
# to GPU here
230+
copyto!(similar(L.op), Lθ)
231+
end
222232
else
223233
L.op
224234
end

0 commit comments

Comments
 (0)