Skip to content

Commit 7158ae8

Browse files
committed
add SHTnsCUDAExt and working wrappers
1 parent 6eb98ca commit 7158ae8

File tree

6 files changed

+254
-2
lines changed

6 files changed

+254
-2
lines changed

Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ version = "0.2.0"
66
[deps]
77
SHTns_jll = "daf09cc5-9ab3-509e-9618-0b89086eb825"
88

9+
[weakdeps]
10+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
11+
12+
[extensions]
13+
SHTnsCUDAExt = "CUDA"
14+
915
[compat]
1016
julia = "1.6"
1117

ext/SHTnsCUDAExt/SHTnsCUDAExt.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module SHTnsCUDAExt
2+
3+
using CUDA
4+
using SHTns
5+
6+
import SHTns: libshtns
7+
import SHTns: synth, synth!, analys, analys!
8+
9+
__init__() = @assert CUDA.functional()
10+
11+
include("sht.jl")
12+
include("synth.jl")
13+
include("analys.jl")
14+
15+
end #module

ext/SHTnsCUDAExt/analys.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
function analys(cfg::SHTnsCfg, v::CuMatrix{Float64})
2+
@assert cfg.shtype.gpu
3+
@assert cfg.nlat != 0
4+
qlm = CuVector{ComplexF64}(undef, cfg.nlm)
5+
analys!(cfg, copy(v), qlm)
6+
return qlm
7+
end
8+
9+
function analys(cfg::SHTnsCfg, v::CuMatrix{ComplexF64})
10+
@assert cfg.shtype.gpu
11+
@assert cfg.nlat != 0
12+
qlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
13+
analys!(cfg, copy(v), qlm)
14+
return qlm
15+
end
16+
17+
function analys(cfg::SHTnsCfg, utheta::CuMatrix{Float64}, uphi::CuMatrix{Float64})
18+
@assert cfg.shtype.gpu
19+
@assert cfg.nlat != 0
20+
slm = CuVector{ComplexF64}(undef, cfg.nlm)
21+
tlm = CuVector{ComplexF64}(undef, cfg.nlm)
22+
analys!(cfg, copy(utheta), copy(uphi), slm, tlm)
23+
return slm, tlm
24+
end
25+
26+
function analys(cfg::SHTnsCfg, utheta::CuMatrix{ComplexF64}, uphi::CuMatrix{ComplexF64})
27+
@assert cfg.shtype.gpu
28+
@assert cfg.nlat != 0
29+
slm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
30+
tlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
31+
analys!(cfg, copy(utheta), copy(uphi), slm, tlm)
32+
return slm, tlm
33+
end
34+
35+
function analys(cfg::SHTnsCfg, ur::CuMatrix{Float64}, utheta::CuMatrix{Float64}, uphi::CuMatrix{Float64})
36+
@assert cfg.shtype.gpu
37+
@assert cfg.nlat != 0
38+
qlm = CuVector{ComplexF64}(undef, cfg.nlm)
39+
slm = CuVector{ComplexF64}(undef, cfg.nlm)
40+
tlm = CuVector{ComplexF64}(undef, cfg.nlm)
41+
analys!(cfg, copy(ur), copy(utheta), copy(uphi), qlm, slm, tlm)
42+
return qlm, slm, tlm
43+
end
44+
45+
function analys(cfg::SHTnsCfg, ur::CuMatrix{ComplexF64}, utheta::CuMatrix{ComplexF64}, uphi::CuMatrix{ComplexF64})
46+
@assert cfg.shtype.gpu
47+
@assert cfg.nlat != 0
48+
qlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
49+
slm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
50+
tlm = CuVector{ComplexF64}(undef, cfg.nlm_cplx)
51+
analys!(cfg, copy(ur), copy(utheta), copy(uphi), qlm, slm, tlm)
52+
return qlm, slm, tlm
53+
end
54+
55+
function analys!(cfg::SHTnsCfg, v::CuMatrix{Float64}, qlm::CuVector{ComplexF64})
56+
@assert cfg.shtype.gpu
57+
return cu_spat_to_SH(cfg.cfg, v, qlm, cfg.lmax)
58+
end
59+
60+
61+
function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{Float64}}
62+
@assert cfg.shtype.gpu
63+
return cu_spat_to_SHsphtor(cfg.cfg, utheta, uphi, slm, tlm, cfg.lmax)
64+
end
65+
66+
function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{Float64}}
67+
@assert cfg.shtype.gpu
68+
return cu_spat_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm, cfg.lmax)
69+
end
70+
71+
#complex to complex not available for CUDA (status: SHTns v3.7)
72+
73+
# function analys!(cfg::SHTnsCfg, v::CuMatrix{ComplexF64}, qlm::CuVector{ComplexF64})
74+
# return cu_spat_cplx_to_SH(cfg.cfg, v, qlm)
75+
# end
76+
77+
# function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{ComplexF64}}
78+
# return cu_spat_cplx_to_SHsphtor(cfg.cfg, utheta, uphi, slm, tlm)
79+
# end
80+
81+
# function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}) where {T<:CuMatrix{ComplexF64}}
82+
# return cu_spat_cplx_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
83+
# end

ext/SHTnsCUDAExt/sht.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
function cu_spat_to_SH(cfg, Vr::CuMatrix{Float64}, Qlm::CuVector{Complex{Float64}}, lmax)
2+
ccall((:cu_spat_to_SH, libshtns[]), Nothing, (shtns_cfg, CuPtr{Float64}, CuPtr{Complex{Float64}}, Clong), cfg, Vr, Qlm, lmax)
3+
end
4+
5+
function cu_SH_to_spat(cfg, Qlm::CuVector{Complex{Float64}}, Vr::CuMatrix{Float64}, lmax)
6+
ccall((:cu_SH_to_spat, libshtns[]), Nothing, (shtns_cfg, CuPtr{Complex{Float64}}, CuPtr{Float64}, Clong), cfg, Qlm, Vr, lmax)
7+
end
8+
9+
function cu_spat_to_SHsphtor(cfg, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, lmax)
10+
ccall((:cu_spat_to_SHsphtor, libshtns[]), Nothing, (shtns_cfg,CuPtr{Float64},CuPtr{Float64},CuPtr{ComplexF64},CuPtr{ComplexF64}, Clong), cfg, Vt, Vp, Slm, Tlm, lmax)
11+
end
12+
13+
function cu_SHsphtor_to_spat(cfg, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, lmax)
14+
ccall((:cu_SHsphtor_to_spat, libshtns[]), Nothing, (shtns_cfg,CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{Float64},CuPtr{Float64}, Clong), cfg, Slm, Tlm, Vt, Vp, lmax)
15+
end
16+
17+
function spat_to_SHqst(cfg, Vr::CuMatrix{Float64}, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, Qlm::CuVector{Complex{Float64}}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, lmax)
18+
ccall((:cu_spat_to_SHqst, libshtns[]), Nothing, (shtns_cfg,CuPtr{Float64},CuPtr{Float64},CuPtr{Float64},CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{ComplexF64}, Clong), cfg, Vr, Vt, Vp, Qlm, Slm, Tlm, lmax)
19+
end
20+
21+
function SHqst_to_spat(cfg, Qlm::CuVector{Complex{Float64}}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, Vr::CuMatrix{Float64}, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, lmax)
22+
ccall((:cu_SHqst_to_spat, libshtns[]), Nothing, (shtns_cfg,CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{Float64},CuPtr{Float64},CuPtr{Float64}, Clong), cfg, Qlm, Slm, Tlm, Vr, Vt, Vp, lmax)
23+
end
24+
25+

ext/SHTnsCUDAExt/synth.jl

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
function synth(cfg::SHTnsCfg, qlm::CuVector{ComplexF64})
2+
@assert cfg.shtype.gpu
3+
@assert cfg.nlat != 0
4+
@assert length(qlm) == cfg.nlm
5+
6+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
7+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
8+
9+
v = CuMatrix{Float64}(undef, nx, ny)
10+
synth!(cfg, qlm, v)
11+
return v
12+
end
13+
14+
function synth(cfg::SHTnsCfg, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64})
15+
@assert cfg.shtype.gpu
16+
@assert cfg.nlat != 0
17+
@assert length(slm) == length(tlm) == cfg.nlm
18+
19+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
20+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
21+
22+
utheta = CuMatrix{Float64}(undef, nx, ny)
23+
uphi = CuMatrix{Float64}(undef, nx, ny)
24+
synth!(cfg, slm, tlm, utheta, uphi)
25+
return utheta, uphi
26+
end
27+
28+
function synth(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64})
29+
@assert cfg.shtype.gpu
30+
@assert cfg.nlat != 0
31+
@assert length(qlm) == length(slm) == length(tlm) == cfg.nlm
32+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
33+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
34+
35+
ur = CuMatrix{Float64}(undef, nx, ny)
36+
utheta = CuMatrix{Float64}(undef, nx, ny)
37+
uphi = CuMatrix{Float64}(undef, nx, ny)
38+
synth!(cfg, qlm, slm, tlm, ur, utheta, uphi)
39+
return ur, utheta, uphi
40+
end
41+
42+
43+
function synth!(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, v::CuMatrix{Float64})
44+
@assert cfg.shtype.gpu
45+
cu_SH_to_spat(cfg.cfg, qlm, v, cfg.lmax)
46+
return v
47+
end
48+
49+
function synth!(cfg::SHTnsCfg, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, utheta::T, uphi::T) where {T<:CuMatrix{Float64}}
50+
@assert cfg.shtype.gpu
51+
cu_SHsphtor_to_spat(cfg.cfg, slm, tlm, utheta, uphi, cfg.lmax)
52+
return utheta, uphi
53+
end
54+
55+
function synth!(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, ur::T, utheta::T, uphi::T) where {T<:CuMatrix{Float64}}
56+
@assert cfg.shtype.gpu
57+
cu_SHqst_to_spat(cfg.cfg, qlm, slm, tlm, ur, utheta, uphi, cfg.lmax)
58+
return ur, utheta, uphi
59+
end
60+
61+
62+
#complex to complex not available for CUDA (status: SHTns v3.7)
63+
64+
# function synth_cplx(cfg::SHTnsCfg, qlm::CuVector{ComplexF64})
65+
# @assert cfg.shtype.gpu
66+
# @assert cfg.nlat != 0
67+
# @assert length(qlm) == cfg.nlm_cplx
68+
# @assert cfg.lmax == cfg.mmax
69+
70+
# nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
71+
# ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
72+
73+
# v = CuMatrix{ComplexF64}(undef, nx, ny)
74+
# synth!(cfg, qlm, v)
75+
# return v
76+
# end
77+
78+
# function synth_cplx(cfg::SHTnsCfg, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64})
79+
# @assert cfg.shtype.gpu
80+
# @assert cfg.nlat != 0
81+
# @assert length(slm) == length(tlm) == cfg.nlm_cplx
82+
# @assert cfg.lmax == cfg.mmax
83+
84+
# nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
85+
# ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
86+
87+
# utheta = CuMatrix{ComplexF64}(undef, nx, ny)
88+
# uphi = CuMatrix{ComplexF64}(undef, nx, ny)
89+
# synth!(cfg, slm, tlm, utheta, uphi)
90+
# return utheta, uphi
91+
# end
92+
93+
# function synth_cplx(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64})
94+
# @assert cfg.shtype.gpu
95+
# @assert cfg.nlat != 0
96+
# @assert length(qlm) == length(slm) == length(tlm) == cfg.nlm_cplx
97+
# @assert cfg.lmax == cfg.mmax
98+
99+
# nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
100+
# ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
101+
102+
# ur = CuMatrix{ComplexF64}(undef, nx, ny)
103+
# utheta = CuMatrix{ComplexF64}(undef, nx, ny)
104+
# uphi = CuMatrix{ComplexF64}(undef, nx, ny)
105+
# synth!(cfg, qlm, slm, tlm, ur, utheta, uphi)
106+
# return ur, utheta, uphi
107+
# end
108+
109+
# function synth!(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, v::CuMatrix{ComplexF64})
110+
# cu_SH_to_spat_cplx(cfg.cfg, qlm, v, cfg.lmax)
111+
# return v
112+
# end
113+
114+
# function synth!(cfg::SHTnsCfg, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, utheta::T, uphi::T) where {T<:CuMatrix{ComplexF64}}
115+
# cu_SHsphtor_to_spat_cplx(cfg.cfg, slm, tlm, utheta, uphi, cfg.lmax)
116+
# return utheta, uphi
117+
# end
118+
119+
# function synth!(cfg::SHTnsCfg, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, ur::T, utheta::T, uphi::T) where {T<:CuMatrix{ComplexF64}}
120+
# cu_SHqst_to_spat_cplx(cfg.cfg, qlm, slm, tlm, ur, utheta, uphi, cfg.lmax)
121+
# return ur, utheta, uphi
122+
# end
123+

src/analys.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ function analys!(cfg::SHTnsCfg, utheta::T, uphi::T, slm, tlm) where {T<:Abstract
6363
end
6464

6565
function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm, slm, tlm) where {T<:AbstractMatrix{Float64}}
66-
return spat_cplx_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
66+
return spat_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
6767
end
6868

6969
function analys!(cfg::SHTnsCfg, ur::T, utheta::T, uphi::T, qlm, slm, tlm) where {T<:AbstractMatrix{ComplexF64}}
70-
return spat_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
70+
return spat_cplx_to_SHqst(cfg.cfg, ur, utheta, uphi, qlm, slm, tlm)
7171
end
7272

7373

0 commit comments

Comments
 (0)