-
Notifications
You must be signed in to change notification settings - Fork 50
implement generic ROF model using Chambolle04 primal-dual method #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
bfbaf69
a9ed869
3c441fa
a7fc3c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
module Models | ||
|
||
using ImageBase | ||
using ImageBase.ImageCore.MappedArrays: of_eltype | ||
using ImageBase.FiniteDiff | ||
|
||
# Introduced in ColorVectorSpace v0.9.3 | ||
# https://github.com/JuliaGraphics/ColorVectorSpace.jl/pull/172 | ||
using ImageBase.ImageCore.ColorVectorSpace.Future: abs2 | ||
|
||
""" | ||
This submodule provides predefined image-related models and its solvers that can be reused | ||
by many image processing tasks. | ||
|
||
- solve the Rudin Osher Fatemi (ROF) model using the primal-dual method: [`solve_ROF_PD`](@ref) | ||
""" | ||
Models | ||
|
||
export solve_ROF_PD | ||
|
||
|
||
##### implementation details | ||
|
||
""" | ||
solve_ROF_PD(img::AbstractArray, λ; kwargs...) | ||
|
||
Perform Rudin-Osher-Fatemi (ROF) filtering, more commonly known as Total Variation (TV) | ||
denoising or TV regularization. This algorithm is based on the primal-dual method. | ||
|
||
This function applies to generic n-dimensional colorant array and is also CUDA-compatible. | ||
|
||
# Arguments | ||
|
||
- `img`: the input image, usually is a noisy image. | ||
- `λ`: the regularization coefficient. Larger `λ` would produce more smooth image. | ||
johnnychen94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Parameters | ||
|
||
- `num_iters::Int`: The number of iterations before stopping. | ||
|
||
# Examples | ||
|
||
```julia | ||
using ImageFiltering | ||
using ImageFiltering.Models: solve_ROF_PD | ||
using ImageQualityIndexes | ||
using TestImages | ||
|
||
img_ori = float.(testimage("cameraman")) | ||
img_noisy = img_ori .+ 0.1 .* randn(size(img_ori)) | ||
assess_psnr(img_noisy, img_ori) # ~20 dB | ||
|
||
img_smoothed = solve_ROF_PD(img_noisy, 0.015, 50) | ||
assess_psnr(img_smoothed, img_ori) # ~27 dB | ||
|
||
# larger λ produces over-smoothed result | ||
img_smoothed = solve_ROF_PD(img_noisy, 5, 50) | ||
assess_psnr(img_smoothed, img_ori) # ~21 dB | ||
``` | ||
|
||
# Extended help | ||
|
||
Mathematically, this function solves the following ROF model using the primal-dual method: | ||
|
||
```math | ||
\\min_u \\lVert u - g \\rVert^2 + \\lambda\\lvert\\nabla u\\rvert | ||
``` | ||
|
||
# References | ||
|
||
- [1] Chambolle, A. (2004). "An algorithm for total variation minimization and applications". _Journal of Mathematical Imaging and Vision_. 20: 89–97 | ||
- [2] https://en.wikipedia.org/wiki/Total_variation_denoising | ||
""" | ||
function solve_ROF_PD(img::AbstractArray, λ::Real, num_iters::Integer) | ||
# Total Variation regularized image denoising using the primal dual algorithm | ||
# Implement according to reference [1] | ||
|
||
# use Float32 for better GPU performance | ||
τ = Float32(1/4) # see 2nd remark after proof of Theorem 3.1. | ||
johnnychen94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
λ = Float32(λ) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you're going to do this, then best to do in a stub method to reduce latency. I.e., function myfunc(x::Int, args...)
# big method, slow to compile, so we compile it only for `x::Int`
end
myfunc(x::Integer, args...) = myfunc(Int(x), args...) # very fast to compile, we can make lots of instances There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is definitely a good suggestion, just that I don't have a good estimation of how much benefit we get by doing this. Are there any utils from SnoopCompile to watch it more closely other than check the first-time-to-plot with Edit: It seems yes, I found https://timholy.github.io/SnoopCompile.jl/stable/pgdsgui/#pgds There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also a variant of something in the style guide: https://docs.julialang.org/en/v1/manual/style-guide/#Handle-excess-argument-diversity-in-the-caller |
||
FT = float32(eltype(img)) | ||
|
||
# use the same symbol in the paper | ||
if FT == eltype(img) | ||
g = img | ||
else | ||
g = FT.(img) | ||
end | ||
u = similar(g) | ||
p = fgradient(g) | ||
div_p = similar(g) | ||
∇u = map(similar, p) | ||
∇u_mag = similar(g, eltype(eltype(g))) | ||
|
||
# This iterates Eq. (9) of [1] | ||
# TODO(johnnychen94): set better stop criterion | ||
for _ in 1:num_iters | ||
fdiv!(div_p, p) | ||
# multiply term inside ∇ by -λ. Thm. 3.1 relates this to `u` via Eq. 7. | ||
@. u = g - λ*div_p | ||
fgradient!(∇u, u) | ||
_l2norm_vec!(∇u_mag, ∇u) # |∇(g - λdiv p)| | ||
# Eq. (9): update p | ||
for i in 1:length(p) | ||
@. p[i] = (p[i] - (τ/λ)*∇u[i])/(1 + (τ/λ) * ∇u_mag) | ||
end | ||
end | ||
return u | ||
end | ||
|
||
|
||
function _l2norm_vec!(out, Vs::Tuple) | ||
all(v->axes(out) == axes(v), Vs) || throw(ArgumentError("All axes of input data should be the same.")) | ||
@. out = abs2(Vs[1]) | ||
for v in Vs[2:end] | ||
@. out += abs2(v) | ||
end | ||
@. out = sqrt(out) | ||
return out | ||
end | ||
|
||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[deps] | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
ImageBase = "c817782e-172a-44cc-b673-b171935fbb9e" | ||
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5" | ||
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19" | ||
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1" | ||
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
using ImageFiltering.Models | ||
|
||
@testset "solve_ROF_PD" begin | ||
# This testset is modified from its CPU version | ||
|
||
@testset "Numerical" begin | ||
# 2D Gray | ||
img = restrict(testimage("cameraman")) | ||
img_noisy = img .+ 0.05randn(MersenneTwister(0), size(img)) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.05, 20) | ||
@test ndims(img_smoothed) == 2 | ||
@test eltype(img_smoothed) <: Gray | ||
@test assess_psnr(img_smoothed, img) > 31.67 | ||
@test assess_ssim(img_smoothed, img) > 0.90 | ||
|
||
img_noisy_cu = CuArray(float32.(img_noisy)) | ||
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.05, 20) | ||
@test img_smoothed_cu isa CuArray | ||
@test eltype(eltype(img_smoothed_cu)) == Float32 | ||
@test Array(img_smoothed_cu) ≈ img_smoothed | ||
|
||
# 2D RGB | ||
img = restrict(testimage("lighthouse")) | ||
img_noisy = img .+ colorview(RGB, ntuple(i->0.05.*randn(MersenneTwister(i), size(img)), 3)...) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.03, 20) | ||
@test ndims(img_smoothed) == 2 | ||
@test eltype(img_smoothed) <: RGB | ||
@test assess_psnr(img_smoothed, img) > 32.15 | ||
@test assess_ssim(img_smoothed, img) > 0.90 | ||
|
||
img_noisy_cu = CuArray(float32.(img_noisy)) | ||
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.03, 20) | ||
@test img_smoothed_cu isa CuArray | ||
@test eltype(eltype(img_smoothed_cu)) == Float32 | ||
@test Array(img_smoothed_cu) ≈ img_smoothed | ||
|
||
# 3D Gray | ||
img = Gray.(restrict(testimage("mri"), (1, 2))) | ||
img_noisy = img .+ 0.05randn(MersenneTwister(0), size(img)) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.02, 20) | ||
@test ndims(img_smoothed) == 3 | ||
@test eltype(img_smoothed) <: Gray | ||
@test assess_psnr(img_smoothed, img) > 31.78 | ||
@test assess_ssim(img_smoothed, img) > 0.85 | ||
|
||
img_noisy_cu = CuArray(float32.(img_noisy)) | ||
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.02, 20) | ||
@test img_smoothed_cu isa CuArray | ||
@test eltype(eltype(img_smoothed_cu)) == Float32 | ||
@test Array(img_smoothed_cu) ≈ img_smoothed | ||
|
||
# 3D RGB | ||
img = RGB.(restrict(testimage("mri"), (1, 2))) | ||
img_noisy = img .+ colorview(RGB, ntuple(i->0.05.*randn(MersenneTwister(i), size(img)), 3)...) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.02, 20) | ||
@test ndims(img_smoothed) == 3 | ||
@test eltype(img_smoothed) <: RGB | ||
@test assess_psnr(img_smoothed, img) > 31.17 | ||
@test assess_ssim(img_smoothed, img) > 0.79 | ||
|
||
img_noisy_cu = CuArray(float32.(img_noisy)) | ||
img_smoothed_cu = solve_ROF_PD(img_noisy_cu, 0.02, 20) | ||
@test img_smoothed_cu isa CuArray | ||
@test eltype(eltype(img_smoothed_cu)) == Float32 | ||
@test Array(img_smoothed_cu) ≈ img_smoothed | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# This file is maintained in a way to support CUDA-only test via | ||
# `julia --project=test/cuda -e 'include("runtests.jl")'` | ||
using ImageFiltering | ||
using CUDA | ||
using TestImages | ||
using ImageBase | ||
using ImageQualityIndexes | ||
using Test | ||
using Random | ||
|
||
CUDA.allowscalar(false) | ||
|
||
@testset "ImageFiltering" begin | ||
if CUDA.functional() | ||
include("models.jl") | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
using ImageFiltering.Models | ||
|
||
@testset "solve_ROF_PD" begin | ||
# Note: random seed really matters a lot | ||
|
||
@testset "Numerical" begin | ||
# 2D Gray | ||
img = restrict(testimage("cameraman")) | ||
img_noisy = img .+ 0.05randn(MersenneTwister(0), size(img)) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.05, 20) | ||
@test ndims(img_smoothed) == 2 | ||
@test eltype(img_smoothed) <: Gray | ||
@test assess_psnr(img_smoothed, img) > 31.67 | ||
@test assess_ssim(img_smoothed, img) > 0.90 | ||
|
||
# 2D RGB | ||
img = restrict(testimage("lighthouse")) | ||
img_noisy = img .+ colorview(RGB, ntuple(i->0.05.*randn(MersenneTwister(i), size(img)), 3)...) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.03, 20) | ||
@test ndims(img_smoothed) == 2 | ||
@test eltype(img_smoothed) <: RGB | ||
@test assess_psnr(img_smoothed, img) > 32.15 | ||
@test assess_ssim(img_smoothed, img) > 0.90 | ||
|
||
# 3D Gray | ||
img = Gray.(restrict(testimage("mri"), (1, 2))) | ||
img_noisy = img .+ 0.05randn(MersenneTwister(0), size(img)) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.02, 20) | ||
@test ndims(img_smoothed) == 3 | ||
@test eltype(img_smoothed) <: Gray | ||
@test assess_psnr(img_smoothed, img) > 31.78 | ||
@test assess_ssim(img_smoothed, img) > 0.85 | ||
|
||
# 3D RGB | ||
img = RGB.(restrict(testimage("mri"), (1, 2))) | ||
img_noisy = img .+ colorview(RGB, ntuple(i->0.05.*randn(MersenneTwister(i), size(img)), 3)...) | ||
img_smoothed = solve_ROF_PD(img_noisy, 0.02, 20) | ||
@test ndims(img_smoothed) == 3 | ||
@test eltype(img_smoothed) <: RGB | ||
@test assess_psnr(img_smoothed, img) > 31.17 | ||
@test assess_ssim(img_smoothed, img) > 0.79 | ||
end | ||
|
||
@testset "FixedPointNumbers" begin | ||
A = rand(N0f8, 20, 20) | ||
@test solve_ROF_PD(A, 0.01, 5) ≈ solve_ROF_PD(float32.(A), 0.01, 5) | ||
end | ||
|
||
@testset "OffsetArray" begin | ||
Ao = OffsetArray(rand(N0f8, 20, 20), -1, -1) | ||
out = solve_ROF_PD(Ao, 0.01, 5) | ||
@test axes(out) == axes(Ao) | ||
end | ||
end |
Uh oh!
There was an error while loading. Please reload this page.