@@ -7,7 +7,7 @@ function synth(cfg::SHTnsCfg{TR,T,N}, qlm::CuVector{ComplexF64}) where {TR,T,N}
77 nx = cfg. shtype. contiguous_phi ? cfg. nphi : cfg. nlat_padded
88 ny = cfg. shtype. contiguous_phi ? cfg. nlat_padded : cfg. nphi
99
10- v = CuMatrix {Tv} (undef, nx, ny)
10+ v = cfg . howmany > 1 ? CuArray {Tv} (undef, nx, ny, cfg . howmany) : CuMatrix {Tv} (undef, nx, ny)
1111 synth! (cfg, qlm, v)
1212 return v
1313end
@@ -21,8 +21,8 @@ function synth(cfg::SHTnsCfg{TR,T,N}, slm::CuVector{ComplexF64}, tlm::CuVector{C
2121 nx = cfg. shtype. contiguous_phi ? cfg. nphi : cfg. nlat_padded
2222 ny = cfg. shtype. contiguous_phi ? cfg. nlat_padded : cfg. nphi
2323
24- utheta = CuMatrix {Tv} (undef, nx, ny)
25- uphi = CuMatrix {Tv} (undef, nx, ny)
24+ utheta = cfg . howmany > 1 ? CuArray {Tv} (undef, nx, ny, cfg . howmany) : CuMatrix {Tv} (undef, nx, ny)
25+ uphi = cfg . howmany > 1 ? CuArray {Tv} (undef, nx, ny, cfg . howmany) : CuMatrix {Tv} (undef, nx, ny)
2626 synth! (cfg, slm, tlm, utheta, uphi)
2727 return utheta, uphi
2828end
@@ -35,27 +35,27 @@ function synth(cfg::SHTnsCfg{TR,T,N}, qlm::CuVector{ComplexF64}, slm::CuVector{C
3535 nx = cfg. shtype. contiguous_phi ? cfg. nphi : cfg. nlat_padded
3636 ny = cfg. shtype. contiguous_phi ? cfg. nlat_padded : cfg. nphi
3737
38- ur = CuMatrix {Tv} (undef, nx, ny)
39- utheta = CuMatrix {Tv} (undef, nx, ny)
40- uphi = CuMatrix {Tv} (undef, nx, ny)
38+ ur = cfg . howmany > 1 ? CuArray {Tv} (undef, nx, ny, cfg . howmany) : CuMatrix {Tv} (undef, nx, ny)
39+ utheta = cfg . howmany > 1 ? CuArray {Tv} (undef, nx, ny, cfg . howmany) : CuMatrix {Tv} (undef, nx, ny)
40+ uphi = cfg . howmany > 1 ? CuArray {Tv} (undef, nx, ny, cfg . howmany) : CuMatrix {Tv} (undef, nx, ny)
4141 synth! (cfg, qlm, slm, tlm, ur, utheta, uphi)
4242 return ur, utheta, uphi
4343end
4444
4545
46- function synth! (cfg:: SHTnsCfg{Real,T,N} , qlm:: CuVector{ComplexF64} , v:: CuMatrix {Float64} ) where {T,N}
46+ function synth! (cfg:: SHTnsCfg{Real,T,N} , qlm:: CuVector{ComplexF64} , v:: CuArray {Float64} ) where {T,N}
4747 @assert cfg. shtype. gpu
4848 cu_SH_to_spat (cfg. cfg, qlm, v, cfg. lmax)
4949 return v
5050end
5151
52- function synth! (cfg:: SHTnsCfg{Real,T,N} , slm:: CuVector{ComplexF64} , tlm:: CuVector{ComplexF64} , utheta:: Tv , uphi:: Tv ) where {T,N,Tv<: CuMatrix {Float64} }
52+ function synth! (cfg:: SHTnsCfg{Real,T,N} , slm:: CuVector{ComplexF64} , tlm:: CuVector{ComplexF64} , utheta:: Tv , uphi:: Tv ) where {T,N,Tv<: CuArray {Float64} }
5353 @assert cfg. shtype. gpu
5454 cu_SHsphtor_to_spat (cfg. cfg, slm, tlm, utheta, uphi, cfg. lmax)
5555 return utheta, uphi
5656end
5757
58- function synth! (cfg:: SHTnsCfg{Real,T,N} , qlm:: CuVector{ComplexF64} , slm:: CuVector{ComplexF64} , tlm:: CuVector{ComplexF64} , ur:: Tv , utheta:: Tv , uphi:: Tv ) where {T,N,Tv<: CuMatrix {Float64} }
58+ function synth! (cfg:: SHTnsCfg{Real,T,N} , qlm:: CuVector{ComplexF64} , slm:: CuVector{ComplexF64} , tlm:: CuVector{ComplexF64} , ur:: Tv , utheta:: Tv , uphi:: Tv ) where {T,N,Tv<: CuArray {Float64} }
5959 @assert cfg. shtype. gpu
6060 cu_SHqst_to_spat (cfg. cfg, qlm, slm, tlm, ur, utheta, uphi, cfg. lmax)
6161 return ur, utheta, uphi
0 commit comments