Skip to content

Commit 6eb98ca

Browse files
committed
add some SHTnsType settings
1 parent 9ffc35a commit 6eb98ca

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

src/SHTns.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,22 @@ for (type, enumtype) in [(:Gauss, :sht_gauss), (:RegFast, :sht_reg_fast), (:RegD
8282
8383
"""
8484
Base.@kwdef struct $(type)<:SHTnsType
85-
contiguous_lat::Bool=false
85+
contiguous_lat::Bool=true
8686
contiguous_phi::Bool=false
8787
padding::Bool=false
88+
gpu::Bool=false
89+
southpolefirst::Bool=false
90+
float32::Bool=false
8891
end
8992

9093
function Base.convert(::Type{shtns_type}, x::$(type))
9194
shtype = $(enumtype)
92-
x.contiguous_lat && (shtype += SHT_THETA_CONTIGUOUS)
9395
x.contiguous_phi && (shtype += SHT_PHI_CONTIGUOUS)
9496
x.padding && (shtype += SHT_ALLOW_PADDING)
97+
x.gpu && (shtype += SHT_ALLOW_GPU)
98+
x.contiguous_lat && (shtype += SHT_THETA_CONTIGUOUS)
99+
x.southpolefirst && (shtype += SHT_SOUTH_POLE_FIRST)
100+
x.float32 && (shtype += SHT_FP32)
95101
return shtype
96102
end
97103
end
@@ -223,6 +229,7 @@ const SHT_SCALAR_ONLY = UInt32(256 * 16)
223229
const SHT_LOAD_SAVE_CFG = UInt32(256 * 64)
224230
const SHT_ALLOW_GPU = UInt32(256 * 128)
225231
const SHT_ALLOW_PADDING = UInt32(256 * 256)
232+
const SHT_FP32 = UInt32(256 * 1024)
226233

227234
include("sht.jl")
228235
include("tools.jl")

src/synth.jl

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ DOCSTRING
66
function synth(cfg::SHTnsCfg, qlm)
77
@assert cfg.nlat != 0
88
@assert length(qlm) == cfg.nlm
9-
v = Matrix{Float64}(undef, cfg.nlat_padded, cfg.nphi)
9+
10+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
11+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
12+
13+
v = Matrix{Float64}(undef, nx, ny)
1014
synth!(cfg, qlm, v)
1115
return v
1216
end
@@ -20,7 +24,11 @@ function synth_cplx(cfg::SHTnsCfg, qlm)
2024
@assert cfg.nlat != 0
2125
@assert length(qlm) == cfg.nlm_cplx
2226
@assert cfg.lmax == cfg.mmax
23-
v = Matrix{ComplexF64}(undef, cfg.nlat_padded, cfg.nphi)
27+
28+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
29+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
30+
31+
v = Matrix{ComplexF64}(undef, nx, ny)
2432
synth!(cfg, qlm, v)
2533
return v
2634
end
@@ -38,8 +46,12 @@ DOCSTRING
3846
function synth(cfg::SHTnsCfg, slm, tlm)
3947
@assert cfg.nlat != 0
4048
@assert length(slm) == length(tlm) == cfg.nlm
41-
utheta = Matrix{Float64}(undef, cfg.nlat_padded, cfg.nphi)
42-
uphi = Matrix{Float64}(undef, cfg.nlat_padded, cfg.nphi)
49+
50+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
51+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
52+
53+
utheta = Matrix{Float64}(undef, nx, ny)
54+
uphi = Matrix{Float64}(undef, nx, ny)
4355
synth!(cfg, slm, tlm, utheta, uphi)
4456
return utheta, uphi
4557
end
@@ -58,8 +70,12 @@ function synth_cplx(cfg::SHTnsCfg, slm, tlm)
5870
@assert cfg.nlat != 0
5971
@assert length(slm) == length(tlm) == cfg.nlm_cplx
6072
@assert cfg.lmax == cfg.mmax
61-
utheta = Matrix{ComplexF64}(undef, cfg.nlat_padded, cfg.nphi)
62-
uphi = Matrix{ComplexF64}(undef, cfg.nlat_padded, cfg.nphi)
73+
74+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
75+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
76+
77+
utheta = Matrix{ComplexF64}(undef, nx, ny)
78+
uphi = Matrix{ComplexF64}(undef, nx, ny)
6379
synth!(cfg, slm, tlm, utheta, uphi)
6480
return utheta, uphi
6581
end
@@ -78,9 +94,12 @@ DOCSTRING
7894
function synth(cfg::SHTnsCfg, qlm, slm, tlm)
7995
@assert cfg.nlat != 0
8096
@assert length(qlm) == length(slm) == length(tlm) == cfg.nlm
81-
ur = Matrix{Float64}(undef, cfg.nlat, cfg.nphi)
82-
utheta = Matrix{Float64}(undef, cfg.nlat_padded, cfg.nphi)
83-
uphi = Matrix{Float64}(undef, cfg.nlat_padded, cfg.nphi)
97+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
98+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
99+
100+
ur = Matrix{Float64}(undef, nx, ny)
101+
utheta = Matrix{Float64}(undef, nx, ny)
102+
uphi = Matrix{Float64}(undef, nx, ny)
84103
synth!(cfg, qlm, slm, tlm, ur, utheta, uphi)
85104
return ur, utheta, uphi
86105
end
@@ -100,9 +119,13 @@ function synth_cplx(cfg::SHTnsCfg, qlm, slm, tlm)
100119
@assert cfg.nlat != 0
101120
@assert length(qlm) == length(slm) == length(tlm) == cfg.nlm_cplx
102121
@assert cfg.lmax == cfg.mmax
103-
ur = Matrix{ComplexF64}(undef, cfg.nlat_padded, cfg.nphi)
104-
utheta = Matrix{ComplexF64}(undef, cfg.nlat_padded, cfg.nphi)
105-
uphi = Matrix{ComplexF64}(undef, cfg.nlat_padded, cfg.nphi)
122+
123+
nx = cfg.shtype.contiguous_phi ? cfg.nphi : cfg.nlat_padded
124+
ny = cfg.shtype.contiguous_phi ? cfg.nlat_padded : cfg.nphi
125+
126+
ur = Matrix{ComplexF64}(undef, nx, ny)
127+
utheta = Matrix{ComplexF64}(undef, nx, ny)
128+
uphi = Matrix{ComplexF64}(undef, nx, ny)
106129
synth!(cfg, qlm, slm, tlm, ur, utheta, uphi)
107130
return ur, utheta, uphi
108131
end

0 commit comments

Comments
 (0)