Skip to content

Commit 8aa167f

Browse files
authored
Avoid threadid using @spawn (#57)
* Partly remove threadid() * Add more warnings * Update fft to remove threadid * Update rotate to remove threadid * wip * Finish removing threadid * To v0.3
1 parent 3d2bd6d commit 8aa167f

File tree

10 files changed

+193
-160
lines changed

10 files changed

+193
-160
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SPECTrecon"
22
uuid = "ab1be465-a7f0-4423-9048-0ee774b70ed9"
33
authors = ["Zongyu Li and Jeff Fessler and group"]
4-
version = "0.2"
4+
version = "0.3"
55

66
[deps]
77
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -19,4 +19,4 @@ ImageFiltering = "0.6, 0.7"
1919
LinearInterpolators = "0.1"
2020
LinearMapsAA = "0.12"
2121
OffsetArrays = "1"
22-
julia = "1.10"
22+
julia = "1.11"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The examples include an illustration
2424
of how to integrate deep learning
2525
into SPECT reconstruction.
2626

27-
Tested with Julia ≥ 1.10.
27+
Tested with Julia ≥ 1.11.
2828

2929
## Related packages
3030

src/SPECTrecon.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module SPECTrecon
88

99
include("foreach.jl")
1010
include("helper.jl")
11+
include("spawn.jl")
1112
include("plan-rotate.jl")
1213
include("rotatez.jl")
1314
include("plan-psf.jl")

src/backproject.jl

Lines changed: 72 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,111 +13,112 @@ function backproject!(
1313
viewidx::Int,
1414
)
1515

16-
Threads.@threads for z in 1:plan.imgsize[3] # 1:nz
17-
thid = Threads.threadid() # thread id
18-
16+
# rotate image and mumap using multiple processors (adjoint)
17+
nz = plan.imgsize[3] # prepare to loop over slices
18+
spawner(plan.nthread, nz) do buffer_id, iz
1919
# rotate mumap
20-
imrotate!((@view plan.mumapr[:, :, z]),
21-
(@view plan.mumap[:, :, z]),
22-
plan.viewangle[viewidx],
23-
plan.planrot[thid],
24-
)
25-
26-
end # COV_EXCL_LINE
20+
imrotate!(
21+
(@view plan.mumapr[:, :, iz]),
22+
(@view plan.mumap[:, :, iz]),
23+
plan.viewangle[viewidx],
24+
plan.planrot[buffer_id],
25+
)
26+
end
2727

2828
# adjoint of convolving img with psf and applying attenuation map
29-
Threads.@threads for y in 1:plan.imgsize[2] # 1:ny
30-
thid = Threads.threadid() # thread id
29+
ny = plan.imgsize[2] # prepare to loop over y planes
30+
spawner(plan.nthread, ny) do buffer_id, iy
3131
# account for half of the final slice thickness
32-
scale3dj!(plan.exp_mumapr[thid], plan.mumapr, y, -0.5)
33-
for j in 1:y
34-
plus3dj!(plan.exp_mumapr[thid], plan.mumapr, j)
35-
end
32+
scale3dj!(plan.exp_mumapr[buffer_id], plan.mumapr, iy, -0.5)
3633

37-
broadcast!(*, plan.exp_mumapr[thid], plan.exp_mumapr[thid], - plan.dy)
38-
39-
broadcast!(exp, plan.exp_mumapr[thid], plan.exp_mumapr[thid])
34+
for j in 1:iy
35+
plus3dj!(plan.exp_mumapr[buffer_id], plan.mumapr, j)
36+
end
4037

41-
fft_conv_adj!((@view plan.imgr[:, y, :]),
42-
view,
43-
(@view plan.psfs[:, :, y, viewidx]),
44-
plan.planpsf[thid],
45-
)
38+
broadcast!(*, plan.exp_mumapr[buffer_id], plan.exp_mumapr[buffer_id], - plan.dy)
39+
broadcast!(exp, plan.exp_mumapr[buffer_id], plan.exp_mumapr[buffer_id])
4640

47-
mul3dj!(plan.imgr, plan.exp_mumapr[thid], y)
48-
end # COV_EXCL_LINE
41+
fft_conv_adj!(
42+
(@view plan.imgr[:, iy, :]),
43+
view,
44+
(@view plan.psfs[:, :, iy, viewidx]),
45+
plan.planpsf[buffer_id],
46+
)
4947

50-
# adjoint of rotating image
51-
Threads.@threads for z in 1:plan.imgsize[3] # 1:nz
52-
thid = Threads.threadid()
48+
mul3dj!(plan.imgr, plan.exp_mumapr[buffer_id], iy)
49+
end
5350

54-
imrotate_adj!((@view image[:, :, z]),
55-
(@view plan.imgr[:, :, z]),
56-
plan.viewangle[viewidx],
57-
plan.planrot[thid],
58-
)
59-
end # COV_EXCL_LINE
51+
# adjoint of rotate image
52+
spawner(plan.nthread, nz) do buffer_id, iz
53+
imrotate_adj!(
54+
(@view image[:, :, iz]),
55+
(@view plan.imgr[:, :, iz]),
56+
plan.viewangle[viewidx],
57+
plan.planrot[buffer_id],
58+
)
59+
end
6060

6161
return image
6262
end
6363

6464

6565
"""
66-
backproject!(image, view, plan, thid, viewidx)
66+
backproject!(image, view, plan, buffer_id, viewidx)
6767
Backproject a single view.
6868
"""
6969
function backproject!(
7070
image::AbstractArray{<:RealU, 3},
7171
view::AbstractMatrix{<:RealU},
7272
plan::SPECTplan,
73-
thid::Int,
73+
buffer_id::Int,
7474
viewidx::Int,
7575
)
7676

77+
# rotate mumap
7778
for z in 1:plan.imgsize[3] # 1:nz
78-
# thid = Threads.threadid() # thread id
79-
80-
# rotate mumap
81-
imrotate!((@view plan.mumapr[thid][:, :, z]),
82-
(@view plan.mumap[:, :, z]),
83-
plan.viewangle[viewidx],
84-
plan.planrot[thid],
85-
)
79+
imrotate!(
80+
(@view plan.mumapr[buffer_id][:, :, z]),
81+
(@view plan.mumap[:, :, z]),
82+
plan.viewangle[viewidx],
83+
plan.planrot[buffer_id],
84+
)
8685

8786
end
8887

8988
# adjoint of convolving img with psf and applying attenuation map
9089
for y in 1:plan.imgsize[2] # 1:ny
91-
thid = Threads.threadid() # thread id
9290
# account for half of the final slice thickness
93-
scale3dj!(plan.exp_mumapr[thid], plan.mumapr[thid], y, -0.5)
91+
scale3dj!(plan.exp_mumapr[buffer_id], plan.mumapr[buffer_id], y, -0.5)
92+
9493
for j in 1:y
95-
plus3dj!(plan.exp_mumapr[thid], plan.mumapr[thid], j)
94+
plus3dj!(plan.exp_mumapr[buffer_id], plan.mumapr[buffer_id], j)
9695
end
9796

98-
broadcast!(*, plan.exp_mumapr[thid], plan.exp_mumapr[thid], - plan.dy)
97+
broadcast!(*, plan.exp_mumapr[buffer_id], plan.exp_mumapr[buffer_id], - plan.dy)
9998

100-
broadcast!(exp, plan.exp_mumapr[thid], plan.exp_mumapr[thid])
99+
broadcast!(exp, plan.exp_mumapr[buffer_id], plan.exp_mumapr[buffer_id])
101100

102-
fft_conv_adj!((@view plan.imgr[thid][:, y, :]),
103-
view,
104-
(@view plan.psfs[:, :, y, viewidx]),
105-
plan.planpsf[thid],
106-
)
101+
fft_conv_adj!(
102+
(@view plan.imgr[buffer_id][:, y, :]),
103+
view,
104+
(@view plan.psfs[:, :, y, viewidx]),
105+
plan.planpsf[buffer_id],
106+
)
107107

108-
mul3dj!(plan.imgr[thid], plan.exp_mumapr[thid], y)
108+
mul3dj!(plan.imgr[buffer_id], plan.exp_mumapr[buffer_id], y)
109109
end
110110

111111
# adjoint of rotating image
112112
for z in 1:plan.imgsize[3] # 1:nz
113-
imrotate_adj!((@view plan.imgr[thid][:, :, z]),
114-
(@view plan.imgr[thid][:, :, z]),
115-
plan.viewangle[viewidx],
116-
plan.planrot[thid],
117-
)
113+
imrotate_adj!(
114+
(@view plan.imgr[buffer_id][:, :, z]),
115+
(@view plan.imgr[buffer_id][:, :, z]),
116+
plan.viewangle[viewidx],
117+
plan.planrot[buffer_id],
118+
)
118119
end
119120

120-
broadcast!(+, image, image, plan.imgr[thid])
121+
broadcast!(+, image, image, plan.imgr[buffer_id])
121122

122123
return image
123124
end
@@ -138,11 +139,14 @@ function backproject!(
138139
# loop over each view index
139140
image .= zero(plan.T) # must be initialized as zero
140141
if plan.mode === :fast
141-
[plan.add_img[i] .= zero(plan.T) for i in 1:plan.nthread]
142-
Threads.@threads for (i, viewidx) in collect(enumerate(index))
143-
thid = Threads.threadid()
144-
backproject!(plan.add_img[thid], (@view views[:, :, i]), plan, thid, viewidx)
145-
end # COV_EXCL_LINE
142+
for i in 1:plan.nthread
143+
plan.add_img[i] .= zero(plan.T)
144+
end
145+
146+
spawner(plan.nthread, length(index)) do buffer_id, ii
147+
viewidx = index[ii]
148+
backproject!(plan.add_img[buffer_id], (@view views[:,:,ii]), plan, buffer_id, viewidx)
149+
end
146150

147151
for i in 1:plan.nthread
148152
broadcast!(+, image, image, plan.add_img[i])
@@ -187,6 +191,7 @@ function backproject(
187191
dy::RealU;
188192
interpmeth::Symbol = :two,
189193
mode::Symbol = :fast,
194+
# nthread::Int = Threads.nthreads(), # todo: option for plan
190195
kwargs...,
191196
)
192197

src/fft_convolve.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
export fft_conv!, fft_conv_adj!
44
export fft_conv, fft_conv_adj
55

6+
67
"""
78
imfilterz!(plan)
89
FFT-based convolution of `plan.img_compl`
910
and kernel `plan.ker_compl` (not centered),
1011
storing result in `plan.workmat`.
12+
Over-writes `plan.ker_compl`.
1113
"""
1214
function imfilterz!(plan::PlanPSF)
1315
mul!(plan.img_compl, plan.fft_plan, plan.img_compl)
1416
mul!(plan.ker_compl, plan.fft_plan, plan.ker_compl)
1517
broadcast!(*, plan.img_compl, plan.img_compl, plan.ker_compl)
18+
# plan.img_compl .*= plan.ker_compl # todo - time it
1619
mul!(plan.img_compl, plan.ifft_plan, plan.img_compl)
1720
fftshift2!(plan.ker_compl, plan.img_compl)
1821
plan.workmat .= real.(plan.ker_compl)
@@ -161,15 +164,16 @@ function fft_conv!(
161164

162165
size(output) == size(image3) || throw(DimensionMismatch())
163166

164-
fun = y -> fft_conv!(
165-
(@view output[:, y, :]),
166-
(@view image3[:, y, :]),
167-
(@view ker3[:, :, y]),
168-
plans[Threads.threadid()],
167+
nbuffer = length(plans)
168+
ny = size(image3, 2)
169+
spawner(nbuffer, ny) do buffer_id, iy
170+
fft_conv!(
171+
(@view output[:, iy, :]),
172+
(@view image3[:, iy, :]),
173+
(@view ker3[:, :, iy]),
174+
plans[buffer_id],
169175
)
170-
171-
ntasks = length(plans)
172-
Threads.foreach(fun, foreach_setup(1:size(image3, 2)); ntasks)
176+
end
173177

174178
return output
175179
end
@@ -188,15 +192,16 @@ function fft_conv_adj!(
188192

189193
size(output) == size(image3) || throw(DimensionMismatch())
190194

191-
fun = y -> fft_conv_adj!(
192-
(@view output[:, y, :]),
193-
(@view image3[:, y, :]),
194-
(@view ker3[:, :, y]),
195-
plans[Threads.threadid()],
195+
nbuffer = length(plans)
196+
ny = size(image3, 2)
197+
spawner(nbuffer, ny) do buffer_id, iy
198+
fft_conv_adj!(
199+
(@view output[:, iy, :]),
200+
(@view image3[:, iy, :]),
201+
(@view ker3[:, :, iy]),
202+
plans[buffer_id],
196203
)
197-
198-
ntasks = length(plans)
199-
Threads.foreach(fun, foreach_setup(1:size(image3, 2)); ntasks)
204+
end
200205

201206
return output
202207
end
@@ -217,15 +222,16 @@ function fft_conv_adj2!(
217222
size(output, 1) == size(image2, 1) || throw("size 1")
218223
size(output, 3) == size(image2, 2) || throw("size 2")
219224

220-
fun = y -> fft_conv_adj!(
221-
(@view output[:, y, :]),
225+
nbuffer = length(plans)
226+
ny = size(image2, 2)
227+
spawner(nbuffer, ny) do buffer_id, iy
228+
fft_conv_adj!(
229+
(@view output[:, iy, :]),
222230
image2,
223-
(@view ker3[:, :, y]),
224-
plans[Threads.threadid()],
231+
(@view ker3[:, :, iy]),
232+
plans[buffer_id],
225233
)
226-
227-
ntasks = length(plans)
228-
Threads.foreach(fun, foreach_setup(1:size(output, 2)); ntasks)
234+
end
229235

230236
return output
231237
end

src/foreach.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
# foreach.jl
12

2-
3+
#=
34
"""
45
foreach_setup(z)
56
Return `Channel` for `foreach` threaded computation from iterable `z`.
@@ -9,3 +10,4 @@ function foreach_setup(z)
910
foreach(i -> put!(ch, i), z)
1011
end
1112
end
13+
=#

0 commit comments

Comments
 (0)