Skip to content

Commit 9df6548

Browse files
committed
allow more than 31 bandpowers on GPU
1 parent 2c67610 commit 9df6548

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

src/CMBLensing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module CMBLensing
44
using Adapt
55
using Base.Broadcast: AbstractArrayStyle, ArrayStyle, Broadcasted,
66
DefaultArrayStyle, preprocess_args, Style, result_style, Unknown
7-
using Base.Iterators: flatten, product, repeated, cycle, countfrom, peel
7+
using Base.Iterators: flatten, product, repeated, cycle, countfrom, peel, partition
88
using Base.Threads
99
using Base: @kwdef, @propagate_inbounds, Bottom, OneTo, showarg, show_datatype,
1010
show_default, show_vector, typed_vcat, typename

src/flat_fields.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ function Cℓ_to_Cov(::Val{:I}, proj::ProjLambert{T}, (Cℓ, ℓedges, θname)::
404404
@eval Main let ℓedges=$((T.(ℓedges))...,), C₀=$C₀
405405
ParamDependentOp(function (;$θname=ones($T,length(ℓedges)-1),_...)
406406
_A = $preprocess.(Ref((nothing,C₀.metadata)), $T.($ensure1d($θname)))
407-
Diagonal(FlatFourier($bandpower_rescale(ℓedges, C₀.ℓmag, C₀.arr, _A...), C₀.metadata))
407+
Diagonal(FlatFourier($bandpower_rescale!(ℓedges, C₀.ℓmag, copy(C₀.arr), _A...), C₀.metadata))
408408
end)
409409
end
410410
end
@@ -414,24 +414,35 @@ function Cℓ_to_Cov(::Val{:P}, proj::ProjLambert{T}, (CℓEE, ℓedges, θname)
414414
ParamDependentOp(function (;$θname=ones($T,length(ℓedges)-1),_...)
415415
_E = $preprocess.(Ref((nothing,C₀.metadata)), $T.($ensure1d($θname)))
416416
_B = $preprocess.(Ref((nothing,C₀.metadata)), one.($T.($ensure1d($θname))))
417-
Diagonal(FlatEBFourier($bandpower_rescale(ℓedges, C₀.ℓmag, C₀.El, _E...), C₀.Bl .* _B[1], C₀.metadata))
417+
Diagonal(FlatEBFourier($bandpower_rescale!(ℓedges, C₀.ℓmag, copy(C₀.El), _E...), C₀.Bl .* _B[1], C₀.metadata))
418418
end)
419419
end
420420
end
421-
# cant reliably get Zygote's gradients to work through these
422-
# broadcasts, which on GPU use ForwardDiff, so write the adjoint by
423-
# hand for now. likely more performant, in any case.
424-
function bandpower_rescale(ℓedges, ℓ, Cℓ, A...)
425-
length(A)==length(ℓedges)-1 || error("Expected $(length(ℓedges)-1) bandpower parameters, not $(length(A)).")
421+
# this is written weird because the stuff inside the broadcast! needs
422+
# to work as a GPU kernel
423+
function bandpower_rescale!(ℓedges, ℓ, Cℓ, A...)
424+
length(A)==length(ℓedges)-1 || error("Expected $(length(ℓedges)-1) bandpower parameters, got $(length(A)).")
426425
eltype(A[1]) <: Real || error("Bandpower parameters must be real numbers.")
427-
broadcast(ℓ, Cℓ, A...) do ℓ, Cℓ, A...
428-
for i=1:length(ℓedges)-1
429-
(ℓedges[i] << ℓedges[i+1]) && return A[i] * Cℓ
426+
if length(A)>30
427+
# if more than 30 bandpowers, we need to chunk the rescaling
428+
# because of a maximum argument limit of CUDA kernels
429+
for p in partition(1:length(A), 30)
430+
bandpower_rescale!(ℓedges[p.start:(p.stop+1)], ℓ, Cℓ, A[p]...)
431+
end
432+
else
433+
broadcast!(Cℓ, ℓ, Cℓ, A...) do ℓ, Cℓ, A...
434+
for i=1:length(ℓedges)-1
435+
(ℓedges[i] << ℓedges[i+1]) && return A[i] * Cℓ
436+
end
437+
return Cℓ
430438
end
431-
return Cℓ
432439
end
440+
Cℓ
433441
end
434-
@adjoint function bandpower_rescale(ℓedges, ℓ, Cℓ, A...)
442+
# cant reliably get Zygote's gradients to work through these
443+
# broadcasts, which on GPU use ForwardDiff, so write the adjoint by
444+
# hand for now. likely more performant, in any case.
445+
@adjoint function bandpower_rescale!(ℓedges, ℓ, Cℓ, A...)
435446
function back(Δ)
436447
= map(1:length(A)) do i
437448
sum(
@@ -444,7 +455,7 @@ end
444455
end
445456
(nothing, nothing, nothing, Ā...)
446457
end
447-
bandpower_rescale(ℓedges, ℓ, Cℓ, A...), back
458+
bandpower_rescale!(ℓedges, ℓ, Cℓ, A...), back
448459
end
449460

450461

0 commit comments

Comments
 (0)