Skip to content

Commit aa64465

Browse files
committed
fix batched in gpu transforms
1 parent d1f128a commit aa64465

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

ext/SHTnsCUDAExt/sht.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1-
function cu_spat_to_SH(cfg, Vr::CuMatrix{Float64}, Qlm::CuVector{Complex{Float64}}, lmax)
1+
function cu_spat_to_SH(cfg, Vr::CuArray{Float64}, Qlm::CuVector{Complex{Float64}}, lmax)
22
ccall((:cu_spat_to_SH, libshtns[]), Nothing, (shtns_cfg, CuPtr{Float64}, CuPtr{Complex{Float64}}, Clong), cfg, Vr, Qlm, lmax)
33
end
44

5-
function cu_SH_to_spat(cfg, Qlm::CuVector{Complex{Float64}}, Vr::CuMatrix{Float64}, lmax)
5+
function cu_SH_to_spat(cfg, Qlm::CuVector{Complex{Float64}}, Vr::CuArray{Float64}, lmax)
66
ccall((:cu_SH_to_spat, libshtns[]), Nothing, (shtns_cfg, CuPtr{Complex{Float64}}, CuPtr{Float64}, Clong), cfg, Qlm, Vr, lmax)
77
end
88

9-
function cu_spat_to_SHsphtor(cfg, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, lmax)
9+
function cu_spat_to_SHsphtor(cfg, Vt::CuArray{Float64}, Vp::CuArray{Float64}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, lmax)
1010
ccall((:cu_spat_to_SHsphtor, libshtns[]), Nothing, (shtns_cfg,CuPtr{Float64},CuPtr{Float64},CuPtr{ComplexF64},CuPtr{ComplexF64}, Clong), cfg, Vt, Vp, Slm, Tlm, lmax)
1111
end
1212

13-
function cu_SHsphtor_to_spat(cfg, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, lmax)
13+
function cu_SHsphtor_to_spat(cfg, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, Vt::CuArray{Float64}, Vp::CuArray{Float64}, lmax)
1414
ccall((:cu_SHsphtor_to_spat, libshtns[]), Nothing, (shtns_cfg,CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{Float64},CuPtr{Float64}, Clong), cfg, Slm, Tlm, Vt, Vp, lmax)
1515
end
1616

17-
function spat_to_SHqst(cfg, Vr::CuMatrix{Float64}, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, Qlm::CuVector{Complex{Float64}}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, lmax)
17+
function spat_to_SHqst(cfg, Vr::CuArray{Float64}, Vt::CuArray{Float64}, Vp::CuArray{Float64}, Qlm::CuVector{Complex{Float64}}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, lmax)
1818
ccall((:cu_spat_to_SHqst, libshtns[]), Nothing, (shtns_cfg,CuPtr{Float64},CuPtr{Float64},CuPtr{Float64},CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{ComplexF64}, Clong), cfg, Vr, Vt, Vp, Qlm, Slm, Tlm, lmax)
1919
end
2020

21-
function SHqst_to_spat(cfg, Qlm::CuVector{Complex{Float64}}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, Vr::CuMatrix{Float64}, Vt::CuMatrix{Float64}, Vp::CuMatrix{Float64}, lmax)
21+
function SHqst_to_spat(cfg, Qlm::CuVector{Complex{Float64}}, Slm::CuVector{Complex{Float64}}, Tlm::CuVector{Complex{Float64}}, Vr::CuArray{Float64}, Vt::CuArray{Float64}, Vp::CuArray{Float64}, lmax)
2222
ccall((:cu_SHqst_to_spat, libshtns[]), Nothing, (shtns_cfg,CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{ComplexF64},CuPtr{Float64},CuPtr{Float64},CuPtr{Float64}, Clong), cfg, Qlm, Slm, Tlm, Vr, Vt, Vp, lmax)
2323
end
2424

ext/SHTnsCUDAExt/synth.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function synth(cfg::SHTnsCfg{TR,T,N}, qlm::CuVector{ComplexF64}) where {TR,T,N}
77
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
88
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
99

10-
v = CuMatrix{Tv}(undef, nx, ny)
10+
v = cfg.howmany > 1 ? CuArray{Tv}(undef, nx, ny, cfg.howmany) : CuMatrix{Tv}(undef, nx, ny)
1111
synth!(cfg, qlm, v)
1212
return v
1313
end
@@ -21,8 +21,8 @@ function synth(cfg::SHTnsCfg{TR,T,N}, slm::CuVector{ComplexF64}, tlm::CuVector{C
2121
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
2222
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
2323

24-
utheta = CuMatrix{Tv}(undef, nx, ny)
25-
uphi = CuMatrix{Tv}(undef, nx, ny)
24+
utheta = cfg.howmany > 1 ? CuArray{Tv}(undef, nx, ny, cfg.howmany) : CuMatrix{Tv}(undef, nx, ny)
25+
uphi = cfg.howmany > 1 ? CuArray{Tv}(undef, nx, ny, cfg.howmany) : CuMatrix{Tv}(undef, nx, ny)
2626
synth!(cfg, slm, tlm, utheta, uphi)
2727
return utheta, uphi
2828
end
@@ -35,27 +35,27 @@ function synth(cfg::SHTnsCfg{TR,T,N}, qlm::CuVector{ComplexF64}, slm::CuVector{C
3535
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
3636
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
3737

38-
ur = CuMatrix{Tv}(undef, nx, ny)
39-
utheta = CuMatrix{Tv}(undef, nx, ny)
40-
uphi = CuMatrix{Tv}(undef, nx, ny)
38+
ur = cfg.howmany > 1 ? CuArray{Tv}(undef, nx, ny, cfg.howmany) : CuMatrix{Tv}(undef, nx, ny)
39+
utheta = cfg.howmany > 1 ? CuArray{Tv}(undef, nx, ny, cfg.howmany) : CuMatrix{Tv}(undef, nx, ny)
40+
uphi = cfg.howmany > 1 ? CuArray{Tv}(undef, nx, ny, cfg.howmany) : CuMatrix{Tv}(undef, nx, ny)
4141
synth!(cfg, qlm, slm, tlm, ur, utheta, uphi)
4242
return ur, utheta, uphi
4343
end
4444

4545

46-
function synth!(cfg::SHTnsCfg{Real,T,N}, qlm::CuVector{ComplexF64}, v::CuMatrix{Float64}) where {T,N}
46+
function synth!(cfg::SHTnsCfg{Real,T,N}, qlm::CuVector{ComplexF64}, v::CuArray{Float64}) where {T,N}
4747
@assert cfg.shtype.gpu
4848
cu_SH_to_spat(cfg.cfg, qlm, v, cfg.lmax)
4949
return v
5050
end
5151

52-
function synth!(cfg::SHTnsCfg{Real,T,N}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, utheta::Tv, uphi::Tv) where {T,N,Tv<:CuMatrix{Float64}}
52+
function synth!(cfg::SHTnsCfg{Real,T,N}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, utheta::Tv, uphi::Tv) where {T,N,Tv<:CuArray{Float64}}
5353
@assert cfg.shtype.gpu
5454
cu_SHsphtor_to_spat(cfg.cfg, slm, tlm, utheta, uphi, cfg.lmax)
5555
return utheta, uphi
5656
end
5757

58-
function synth!(cfg::SHTnsCfg{Real,T,N}, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, ur::Tv, utheta::Tv, uphi::Tv) where {T,N,Tv<:CuMatrix{Float64}}
58+
function synth!(cfg::SHTnsCfg{Real,T,N}, qlm::CuVector{ComplexF64}, slm::CuVector{ComplexF64}, tlm::CuVector{ComplexF64}, ur::Tv, utheta::Tv, uphi::Tv) where {T,N,Tv<:CuArray{Float64}}
5959
@assert cfg.shtype.gpu
6060
cu_SHqst_to_spat(cfg.cfg, qlm, slm, tlm, ur, utheta, uphi, cfg.lmax)
6161
return ur, utheta, uphi

0 commit comments

Comments
 (0)