Skip to content

Commit 01d720c

Browse files
authored
Merge pull request #9 from JuliaImageRecon/gpuOps
GPU Support for Operators
2 parents 97e14da + 8e14f5d commit 01d720c

File tree

22 files changed

+827
-229
lines changed

22 files changed

+827
-229
lines changed

.buildkite/pipeline.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
steps:
2+
- label: "Nvidia GPUs -- LinearOperators.jl"
3+
plugins:
4+
- JuliaCI/julia#v1:
5+
version: "1.10"
6+
agents:
7+
queue: "juliagpu"
8+
cuda: "*"
9+
command: |
10+
julia --color=yes --project -e '
11+
using Pkg
12+
Pkg.add("TestEnv")
13+
using TestEnv
14+
TestEnv.activate();
15+
Pkg.add("CUDA")
16+
Pkg.add("CuNFFT")
17+
Pkg.instantiate()
18+
include("test/gpu/cuda.jl")'
19+
timeout_in_minutes: 30

.github/workflows/Breakage.yml

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
name: Breakage
2+
# Based on: https://github.com/JuliaSmoothOptimizers/LinearOperators.jl/blob/main/.github/workflows/Breakage.yml
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
8+
jobs:
9+
break:
10+
runs-on: ubuntu-latest
11+
strategy:
12+
fail-fast: false
13+
matrix:
14+
pkg: [
15+
"JuliaImageRecon/RegularizedLeastSquares.jl",
16+
"MagneticResonanceImaging/MRIReco.jl"
17+
]
18+
pkgversion: [latest, stable]
19+
20+
steps:
21+
- uses: actions/checkout@v2
22+
23+
# Install Julia
24+
- uses: julia-actions/setup-julia@v2
25+
with:
26+
version: 1
27+
arch: x64
28+
- uses: actions/cache@v1
29+
env:
30+
cache-name: cache-artifacts
31+
with:
32+
path: ~/.julia/artifacts
33+
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
34+
restore-keys: |
35+
${{ runner.os }}-test-${{ env.cache-name }}-
36+
${{ runner.os }}-test-
37+
${{ runner.os }}-
38+
- uses: julia-actions/julia-buildpkg@v1
39+
40+
# Breakage test
41+
- name: 'Breakage of ${{ matrix.pkg }}, ${{ matrix.pkgversion }} version'
42+
env:
43+
URL: ${{ matrix.pkg }}
44+
VERSION: ${{ matrix.pkgversion }}
45+
run: |
46+
set -v
47+
mkdir -p ./pr
48+
echo "${{ github.event.number }}" > ./pr/NR
49+
git clone https://github.com/$URL
50+
export PKG=$(echo $URL | cut -f2 -d/)
51+
cd $PKG
52+
if [ $VERSION == "stable" ]; then
53+
TAG=$(git tag -l "v*" --sort=-creatordate | head -n1)
54+
if [ -z "$TAG" ]; then
55+
TAG="no_tag"
56+
else
57+
git checkout $TAG
58+
fi
59+
else
60+
TAG=$VERSION
61+
fi
62+
export TAG
63+
julia -e 'using Pkg;
64+
PKG, TAG, VERSION = ENV["PKG"], ENV["TAG"], ENV["VERSION"]
65+
joburl = joinpath(ENV["GITHUB_SERVER_URL"], ENV["GITHUB_REPOSITORY"], "actions/runs", ENV["GITHUB_RUN_ID"])
66+
TAG == "no_tag" && error("Not tag for $VERSION")
67+
pkg"activate .";
68+
pkg"instantiate";
69+
pkg"dev ../";
70+
if TAG == "latest"
71+
global TAG = chomp(read(`git rev-parse --short HEAD`, String))
72+
end
73+
pkg"build";
74+
pkg"test";'

Project.toml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,36 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1313

1414
[extras]
1515
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
17+
NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
18+
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
19+
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
20+
RadonKA = "86de8297-835b-47df-b249-c04e8db91db5"
1621

1722
[compat]
1823
julia = "1.9"
24+
GPUArrays = "8, 9, 10"
25+
JLArrays = "0.1"
1926
NFFT = "0.13"
2027
LinearOperators = "2.3.3"
28+
RadonKA = "0.6"
2129
Wavelets = "0.9, 0.10"
2230
Reexport = "1.0"
2331
FFTW = "1.0"
2432

2533
[weakdeps]
34+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
2635
NFFT = "efe261a4-0d2b-5849-be55-fc731d526b0d"
2736
Wavelets = "29a6e085-ba6d-5f35-a997-948ac2efa89a"
2837
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
38+
RadonKA = "86de8297-835b-47df-b249-c04e8db91db5"
2939

3040
[targets]
31-
test = ["Test", "FFTW", "Wavelets", "NFFT"]
41+
test = ["Test", "FFTW", "Wavelets", "NFFT", "JLArrays", "RadonKA"]
3242

3343
[extensions]
3444
LinearOperatorNFFTExt = ["NFFT", "FFTW"]
3545
LinearOperatorFFTWExt = "FFTW"
3646
LinearOperatorWaveletExt = "Wavelets"
47+
LinearOperatorGPUArraysExt = "GPUArrays"
48+
LinearOperatorRadonKAExt = "RadonKA"

ext/LinearOperatorFFTWExt/FFTOp.jl

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export FFTOpImpl
22

3-
mutable struct FFTOpImpl{T} <: FFTOp{T}
3+
mutable struct FFTOpImpl{T, vecT, P <: AbstractFFTs.Plan{T}, IP <: AbstractFFTs.Plan{T}} <: FFTOp{T}
44
nrow :: Int
55
ncol :: Int
66
symmetric :: Bool
@@ -14,10 +14,10 @@ mutable struct FFTOpImpl{T} <: FFTOp{T}
1414
args5 :: Bool
1515
use_prod5! :: Bool
1616
allocated5 :: Bool
17-
Mv5 :: Vector{T}
18-
Mtu5 :: Vector{T}
19-
plan
20-
iplan
17+
Mv5 :: vecT
18+
Mtu5 :: vecT
19+
plan :: P
20+
iplan :: IP
2121
shift::Bool
2222
unitary::Bool
2323
end
@@ -34,13 +34,14 @@ returns an operator which performs an FFT on Arrays of type T
3434
* `shape::Tuple` - size of the array to transform
3535
* (`shift=true`) - if true, fftshifts are performed
3636
* (`unitary=true`) - if true, FFT is normalized such that it is unitary
37+
* (`S = Vector{T}`) - type of temporary vector, change to use on GPU
38+
* (`kwargs...`) - keyword arguments given to fft plan
3739
"""
38-
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, cuda::Bool=false) where D
40+
function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::Bool=true, unitary::Bool=true, S = Array{Complex{real(T)}}, kwargs...) where D
3941

40-
#tmpVec = cuda ? CuArray{T}(undef,shape) : Array{Complex{real(T)}}(undef, shape)
41-
tmpVec = Array{Complex{real(T)}}(undef, shape)
42-
plan = plan_fft!(tmpVec; flags=FFTW.MEASURE)
43-
iplan = plan_bfft!(tmpVec; flags=FFTW.MEASURE)
42+
tmpVec = similar(S(undef, 0), shape...)
43+
plan = plan_fft!(tmpVec; kwargs...)
44+
iplan = plan_bfft!(tmpVec; kwargs...)
4445

4546
if unitary
4647
facF = T(1.0/sqrt(prod(shape)))
@@ -50,39 +51,25 @@ function LinearOperatorCollection.FFTOp(T::Type; shape::NTuple{D,Int64}, shift::
5051
facB = T(1.0)
5152
end
5253

53-
let shape_=shape, plan_=plan, iplan_=iplan, tmpVec_=tmpVec, facF_=facF, facB_=facB
54+
let shape_ = shape, plan_ = plan, iplan_ = iplan, tmpVec_ = tmpVec, facF_ = facF, facB_ = facB
5455

55-
if shift
56-
return FFTOpImpl{T}(prod(shape), prod(shape), false, false
57-
, (res, x) -> fft_multiply_shift!(res, plan_, x, shape_, facF_, tmpVec_)
58-
, nothing
59-
, (res, x) -> fft_multiply_shift!(res, iplan_, x, shape_, facB_, tmpVec_)
60-
, 0, 0, 0, true, false, true, T[], T[]
61-
, plan
62-
, iplan
63-
, shift
64-
, unitary)
65-
else
66-
return FFTOpImpl{T}(prod(shape), prod(shape), false, false
67-
, (res, x) -> fft_multiply!(res, plan_, x, facF_, tmpVec_)
68-
, nothing
69-
, (res, x) -> fft_multiply!(res, iplan_, x, facB_, tmpVec_)
70-
, 0, 0, 0, true, false, true, T[], T[]
71-
, plan
72-
, iplan
73-
, shift
74-
, unitary)
75-
end
56+
fun! = fft_multiply!
57+
if shift
58+
fun! = fft_multiply_shift!
59+
end
60+
61+
return FFTOpImpl(prod(shape), prod(shape), false, false, (res, x) -> fun!(res, plan_, x, shape_, facF_, tmpVec_),
62+
nothing, (res, x) -> fun!(res, iplan_, x, shape_, facB_, tmpVec_),
63+
0, 0, 0, true, false, true, similar(tmpVec, 0), similar(tmpVec, 0), plan, iplan, shift, unitary)
7664
end
7765
end
7866

79-
function fft_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, factor::T, tmpVec::Array{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
80-
tmpVec[:] .= x
81-
plan * tmpVec
67+
function fft_multiply!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, ::NTuple{D}, factor::T, tmpVec::AbstractArray{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
68+
plan * copyto!(tmpVec, x)
8269
res .= factor .* vec(tmpVec)
8370
end
8471

85-
function fft_multiply_shift!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, shape::NTuple{D}, factor::T, tmpVec::Array{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
72+
function fft_multiply_shift!(res::AbstractVector{T}, plan::P, x::AbstractVector{Tr}, shape::NTuple{D}, factor::T, tmpVec::AbstractArray{T,D}) where {T, Tr, P<:AbstractFFTs.Plan, D}
8673
ifftshift!(tmpVec, reshape(x,shape))
8774
plan * tmpVec
8875
fftshift!(reshape(res,shape), tmpVec)
@@ -91,5 +78,5 @@ end
9178

9279

9380
function Base.copy(S::FFTOpImpl)
94-
return FFTOp(eltype(S); shape=size(S.plan), shift=S.shift, unitary=S.unitary)
81+
return FFTOp(eltype(S); shape=size(S.plan), shift=S.shift, unitary=S.unitary, S = LinearOperators.storage_type(S)) # TODO loses kwargs...
9582
end

ext/LinearOperatorFFTWExt/LinearOperatorFFTWExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LinearOperatorFFTWExt
22

3-
using LinearOperatorCollection, FFTW
3+
using LinearOperatorCollection, FFTW, FFTW.AbstractFFTs
44

55
include("FFTOp.jl")
66
include("DCTOp.jl")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
function LinearOperatorCollection.grad!(res::vecT, img::vecT, shape::NTuple{N,Int64}, di::CartesianIndex{N}) where {vecT <: AbstractGPUVector, N}
2+
res = reshape(res, shape .- Tuple(di))
3+
4+
if length(res) > 0
5+
gpu_call(grad_kernel!, res, reshape(img,shape), di)
6+
end
7+
8+
return res
9+
end
10+
11+
function grad_kernel!(ctx, res, img, di)
12+
idx = @cartesianidx(res)
13+
@inbounds res[idx] = img[idx] - img[idx + di]
14+
return nothing
15+
end
16+
17+
# adjoint of directional gradients
18+
function LinearOperatorCollection.grad_t!(res::vecT, g::vecT, shape::NTuple{N,Int64}, di::CartesianIndex{N}) where {T, vecT <: AbstractGPUVector{T}, N}
19+
res_ = reshape(res,shape)
20+
g_ = reshape(g, shape .- Tuple(di))
21+
22+
fill!(res, zero(T))
23+
if length(g_) > 0
24+
gpu_call(grad_t_kernel_1!, res_, g_, di, elements = length(g))
25+
gpu_call(grad_t_kernel_2!, res_, g_, di, elements = length(g))
26+
end
27+
end
28+
29+
function grad_t_kernel_1!(ctx, res, g, di)
30+
idx = @cartesianidx(g)
31+
@inbounds res[idx] += g[idx]
32+
return nothing
33+
end
34+
35+
function grad_t_kernel_2!(ctx, res, g, di)
36+
idx = @cartesianidx(g)
37+
@inbounds res[idx + di] -= g[idx]
38+
return nothing
39+
end
40+
41+
function LinearOperatorCollection.grad_t!(res::vecT, g::vecT, shape::NTuple{N,Int64}, dirs, dims, dim_ends, tmp) where {T, vecT <: AbstractGPUVector{T}, N}
42+
dim_start = 1
43+
res = reshape(res, shape)
44+
45+
fill!(res, zero(eltype(res)))
46+
for (i, di) in enumerate(dirs)
47+
g_ = reshape(view(g, dim_start:dim_ends[i]), shape .- Tuple(di))
48+
if length(g_) > 0
49+
gpu_call(grad_t_kernel_1!, res, g_, di, elements = length(g))
50+
gpu_call(grad_t_kernel_2!, res, g_, di, elements = length(g))
51+
end
52+
dim_start = dim_ends[i] + 1
53+
end
54+
end
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module LinearOperatorGPUArraysExt
2+
3+
using LinearOperatorCollection, GPUArrays
4+
5+
include("GradientOp.jl")
6+
7+
8+
end # module

ext/LinearOperatorNFFTExt/LinearOperatorNFFTExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LinearOperatorNFFTExt
22

3-
using LinearOperatorCollection, NFFT, FFTW
3+
using LinearOperatorCollection, NFFT, NFFT.AbstractNFFTs, FFTW, FFTW.AbstractFFTs
44

55
include("NFFTOp.jl")
66

0 commit comments

Comments
 (0)