Skip to content

Commit 79737f5

Browse files
Merge pull request #5 from JuliaImageRecon/Fix-adjoint-gradient-bug
GradOp Cleanup + docstring
2 parents 7d57dad + 7c4f2f2 commit 79737f5

File tree

1 file changed

+16
-22
lines changed

1 file changed

+16
-22
lines changed

src/GradientOp.jl

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,26 @@
1-
function LinearOperatorCollection.GradientOp(::Type{T};
2-
shape::Tuple, dims=nothing) where T <: Number
3-
if dims == nothing
4-
return GradientOpImpl(T, shape)
5-
else
6-
return GradientOpImpl(T, shape, dims)
7-
end
8-
end
9-
101
"""
11-
GradOp(T::Type, shape::NTuple{N,Int64})
2+
GradientOp(T::Type; shape::Tuple, dims=1:length(shape))
123
13-
Nd gradient operator for an array of size `shape`
14-
"""
15-
function GradientOpImpl(T::Type, shape)
16-
shape = typeof(shape) <: Number ? (shape,) : shape # convert Number to Tuple
17-
return vcat([GradientOpImpl(T, shape, i) for i eachindex(shape)]...)
18-
end
4+
directional gradient operator along the dimensions `dims` for an array of size `shape`.
195
20-
"""
21-
GradOp(T::Type, shape::NTuple{N,Int64}, dims)
6+
# Required Argument
7+
* `T` - type of elements, .e.g. `Float64` for `ComplexF32`
8+
9+
# Required Keyword argument
10+
* `shape::NTuple{N,Int}` - shape of the array (e.g., image)
2211
23-
directional gradient operator along the dimensions `dims`
24-
for an array of size `shape`
12+
# Optional Keyword argument
13+
* `dims` - dimension(s) along which the gradient is applied; default is `1:length(shape)`
2514
"""
26-
function GradientOpImpl(T::Type, shape::NTuple{N,Int64}, dims) where N
15+
function GradientOp(::Type{T}; shape::NTuple{N,Int}, dims=1:length(shape)) where {T <: Number, N}
16+
return GradientOpImpl(T, shape, dims)
17+
end
18+
19+
function GradientOpImpl(T::Type, shape::NTuple{N,Int}, dims) where N
2720
return vcat([GradientOpImpl(T, shape, dim) for dim dims]...)
2821
end
29-
function GradientOpImpl(T::Type, shape::NTuple{N,Int64}, dim::Integer) where N
22+
23+
function GradientOpImpl(T::Type, shape::NTuple{N,Int}, dim::Int) where N
3024
nrow = div( (shape[dim]-1)*prod(shape), shape[dim] )
3125
ncol = prod(shape)
3226
return LinearOperator{T}(nrow, ncol, false, false,

0 commit comments

Comments
 (0)