@@ -12,19 +12,16 @@ using ArgCheck
1212using FillArrays
1313using StaticArrays
1414
15+ const CuOrFillArray{T,N} = Union{CuArray{T,N},FillArrays. AbstractFill{T,N}}
1516
16- const CuOrFillArray{T, N} = Union{CuArray{T, N}, FillArrays. AbstractFill{T, N}}
17-
18-
19- const CuOrFillVector{T} = CuOrFillArray{T, 1 }
20-
17+ const CuOrFillVector{T} = CuOrFillArray{T,1 }
2118
2219function raster_pullback_kernel! (
2320 :: Type{T} ,
2421 ds_dout,
2522 points:: AbstractVector{<:StaticVector{N_in}} ,
26- rotations:: AbstractVector{<:StaticMatrix{N_out, N_in, TR}} ,
27- translations:: AbstractVector{<:StaticVector{N_out, TT}} ,
23+ rotations:: AbstractVector{<:StaticMatrix{N_out,N_in,TR}} ,
24+ translations:: AbstractVector{<:StaticVector{N_out,TT}} ,
2825 out_weights,
2926 point_weights,
3027 shifts,
@@ -35,8 +32,7 @@ function raster_pullback_kernel!(
3532 ds_dtranslation,
3633 ds_dout_weight,
3734 ds_dpoint_weight,
38-
39- ) where {T, TR, TT, N_in, N_out}
35+ ) where {T,TR,TT,N_in,N_out}
4036 n_voxel = blockDim (). z
4137 points_per_workgroup = blockDim (). x
4238 batchsize_per_workgroup = blockDim (). y
@@ -74,24 +70,27 @@ function raster_pullback_kernel!(
7470 origin = (- @SVector ones (TT, N_out)) - translation
7571
7672 coord_reference_voxel, deltas = DiffPointRasterisation. reference_coordinate_and_deltas (
77- point,
78- rotation,
79- origin,
80- scale,
73+ point, rotation, origin, scale
74+ )
75+ voxel_idx = CartesianIndex (
76+ CartesianIndex ( Tuple (coord_reference_voxel)) + CartesianIndex (shift), batch_idx
8177 )
82- voxel_idx = CartesianIndex (CartesianIndex (Tuple (coord_reference_voxel)) + CartesianIndex (shift), batch_idx)
83-
8478
8579 ds_dweight_local = zero (T)
8680 if voxel_idx in CartesianIndices (ds_dout)
8781 @inbounds ds_dweight_local = DiffPointRasterisation. voxel_weight (
88- deltas,
89- shift,
90- ds_dout[voxel_idx],
82+ deltas, shift, ds_dout[voxel_idx]
9183 )
9284
9385 factor = ds_dout[voxel_idx] * out_weight * point_weight
94- ds_dcoord_part = SVector (factor .* ntuple (n -> DiffPointRasterisation. interpolation_weight (n, N_out, deltas, shift), Val (N_out)))
86+ ds_dcoord_part = SVector (
87+ factor .* ntuple (
88+ n -> DiffPointRasterisation. interpolation_weight (
89+ n, N_out, deltas, shift
90+ ),
91+ Val (N_out),
92+ ),
93+ )
9594 @inbounds ds_dpoint_rot_shared[:, s, b] .= ds_dcoord_part .* scale
9695 else
9796 @inbounds ds_dpoint_rot_shared[:, s, b] .= zero (T)
@@ -136,7 +135,7 @@ function raster_pullback_kernel!(
136135 j = 1
137136 while j <= N_in
138137 val = coef * point[j]
139- @inbounds CUDA. @atomic ds_drotation[dim, j, batch_idx] += val
138+ @inbounds CUDA. @atomic ds_drotation[dim, j, batch_idx] += val
140139 j += 1
141140 end
142141 end
@@ -161,7 +160,7 @@ function raster_pullback_kernel!(
161160 sync_threads ()
162161 idx = 2 * stride * (b - 1 ) + 1
163162 if idx <= batchsize_per_workgroup
164- dim = s
163+ dim = s
165164 while dim <= N_in
166165 other_val_p = if idx + stride <= batchsize_per_workgroup
167166 ds_dpoint_shared[dim, idx + stride]
@@ -181,7 +180,7 @@ function raster_pullback_kernel!(
181180 sync_threads ()
182181 idx = 2 * stride * (thread - 1 ) + 1
183182 if idx <= n_threads_per_workgroup
184- other_val_w = if idx + stride <= n_threads_per_workgroup
183+ other_val_w = if idx + stride <= n_threads_per_workgroup
185184 ds_dpoint_weight_shared[idx + stride]
186185 else
187186 zero (T)
@@ -207,74 +206,103 @@ function raster_pullback_kernel!(
207206 @inbounds CUDA. @atomic ds_dpoint_weight[point_idx] += val_w
208207 end
209208
210- nothing
209+ return nothing
211210end
212211
213212# single image
214- raster_pullback! (
215- ds_dout:: CuArray{<:Number, N_out} ,
216- points:: AbstractVector{<:StaticVector{N_in, <:Number}} ,
217- rotation:: StaticMatrix{N_out, N_in, <:Number} ,
218- translation:: StaticVector{N_out, <:Number} ,
213+ function raster_pullback! (
214+ ds_dout:: CuArray{<:Number,N_out} ,
215+ points:: AbstractVector{<:StaticVector{N_in,<:Number}} ,
216+ rotation:: StaticMatrix{N_out,N_in,<:Number} ,
217+ translation:: StaticVector{N_out,<:Number} ,
219218 background:: Number ,
220219 out_weight:: Number ,
221220 point_weight:: CuOrFillVector{<:Number} ,
222221 ds_dpoints:: AbstractMatrix{<:Number} ,
223222 ds_dpoint_weight:: AbstractVector{<:Number} ;
224- kwargs...
225- ) where {N_in, N_out} = error (" Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays" )
223+ kwargs... ,
224+ ) where {N_in,N_out}
225+ return error (
226+ " Not implemented: raster_pullback! for single image not implemented on GPU. Consider using CPU arrays" ,
227+ )
228+ end
226229
227230# batch of images
228231function DiffPointRasterisation. raster_pullback! (
229- ds_dout:: CuArray{<:Number, N_out_p1} ,
230- points:: CuVector{<:StaticVector{N_in, <:Number}} ,
231- rotation:: CuVector{<:StaticMatrix{N_out, N_in, <:Number}} ,
232- translation:: CuVector{<:StaticVector{N_out, <:Number}} ,
232+ ds_dout:: CuArray{<:Number,N_out_p1} ,
233+ points:: CuVector{<:StaticVector{N_in,<:Number}} ,
234+ rotation:: CuVector{<:StaticMatrix{N_out,N_in,<:Number}} ,
235+ translation:: CuVector{<:StaticVector{N_out,<:Number}} ,
233236 background:: CuOrFillVector{<:Number} ,
234237 out_weight:: CuOrFillVector{<:Number} ,
235238 point_weight:: CuOrFillVector{<:Number} ,
236239 ds_dpoints:: CuMatrix{TP} ,
237- ds_drotation:: CuArray{TR, 3} ,
240+ ds_drotation:: CuArray{TR,3} ,
238241 ds_dtranslation:: CuMatrix{TT} ,
239242 ds_dbackground:: CuVector{<:Number} ,
240243 ds_dout_weight:: CuVector{OW} ,
241244 ds_dpoint_weight:: CuVector{PW} ,
242- ) where {N_in, N_out, N_out_p1, TP<: Number , TR<: Number , TT<: Number , OW<: Number , PW<: Number }
245+ ) where {N_in,N_out,N_out_p1,TP<: Number ,TR<: Number ,TT<: Number ,OW<: Number ,PW<: Number }
243246 T = promote_type (eltype (ds_dout), TP, TR, TT, OW, PW)
244247 batch_axis = axes (ds_dout, N_out_p1)
245248 @argcheck N_out == N_out_p1 - 1
246- @argcheck batch_axis == axes (rotation, 1 ) == axes (translation, 1 ) == axes (background, 1 ) == axes (out_weight, 1 )
247- @argcheck batch_axis == axes (ds_drotation, 3 ) == axes (ds_dtranslation, 2 ) == axes (ds_dbackground, 1 ) == axes (ds_dout_weight, 1 )
249+ @argcheck batch_axis ==
250+ axes (rotation, 1 ) ==
251+ axes (translation, 1 ) ==
252+ axes (background, 1 ) ==
253+ axes (out_weight, 1 )
254+ @argcheck batch_axis ==
255+ axes (ds_drotation, 3 ) ==
256+ axes (ds_dtranslation, 2 ) ==
257+ axes (ds_dbackground, 1 ) ==
258+ axes (ds_dout_weight, 1 )
248259 @argcheck N_out == N_out_p1 - 1
249260
250261 n_points = length (points)
251262 @argcheck length (ds_dpoint_weight) == n_points
252263 batch_size = length (batch_axis)
253264
254- ds_dbackground = vec (sum! (reshape (ds_dbackground, ntuple (_ -> 1 , Val (N_out))... , batch_size), ds_dout))
265+ ds_dbackground = vec (
266+ sum! (reshape (ds_dbackground, ntuple (_ -> 1 , Val (N_out))... , batch_size), ds_dout)
267+ )
255268
256- scale = SVector {N_out, T} (size (ds_dout)[1 : end - 1 ]) / T (2 )
257- shifts= DiffPointRasterisation. voxel_shifts (Val (N_out))
269+ scale = SVector {N_out,T} (size (ds_dout)[1 : ( end - 1 ) ]) / T (2 )
270+ shifts = DiffPointRasterisation. voxel_shifts (Val (N_out))
258271
259272 ds_dpoints = fill! (ds_dpoints, zero (TP))
260273 ds_drotation = fill! (ds_drotation, zero (TR))
261274 ds_dtranslation = fill! (ds_dtranslation, zero (TT))
262275 ds_dout_weight = fill! (ds_dout_weight, zero (OW))
263276 ds_dpoint_weight = fill! (ds_dpoint_weight, zero (PW))
264277
265- args = (T, ds_dout, points, rotation, translation, out_weight, point_weight, shifts, scale, ds_dpoints, ds_drotation, ds_dtranslation, ds_dout_weight, ds_dpoint_weight)
278+ args = (
279+ T,
280+ ds_dout,
281+ points,
282+ rotation,
283+ translation,
284+ out_weight,
285+ point_weight,
286+ shifts,
287+ scale,
288+ ds_dpoints,
289+ ds_drotation,
290+ ds_dtranslation,
291+ ds_dout_weight,
292+ ds_dpoint_weight,
293+ )
266294
267295 ndrange = (n_points, batch_size, 2 ^ N_out)
268296
269297 workgroup_size (threads) = (1 , min (threads ÷ (2 ^ N_out), batch_size), 2 ^ N_out)
270298
271299 function shmem (threads)
272- _, bs_p_wg, n_voxel = workgroup_size (threads)
273- ((N_out + 1 ) * n_voxel + N_in) * bs_p_wg * sizeof (T)
300+ _, bs_p_wg, n_voxel = workgroup_size (threads)
301+ return ((N_out + 1 ) * n_voxel + N_in) * bs_p_wg * sizeof (T)
274302 # ((N_out + 1) * threads + N_in * bs_p_wg) * sizeof(T)
275303 end
276304
277- let kernel = @cuda launch= false raster_pullback_kernel! (args... )
305+ let kernel = @cuda launch = false raster_pullback_kernel! (args... )
278306 config = CUDA. launch_configuration (kernel. fun; shmem)
279307 workgroup_sz = workgroup_size (config. threads)
280308 blocks = cld .(ndrange, workgroup_sz)
@@ -292,9 +320,16 @@ function DiffPointRasterisation.raster_pullback!(
292320 )
293321end
294322
323+ function DiffPointRasterisation. default_ds_dpoints_batched (
324+ points:: CuVector{<:AbstractVector{TP}} , N_in, batch_size
325+ ) where {TP<: Number }
326+ return similar (points, TP, (N_in, length (points)))
327+ end
295328
296- DiffPointRasterisation. default_ds_dpoints_batched (points:: CuVector{<:AbstractVector{TP}} , N_in, batch_size) where {TP<: Number } = similar (points, TP, (N_in, length (points)))
297-
298- DiffPointRasterisation. default_ds_dpoint_weight_batched (points:: CuVector{<:AbstractVector{<:Number}} , T, batch_size) = similar (points, T)
329+ function DiffPointRasterisation. default_ds_dpoint_weight_batched (
330+ points:: CuVector{<:AbstractVector{<:Number}} , T, batch_size
331+ )
332+ return similar (points, T)
333+ end
299334
300335end # module
0 commit comments