Skip to content

Commit 9e53c23

Browse files
committed
add batched transforms, include real or complex transform in cfg type
1 parent 7158ae8 commit 9e53c23

File tree

8 files changed

+228
-157
lines changed

8 files changed

+228
-157
lines changed

ext/SHTnsCUDAExt/analys.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,83 @@
1-
function analys(cfg::SHTnsCfg, v::CuMatrix{Float64})
1+
function analys(cfg::SHTnsCfg, v::CuArray{Float64})
22
@assert cfg.shtype.gpu
33
@assert cfg.nlat != 0
4-
qlm = CuVector{ComplexF64}(undef, cfg.nlm)
4+
qlm = CuVector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
55
analys!(cfg, copy(v), qlm)
66
return qlm
77
end
88

9-
function analys(cfg::SHTnsCfg, v::CuMatrix{ComplexF64})
9+
function analys(cfg::SHTnsCfg, v::CuArray{ComplexF64})
1010
@assert cfg.shtype.gpu
1111
@assert cfg.nlat != 0
12-
qlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
12+
qlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
1313
analys!(cfg, copy(v), qlm)
1414
return qlm
1515
end
1616

17-
function analys(cfg::SHTnsCfg, utheta::CuMatrix{Float64}, uphi::CuMatrix{Float64})
17+
function analys(cfg::SHTnsCfg, utheta::CuArray{Float64}, uphi::CuArray{Float64})
1818
@assert cfg.shtype.gpu
1919
@assert cfg.nlat != 0
20-
slm = CuVector{ComplexF64}(undef, cfg.nlm)
21-
tlm = CuVector{ComplexF64}(undef, cfg.nlm)
20+
slm = CuVector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
21+
tlm = CuVector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
2222
analys!(cfg, copy(utheta), copy(uphi), slm, tlm)
2323
return slm, tlm
2424
end
2525

26-
function analys(cfg::SHTnsCfg, utheta::CuMatrix{ComplexF64}, uphi::CuMatrix{ComplexF64})
26+
function analys(cfg::SHTnsCfg, utheta::CuArray{ComplexF64}, uphi::CuArray{ComplexF64})
2727
@assert cfg.shtype.gpu
2828
@assert cfg.nlat != 0
29-
slm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
30-
tlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
29+
slm = CuVector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
30+
tlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
3131
analys!(cfg, copy(utheta), copy(uphi), slm, tlm)
3232
return slm, tlm
3333
end
3434

35-
function analys(cfg::SHTnsCfg, ur::CuMatrix{Float64}, utheta::CuMatrix{Float64}, uphi::CuMatrix{Float64})
35+
function analys(cfg::SHTnsCfg, ur::CuArray{Float64}, utheta::CuArray{Float64}, uphi::CuArray{Float64})
3636
@assert cfg.shtype.gpu
3737
@assert cfg.nlat != 0
38-
qlm = CuVector{ComplexF64}(undef, cfg.nlm)
39-
slm = CuVector{ComplexF64}(undef, cfg.nlm)
40-
tlm = CuVector{ComplexF64}(undef, cfg.nlm)
38+
qlm = CuVector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
39+
slm = CuVector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
40+
tlm = CuVector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
4141
analys!(cfg, copy(ur), copy(utheta), copy(uphi), qlm, slm, tlm)
4242
return qlm, slm, tlm
4343
end
4444

45-
function analys(cfg::SHTnsCfg, ur::CuMatrix{ComplexF64}, utheta::CuMatrix{ComplexF64}, uphi::CuMatrix{ComplexF64})
45+
function analys(cfg::SHTnsCfg, ur::CuArray{ComplexF64}, utheta::CuArray{ComplexF64}, uphi::CuArray{ComplexF64})
4646
@assert cfg.shtype.gpu
4747
@assert cfg.nlat != 0
48-
qlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
49-
slm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
50-
tlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
48+
qlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
49+
slm = CuVector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
50+
tlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
5151
analys!(cfg, copy(ur), copy(utheta), copy(uphi), qlm, slm, tlm)
5252
return qlm, slm, tlm
5353
end
5454

55-
function analys!(cfg::SHTnsCfg, v::CuMatrix{Float64}, qlm::CuVector{ComplexF64})
55+
function analys!(cfg::SHTnsCfg, v::CuArray{Float64}, qlm::CuVector{ComplexF64})
5656
@assert cfg.shtype.gpu
5757
return cu_spat_to_SH(cfg.cfg, v, qlm, cfg.lmax)
5858
end
5959

6060

61-
function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{Float64}}
61+
function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuArray{Float64}}
6262
@assert cfg.shtype.gpu
6363
return cu_spat_to_SHsphtor(cfg.cfg, utheta, uphi, slm, tlm, cfg.lmax)
6464
end
6565

66-
function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{Float64}}
66+
function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuArray{Float64}}
6767
@assert cfg.shtype.gpu
6868
return cu_spat_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm, cfg.lmax)
6969
end
7070

7171
#complex to complex not available for CUDA (status: SHTns v3.7)
7272

73-
# function analys!(cfg::SHTnsCfg, v::CuMatrix{ComplexF64}, qlm::CuVector{ComplexF64})
73+
# function analys!(cfg::SHTnsCfg, v::CuArray{ComplexF64}, qlm::CuVector{ComplexF64})
7474
# return cu_spat_cplx_to_SH(cfg.cfg, v, qlm)
7575
# end
7676

77-
# function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{ComplexF64}}
77+
# function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuArray{ComplexF64}}
7878
# return cu_spat_cplx_to_SHsphtor(cfg.cfg, utheta, uphi, slm, tlm)
7979
# end
8080

81-
# function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{ComplexF64}}
81+
# function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuArray{ComplexF64}}
8282
# return cu_spat_cplx_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
8383
# end

ext/SHTnsCUDAExt/synth.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,61 @@
1-
function synth(cfg::SHTnsCfg, qlm::CuVector{ComplexF64})
1+
function synth(cfg::SHTnsCfg{TR,T,N}, qlm::CuVector{ComplexF64}) where {TR,T,N}
2+
Tv = TR == Real ? Float64 : ComplexF64
23
@assert cfg.shtype.gpu
34
@assert cfg.nlat != 0
4-
@assert length(qlm) == cfg.nlm
5+
@assert length(qlm) == nlm(cfg)*cfg.howmany
56

67
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
78
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
89

9-
v = CuMatrix{Float64}(undef, nx, ny)
10+
v = CuMatrix{Tv}(undef, nx, ny)
1011
synth!(cfg, qlm, v)
1112
return v
1213
end
1314

14-
function synth(cfg::SHTnsCfg, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64})
15+
function synth(cfg::SHTnsCfg{TR,T,N}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {TR,T,N}
16+
Tv = TR == Real ? Float64 : ComplexF64
1517
@assert cfg.shtype.gpu
1618
@assert cfg.nlat != 0
17-
@assert length(slm) == length(tlm) == cfg.nlm
19+
@assert length(slm) == length(tlm) == nlm(cfg)*cfg.howmany
1820

1921
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
2022
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
2123

22-
utheta = CuMatrix{Float64}(undef, nx, ny)
23-
uphi = CuMatrix{Float64}(undef, nx, ny)
24+
utheta = CuMatrix{Tv}(undef, nx, ny)
25+
uphi = CuMatrix{Tv}(undef, nx, ny)
2426
synth!(cfg, slm, tlm, utheta, uphi)
2527
return utheta, uphi
2628
end
2729

28-
function synth(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64})
30+
function synth(cfg::SHTnsCfg{TR,T,N}, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {TR,T,N}
31+
Tv = TR == Real ? Float64 : ComplexF64
2932
@assert cfg.shtype.gpu
3033
@assert cfg.nlat != 0
31-
@assert length(qlm) == length(slm) == length(tlm) == cfg.nlm
34+
@assert length(qlm) == length(slm) == length(tlm) == nlm(cfg)*cfg.howmany
3235
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
3336
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
3437

35-
ur = CuMatrix{Float64}(undef, nx, ny)
36-
utheta = CuMatrix{Float64}(undef, nx, ny)
37-
uphi = CuMatrix{Float64}(undef, nx, ny)
38+
ur = CuMatrix{Tv}(undef, nx, ny)
39+
utheta = CuMatrix{Tv}(undef, nx, ny)
40+
uphi = CuMatrix{Tv}(undef, nx, ny)
3841
synth!(cfg, qlm, slm, tlm, ur, utheta, uphi)
3942
return ur, utheta, uphi
4043
end
4144

4245

43-
function synth!(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, v::CuMatrix{Float64})
46+
function synth!(cfg::SHTnsCfg{Real,T,N}, qlm::CuVector{ComplexF64}, v::CuMatrix{Float64}) where {T,N}
4447
@assert cfg.shtype.gpu
4548
cu_SH_to_spat(cfg.cfg, qlm, v, cfg.lmax)
4649
return v
4750
end
4851

49-
function synth!(cfg::SHTnsCfg, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, utheta::T, uphi::T) where {T<:CuMatrix{Float64}}
52+
function synth!(cfg::SHTnsCfg{Real,T,N}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, utheta::Tv, uphi::Tv) where {T,N,Tv<:CuMatrix{Float64}}
5053
@assert cfg.shtype.gpu
5154
cu_SHsphtor_to_spat(cfg.cfg, slm, tlm, utheta, uphi, cfg.lmax)
5255
return utheta, uphi
5356
end
5457

55-
function synth!(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, ur::T, utheta::T, uphi::T) where {T<:CuMatrix{Float64}}
58+
function synth!(cfg::SHTnsCfg{Real,T,N}, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, ur::Tv, utheta::Tv, uphi::Tv) where {T,N,Tv<:CuMatrix{Float64}}
5659
@assert cfg.shtype.gpu
5760
cu_SHqst_to_spat(cfg.cfg, qlm, slm, tlm, ur, utheta, uphi, cfg.lmax)
5861
return ur, utheta, uphi

src/SHTns.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ end
121121
122122
Configuration of spherical harmonic transform.
123123
"""
124-
mutable struct SHTnsCfg{N<:SHTnsNorm, T<:SHTnsType}
124+
mutable struct SHTnsCfg{TR<:Union{Real,Complex}, N<:SHTnsNorm, T<:SHTnsType}
125125
cfg::Ptr{shtns_info}
126126
norm::N
127127
shtype::T
@@ -140,23 +140,32 @@ mutable struct SHTnsCfg{N<:SHTnsNorm, T<:SHTnsType}
140140
st::Vector{Float64}
141141
nlat_padded::Int
142142
nlm_cplx::Int
143+
howmany::Int
143144
function SHTnsCfg(lmax, mmax, mres, nlat, nphi;
144145
shtype::T=QuickInit(),
145146
norm::N=Orthonormal(),
146147
eps=1e-10,
147148
robert_form=false,
149+
howmany = 1,
150+
transform::Union{Type{Real}, Type{Complex}} = Real
148151
) where {T<:SHTnsType, N<:SHTnsNorm}
149152

150153
_init_checks(shtype, lmax, mmax, mres, nlat, nphi)
151154
cfg = shtns_create(lmax, mmax, mres, norm)
152155
robert_form && shtns_robert_form(cfg,1)
156+
if howmany > 1
157+
@assert transform == Real "Only real transform is supported for batched transforms"
158+
info = unsafe_load(cfg)
159+
spec_dist = transform == Real ? info.nlm : info.nlm_cplx
160+
shtns_set_many(cfg, howmany, spec_dist)
161+
end
153162
shtns_set_grid(cfg, shtype, eps, nlat, nphi)
154163
info = unsafe_load(cfg)
155164
li = Vector{Int}(unsafe_wrap(Vector{Cushort},info.li,Int(info.nlm)))
156165
mi = Vector{Int}(unsafe_wrap(Vector{Cushort},info.mi,Int(info.nlm)))
157166
ct = Vector{Float64}(unsafe_wrap(Vector{Cdouble},info.ct,Int(info.nlat)))
158167
st = Vector{Float64}(unsafe_wrap(Vector{Cdouble},info.st,Int(info.nlat)))
159-
stream = new{N,T}(cfg, norm, shtype, robert_form, info.nlm, info.lmax, info.mmax, info.mres, info.nlat_2, info.nlat, info.nphi, info.nspat, li, mi, ct, st, info.nlat_padded, info.nlm_cplx)
168+
stream = new{transform,N,T}(cfg, norm, shtype, robert_form, info.nlm, info.lmax, info.mmax, info.mres, info.nlat_2, info.nlat, info.nphi, info.nspat, li, mi, ct, st, info.nlat_padded, info.nlm_cplx, howmany)
160169
finalizer(x->shtns_destroy(x.cfg), stream)
161170
return stream
162171
end
@@ -166,6 +175,8 @@ mutable struct SHTnsCfg{N<:SHTnsNorm, T<:SHTnsType}
166175
eps=1e-10,
167176
robert_form=false,
168177
nl_order = 0,
178+
howmany = 1,
179+
transform::Union{Type{Real}, Type{Complex}} = Real
169180
) where {T<:SHTnsType, N<:SHTnsNorm}
170181

171182
@assert lmax > 1
@@ -175,13 +186,19 @@ mutable struct SHTnsCfg{N<:SHTnsNorm, T<:SHTnsType}
175186
cfg = shtns_create(lmax, mmax, mres, norm)
176187
robert_form && shtns_robert_form(cfg,1)
177188
info = unsafe_load(cfg)
189+
if howmany > 1
190+
@assert transform == Real "Only real transform is supported for batched transforms"
191+
info = unsafe_load(cfg)
192+
spec_dist = transform == Real ? info.nlm : info.nlm_cplx
193+
shtns_set_many(cfg, howmany, spec_dist)
194+
end
178195
shtns_set_grid_auto(cfg, shtype, eps, nl_order, Ref(info.nlat), Ref(info.nphi))
179196
info = unsafe_load(cfg)
180197
li = Vector{Int}(unsafe_wrap(Vector{Cushort},info.li,Int(info.nlm)))
181198
mi = Vector{Int}(unsafe_wrap(Vector{Cushort},info.mi,Int(info.nlm)))
182199
ct = Vector{Float64}(unsafe_wrap(Vector{Cdouble},info.ct,Int(info.nlat)))
183200
st = Vector{Float64}(unsafe_wrap(Vector{Cdouble},info.st,Int(info.nlat)))
184-
stream = new{N,T}(cfg, norm, shtype, robert_form, info.nlm, info.lmax, info.mmax, info.mres, info.nlat_2, info.nlat, info.nphi, info.nspat, li, mi, ct, st, info.nlat_padded, info.nlm_cplx)
201+
stream = new{transform,N,T}(cfg, norm, shtype, robert_form, info.nlm, info.lmax, info.mmax, info.mres, info.nlat_2, info.nlat, info.nphi, info.nspat, li, mi, ct, st, info.nlat_padded, info.nlm_cplx, howmany)
185202
finalizer(x->shtns_destroy(x.cfg), stream)
186203
return stream
187204
end

src/analys.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,72 @@
1-
function analys(cfg::SHTnsCfg, v::Matrix{Float64})
1+
function analys(cfg::SHTnsCfg{Real,T,N}, v::Array{Float64}) where {T,N}
22
@assert cfg.nlat != 0
3-
qlm = Vector{ComplexF64}(undef, cfg.nlm)
3+
qlm = Vector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
44
analys!(cfg, copy(v), qlm)
55
return qlm
66
end
77

8-
function analys(cfg::SHTnsCfg, v::Matrix{ComplexF64})
8+
function analys(cfg::SHTnsCfg{Complex,T,N}, v::Array{ComplexF64}) where {T,N}
99
@assert cfg.nlat != 0
10-
qlm = Vector{ComplexF64}(undef, cfg.nlm_cplx)
10+
qlm = Vector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
1111
analys!(cfg, copy(v), qlm)
1212
return qlm
1313
end
1414

15-
function analys(cfg::SHTnsCfg, utheta::Matrix{Float64}, uphi::Matrix{Float64})
15+
function analys(cfg::SHTnsCfg{Real,T,N}, utheta::Array{Float64}, uphi::Array{Float64}) where {T,N}
1616
@assert cfg.nlat != 0
17-
slm = Vector{ComplexF64}(undef, cfg.nlm)
18-
tlm = Vector{ComplexF64}(undef, cfg.nlm)
17+
slm = Vector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
18+
tlm = Vector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
1919
analys!(cfg, copy(utheta), copy(uphi), slm, tlm)
2020
return slm, tlm
2121
end
2222

23-
function analys(cfg::SHTnsCfg, utheta::Matrix{ComplexF64}, uphi::Matrix{ComplexF64})
23+
function analys(cfg::SHTnsCfg{Complex,T,N}, utheta::Array{ComplexF64}, uphi::Array{ComplexF64}) where {T,N}
2424
@assert cfg.nlat != 0
25-
slm = Vector{ComplexF64}(undef, cfg.nlm_cplx)
26-
tlm = Vector{ComplexF64}(undef, cfg.nlm_cplx)
25+
slm = Vector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
26+
tlm = Vector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
2727
analys!(cfg, copy(utheta), copy(uphi), slm, tlm)
2828
return slm, tlm
2929
end
3030

31-
function analys(cfg::SHTnsCfg, ur::Matrix{Float64}, utheta::Matrix{Float64}, uphi::Matrix{Float64})
31+
function analys(cfg::SHTnsCfg{Real,T,N}, ur::Array{Float64}, utheta::Array{Float64}, uphi::Array{Float64}) where {T,N}
3232
@assert cfg.nlat != 0
33-
qlm = Vector{ComplexF64}(undef, cfg.nlm)
34-
slm = Vector{ComplexF64}(undef, cfg.nlm)
35-
tlm = Vector{ComplexF64}(undef, cfg.nlm)
33+
qlm = Vector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
34+
slm = Vector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
35+
tlm = Vector{ComplexF64}(undef, cfg.nlm*cfg.howmany)
3636
analys!(cfg, copy(ur), copy(utheta), copy(uphi), qlm, slm, tlm)
3737
return qlm, slm, tlm
3838
end
3939

40-
function analys(cfg::SHTnsCfg, ur::Matrix{ComplexF64}, utheta::Matrix{ComplexF64}, uphi::Matrix{ComplexF64})
40+
function analys(cfg::SHTnsCfg{Complex,T,N}, ur::Array{ComplexF64}, utheta::Array{ComplexF64}, uphi::Array{ComplexF64}) where {T,N}
4141
@assert cfg.nlat != 0
42-
qlm = Vector{ComplexF64}(undef, cfg.nlm_cplx)
43-
slm = Vector{ComplexF64}(undef, cfg.nlm_cplx)
44-
tlm = Vector{ComplexF64}(undef, cfg.nlm_cplx)
42+
qlm = Vector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
43+
slm = Vector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
44+
tlm = Vector{ComplexF64}(undef, cfg.nlm_cplx*cfg.howmany)
4545
analys!(cfg, copy(ur), copy(utheta), copy(uphi), qlm, slm, tlm)
4646
return qlm, slm, tlm
4747
end
4848

49-
function analys!(cfg::SHTnsCfg, v::AbstractMatrix{Float64}, qlm)
49+
function analys!(cfg::SHTnsCfg{Real,T,N}, v::Array{Float64}, qlm) where {T,N}
5050
return spat_to_SH(cfg.cfg, v, qlm)
5151
end
5252

53-
function analys!(cfg::SHTnsCfg, v::AbstractMatrix{ComplexF64}, qlm)
53+
function analys!(cfg::SHTnsCfg{Complex,T,N}, v::Array{ComplexF64}, qlm) where {T,N}
5454
return spat_cplx_to_SH(cfg.cfg, v, qlm)
5555
end
5656

57-
function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm, tlm) where {T<:AbstractMatrix{Float64}}
57+
function analys!(cfg::SHTnsCfg{Real,T,N}, utheta::Tv, uphi::Tv, slm, tlm) where {T,N,Tv<:Array{Float64}}
5858
return spat_to_SHsphtor(cfg.cfg, utheta, uphi, slm, tlm)
5959
end
6060

61-
function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm, tlm) where {T<:AbstractMatrix{ComplexF64}}
61+
function analys!(cfg::SHTnsCfg{Complex,T,N}, utheta::Tv, uphi::Tv, slm, tlm) where {T,N,Tv<:Array{ComplexF64}}
6262
return spat_cplx_to_SHsphtor(cfg.cfg, utheta, uphi, slm, tlm)
6363
end
6464

65-
function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm, slm, tlm) where {T<:AbstractMatrix{Float64}}
65+
function analys!(cfg::SHTnsCfg{Real,T,N}, ur::Tv, utheta::Tv, uphi::Tv, qlm, slm, tlm) where {T,N,Tv<:Array{Float64}}
6666
return spat_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
6767
end
6868

69-
function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm, slm, tlm) where {T<:AbstractMatrix{ComplexF64}}
69+
function analys!(cfg::SHTnsCfg{Complex,T,N}, ur::Tv, utheta::Tv, uphi::Tv, qlm, slm, tlm) where {T,N,Tv<:Array{ComplexF64}}
7070
return spat_cplx_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
7171
end
7272

src/sht.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,7 @@ function SHqst_to_lat(cfg, Qlm, Slm, Tlm, cost, vr, vt, vp, nphi, ltr, mtr)
308308
ccall((:SHqst_to_lat, libshtns[]), Nothing, (shtns_cfg, Ptr{ComplexF64}, Ptr{ComplexF64}, Ptr{ComplexF64}, Float64, Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Int32, Int32, Int32), cfg, Qlm, Slm, Tlm, cost, vr, vt, vp, nphi, ltr, mtr)
309309
end
310310

311+
function shtns_set_many(shtns, howmany, spec_dist)
312+
hm = ccall((:shtns_set_many, libshtns[]), Int32, (shtns_cfg, Int32, Int64), shtns, howmany, spec_dist)
313+
@assert hm == howmany
314+
end

0 commit comments

Comments
 (0)