Skip to content

Commit 5ebafe3

Browse files
authored
Add ImageNet preprocessing utilities (#80)
1 parent 6d300f9 commit 5ebafe3

File tree

6 files changed

+50
-2
lines changed

6 files changed

+50
-2
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
88
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
99
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1010
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
11+
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1314
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"

src/ExplainableAI.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Tullio
1010

1111
# Heatmapping:
1212
using ImageCore
13+
using ImageTransformations: imresize
1314
using ColorSchemes
1415

1516
# Model checks:
@@ -29,7 +30,7 @@ include("lrp_checks.jl")
2930
include("lrp_rules.jl")
3031
include("lrp.jl")
3132
include("heatmap.jl")
32-
33+
include("imagenet.jl")
3334
export analyze
3435

3536
# Analyzers
@@ -53,5 +54,5 @@ export heatmap
5354

5455
# utils
5556
export strip_softmax, flatten_model, flatten_chain, canonize
56-
57+
export preprocess_imagenet
5758
end # module

src/imagenet.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Image preprocessing for ImageNet models.
2+
# Code adapted from Metalhead 0.5.3's deprecated utils.jl
3+
# TODO: Remove once matching functionality is in either Metalhead.jl or MLDatasets.jl
4+
5+
# Coefficients taken from PyTorch's ImageNet normalization code
6+
const PYTORCH_MEAN = [0.485f0, 0.456f0, 0.406f0]
7+
const PYTORCH_STD = [0.229f0, 0.224f0, 0.225f0]
8+
const IMGSIZE = (224, 224)
9+
10+
# Take rectangle of pixels of shape `outsize` at the center of image `im`
11+
adjust(i::Integer) = ifelse(iszero(i % 2), 1, 0)
12+
function center_crop_view(im::AbstractMatrix, outsize=IMGSIZE)
13+
im = imresize(im; ratio=maximum(outsize .// size(im)))
14+
h2, w2 = div.(outsize, 2) # half height, half width of view
15+
h_adjust, w_adjust = adjust.(outsize)
16+
return @view im[
17+
((div(end, 2) - h2):(div(end, 2) + h2 - h_adjust)) .+ 1,
18+
((div(end, 2) - w2):(div(end, 2) + w2 - w_adjust)) .+ 1,
19+
]
20+
end
21+
22+
"""
23+
preprocess_imagenet(img)
24+
25+
Preprocess an image for use with Metalhead.jl's ImageNet models using PyTorch weights.
26+
Uses PyTorch's normalization constants.
27+
"""
28+
function preprocess_imagenet(im::AbstractMatrix{<:AbstractRGB}, T=Float32::Type{<:Real})
29+
im = center_crop_view(im)
30+
im = (channelview(im) .- PYTORCH_MEAN) ./ PYTORCH_STD
31+
return convert.(T, PermutedDimsArray(im, (3, 2, 1))) # Convert Image.jl's CHW to WHC
32+
end
589 KB
Binary file not shown.

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ using ReferenceTests
2020
@info "Running tests on heatmaps..."
2121
include("test_heatmaps.jl")
2222
end
23+
@testset "ImageNet preprocessing" begin
24+
@info "Running tests on ImageNet preprocessing..."
25+
include("test_imagenet.jl")
26+
end
2327
@testset "Canonize" begin
2428
@info "Running tests on model canonization..."
2529
include("test_canonize.jl")

test/test_imagenet.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using ExplainableAI
2+
using ImageCore
3+
4+
A = RGB{Float32}[
5+
RGB{Float32}(0.44557732f0, 0.25328094f0, 0.53720146f0) RGB{Float32}(0.99433f0, 0.37066674f0, 0.8781263f0) RGB{Float32}(0.59815156f0, 0.21008879f0, 0.07259983f0)
6+
RGB{Float32}(0.6966612f0, 0.27341717f0, 0.40360665f0) RGB{Float32}(0.12119287f0, 0.63196003f0, 0.32167268f0) RGB{Float32}(0.31825548f0, 0.7599565f0, 0.20566207f0)
7+
]
8+
x = preprocess_imagenet(A)
9+
@test_reference "references/utils/preprocess_imagnet.jld2" Dict("x" => x) by =
10+
(r, a) -> isapprox(r["x"], a["x"]; rtol=0.05)

0 commit comments

Comments
 (0)