@@ -16,46 +16,80 @@ function GradientOp(::Type{T}; shape::NTuple{N,Int}, dims=1:length(shape), kwarg
16
16
return GradientOpImpl (T, shape, dims; kwargs... )
17
17
end
18
18
19
- function GradientOpImpl (T:: Type , shape:: NTuple{N,Int} , dims; kwargs... ) where N
20
- return vcat ([GradientOpImpl (T, shape, dim; kwargs... ) for dim ∈ dims]. .. )
19
+ function GradientOpImpl (T:: Type , shape:: NTuple{N,Int} , dims; S = Vector{T}) where N
20
+ dirs = CartesianIndex{N}[]
21
+ cols = Int64[]
22
+ for dim in dims
23
+ δ = zeros (Int32, N)
24
+ δ[dim] = 1
25
+ δ = NTuple {N} (δ)
26
+ di = CartesianIndex (δ)
27
+ push! (dirs, di)
28
+ push! (cols, div ((shape[dim]- 1 )* prod (shape), shape[dim]))
29
+ end
30
+ dim_ends = accumulate (+ , cols)
31
+
32
+ nrow = sum (cols)
33
+ ncol = prod (shape)
34
+
35
+ tmp = S (undef, ncol)
36
+
37
+ return LinearOperator {T} (nrow, ncol, false , false ,
38
+ (res,x) -> (grad! (res,x,shape,dirs, dims, dim_ends)),
39
+ (res,x) -> (grad_t! (res,x,shape,dirs, dims, dim_ends, tmp)),
40
+ (res,x) -> (grad_t! (res,x,shape,dirs, dims, dim_ends, tmp)),
41
+ S = S)
21
42
end
22
43
23
44
function GradientOpImpl (T:: Type , shape:: NTuple{N,Int} , dim:: Int ; S = Vector{T}) where N
24
45
nrow = div ( (shape[dim]- 1 )* prod (shape), shape[dim] )
25
46
ncol = prod (shape)
47
+ δ = zeros (Int, length (shape))
48
+ δ[dim] = 1
49
+ δ = Tuple (δ)
50
+ dir = CartesianIndex (δ)
26
51
return LinearOperator {T} (nrow, ncol, false , false ,
27
- (res,x) -> (grad! (res,x,shape,dim )),
28
- (res,x) -> (grad_t! (res,x,shape,dim )),
29
- (res,x) -> (grad_t! (res,x,shape,dim )),
52
+ (res,x) -> (grad! (res,x,shape,dir )),
53
+ (res,x) -> (grad_t! (res,x,shape,dir )),
54
+ (res,x) -> (grad_t! (res,x,shape,dir )),
30
55
S = S)
31
56
end
32
57
58
+ function grad! (res:: T , img:: U , shape, dirs, dims, dim_ends) where {T<: AbstractVector , U<: AbstractVector }
59
+ dim_start = 1
60
+
61
+ for (i, dir) in enumerate (dirs)
62
+ grad! (view (res, dim_start: dim_ends[i]), img, shape, dir)
63
+ dim_start = dim_ends[i] + 1
64
+ end
65
+ end
66
+
33
67
# directional gradients
34
- function grad! (res:: T , img:: U , shape, dim ) where {T<: AbstractVector , U<: AbstractVector }
68
+ function grad! (res:: T , img:: U , shape:: NTuple{N,Int64} , di :: CartesianIndex{N} ) where {N, T<: AbstractVector , U<: AbstractVector }
35
69
img_ = reshape (img,shape)
36
70
37
- δ = zeros (Int, length (shape))
38
- δ[dim] = 1
39
- δ = Tuple (δ)
40
- di = CartesianIndex (δ)
41
-
42
- res_ = reshape (res, shape .- δ)
71
+ res_ = reshape (res, shape .- Tuple (di))
43
72
44
73
Threads. @threads for i ∈ CartesianIndices (res_)
45
74
@inbounds res_[i] = img_[i] - img_[i + di]
46
75
end
47
76
end
48
77
78
+ function grad_t! (res:: T , g:: U , shape, dirs, dims, dims_end, tmp) where {T<: AbstractVector , U<: AbstractVector }
79
+ dim_start = 1
49
80
50
- # adjoint of directional gradients
51
- function grad_t! (res:: T , g:: U , shape:: NTuple{N,Int64} , dim:: Int64 ) where {T<: AbstractVector , U<: AbstractVector , N}
52
- δ = zeros (Int, length (shape))
53
- δ[dim] = 1
54
- δ = Tuple (δ)
55
- di = CartesianIndex (δ)
81
+ fill! (res, zero (eltype (res)))
82
+ for (i, dir) in enumerate (dirs)
83
+ grad_t! (tmp, view (g, dim_start: dims_end[i]), shape, dir)
84
+ dim_start = dims_end[i] + 1
85
+ res .= res .+ tmp
86
+ end
87
+ end
56
88
89
+ # adjoint of directional gradients
90
+ function grad_t! (res:: T , g:: U , shape:: NTuple{N,Int64} , di:: CartesianIndex{N} ) where {N, T<: AbstractVector , U<: AbstractVector }
57
91
res_ = reshape (res,shape)
58
- g_ = reshape (g, shape .- δ )
92
+ g_ = reshape (g, shape .- Tuple (di) )
59
93
60
94
res_ .= 0
61
95
Threads. @threads for i ∈ CartesianIndices (g_)
0 commit comments