Skip to content

Commit aac575e

Browse files
committed
add cpu!, gpu!, and cu! macros
1 parent 9df6548 commit aac575e

File tree

3 files changed

+44
-2
lines changed

3 files changed

+44
-2
lines changed

src/CMBLensing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ import Statistics: std
6868

6969

7070
export
71-
@⌛, @show⌛, @ismain, @namedtuple, @repeated, @unpack, animate,
71+
@⌛, @show⌛, @ismain, @namedtuple, @repeated, @unpack, @cpu!, animate,
7272
argmaxf_lnP, BandPassOp, BaseDataSet, batch, batch_index, batch_length, beamCℓs, cache,
7373
CachedLenseFlow, camb, cov_to_Cℓ, cpu, Cℓ_2D, Cℓ_to_Cov, DataSet, DerivBasis,
7474
diag, Diagonal, DiagOp, dot, EBFourier, EBMap, expnorm, Field, FieldArray, fieldinfo,

src/gpu.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using CUDA: cufunc, curand_rng
44
using CUDA.CUSPARSE: CuSparseMatrix, CuSparseMatrixCSR, CuSparseMatrixCOO
55
using CUDA.CUSOLVER: CuQR
66

7-
export cuda_gc, gpu
7+
export cuda_gc, gpu, @gpu!, @cu!
88

99
const CuBaseField{B,M,T,A<:CuArray} = BaseField{B,M,T,A}
1010

@@ -20,10 +20,40 @@ end
2020
is_gpu_backed(::BaseField{B,M,T,A}) where {B,M,T,A<:CuArray} = true
2121
global_rng_for(::Type{<:CuArray}) = curand_rng()
2222

23+
# handy conversion functions and macros
24+
@doc doc"""
2325
26+
@gpu! x y
27+
28+
Equivalent to `x = gpu(x)`, `y = gpu(y)`, etc... for any number of
29+
listed variables. See [`gpu`](@ref).
30+
"""
31+
macro gpu!(vars...)
32+
:(begin; $((:($(esc(var)) = gpu($(esc(var)))) for var in vars)...); nothing; end)
33+
end
34+
@doc doc"""
35+
36+
gpu(x)
37+
38+
Recursively moves x to GPU, but unlike `CUDA.cu`, doesn't also convert
39+
to Float32. Equivalent to `adapt_structure(CuArray, x)`. Returns nothing.
40+
"""
2441
gpu(x) = adapt_structure(CuArray, x)
2542

2643

44+
@doc doc"""
45+
46+
@cu! x y
47+
48+
Equivalent to `x = cu(x)`, `y = cu(y)`, etc... for any number of
49+
listed variables. See `CUDA.cu`. Returns nothing.
50+
"""
51+
macro cu!(vars...)
52+
:(begin; $((:($(esc(var)) = cu($(esc(var)))) for var in vars)...); nothing; end)
53+
end
54+
55+
56+
2757
adapt_structure(::CUDA.Float32Adaptor, proj::ProjLambert) = adapt_structure(CuArray{Float32}, proj)
2858

2959

src/util.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,18 @@ Recursively move an object to CPU memory. See also [`gpu`](@ref).
225225
"""
226226
cpu(x) = adapt_structure(Array, x)
227227

228+
@doc doc"""
229+
230+
@cpu! x y
231+
232+
Equivalent to `x = cpu(x)`, `y = cpu(y)`, etc... for any number of
233+
listed variables. See [`cpu`](@ref).
234+
"""
235+
macro cpu!(vars...)
236+
:(begin; $((:($(esc(var)) = cpu($(esc(var)))) for var in vars)...); nothing; end)
237+
end
238+
239+
228240
@doc doc"""
229241
230242
gpu(x)

0 commit comments

Comments
 (0)