@@ -404,7 +404,7 @@ function Cℓ_to_Cov(::Val{:I}, proj::ProjLambert{T}, (Cℓ, ℓedges, θname)::
404404 @eval Main let ℓedges= $ ((T .(ℓedges)). .. ,), C₀= $ C₀
405405 ParamDependentOp (function (;$ θname= ones ($ T,length (ℓedges)- 1 ),_... )
406406 _A = $ preprocess .(Ref ((nothing ,C₀. metadata)), $ T .($ ensure1d ($ θname)))
407- Diagonal (FlatFourier ($ bandpower_rescale (ℓedges, C₀. ℓmag, C₀. arr, _A... ), C₀. metadata))
407+ Diagonal (FlatFourier ($ bandpower_rescale! (ℓedges, C₀. ℓmag, copy ( C₀. arr) , _A... ), C₀. metadata))
408408 end )
409409 end
410410end
@@ -414,24 +414,35 @@ function Cℓ_to_Cov(::Val{:P}, proj::ProjLambert{T}, (CℓEE, ℓedges, θname)
414414 ParamDependentOp (function (;$ θname= ones ($ T,length (ℓedges)- 1 ),_... )
415415 _E = $ preprocess .(Ref ((nothing ,C₀. metadata)), $ T .($ ensure1d ($ θname)))
416416 _B = $ preprocess .(Ref ((nothing ,C₀. metadata)), one .($ T .($ ensure1d ($ θname))))
417- Diagonal (FlatEBFourier ($ bandpower_rescale (ℓedges, C₀. ℓmag, C₀. El, _E... ), C₀. Bl .* _B[1 ], C₀. metadata))
417+ Diagonal (FlatEBFourier ($ bandpower_rescale! (ℓedges, C₀. ℓmag, copy ( C₀. El) , _E... ), C₀. Bl .* _B[1 ], C₀. metadata))
418418 end )
419419 end
420420end
421- # cant reliably get Zygote's gradients to work through these
422- # broadcasts, which on GPU use ForwardDiff, so write the adjoint by
423- # hand for now. likely more performant, in any case.
424- function bandpower_rescale (ℓedges, ℓ, Cℓ, A... )
425- length (A)== length (ℓedges)- 1 || error (" Expected $(length (ℓedges)- 1 ) bandpower parameters, not $(length (A)) ." )
421+ # this is written weird because the stuff inside the broadcast! needs
422+ # to work as a GPU kernel
423+ function bandpower_rescale! (ℓedges, ℓ, Cℓ, A... )
424+ length (A)== length (ℓedges)- 1 || error (" Expected $(length (ℓedges)- 1 ) bandpower parameters, got $(length (A)) ." )
426425 eltype (A[1 ]) <: Real || error (" Bandpower parameters must be real numbers." )
427- broadcast (ℓ, Cℓ, A... ) do ℓ, Cℓ, A...
428- for i= 1 : length (ℓedges)- 1
429- (ℓedges[i] < ℓ < ℓedges[i+ 1 ]) && return A[i] * Cℓ
426+ if length (A)> 30
427+ # if more than 30 bandpowers, we need to chunk the rescaling
428+ # because of a maximum argument limit of CUDA kernels
429+ for p in partition (1 : length (A), 30 )
430+ bandpower_rescale! (ℓedges[p. start: (p. stop+ 1 )], ℓ, Cℓ, A[p]. .. )
431+ end
432+ else
433+ broadcast! (Cℓ, ℓ, Cℓ, A... ) do ℓ, Cℓ, A...
434+ for i= 1 : length (ℓedges)- 1
435+ (ℓedges[i] < ℓ < ℓedges[i+ 1 ]) && return A[i] * Cℓ
436+ end
437+ return Cℓ
430438 end
431- return Cℓ
432439 end
440+ Cℓ
433441end
434- @adjoint function bandpower_rescale (ℓedges, ℓ, Cℓ, A... )
442+ # cant reliably get Zygote's gradients to work through these
443+ # broadcasts, which on GPU use ForwardDiff, so write the adjoint by
444+ # hand for now. likely more performant, in any case.
445+ @adjoint function bandpower_rescale! (ℓedges, ℓ, Cℓ, A... )
435446 function back (Δ)
436447 Ā = map (1 : length (A)) do i
437448 sum (
444455 end
445456 (nothing , nothing , nothing , Ā... )
446457 end
447- bandpower_rescale (ℓedges, ℓ, Cℓ, A... ), back
458+ bandpower_rescale! (ℓedges, ℓ, Cℓ, A... ), back
448459end
449460
450461
0 commit comments