Skip to content

Commit 084c962

Browse files
committed
Restructure src folder
Create `src/lrp` subfolder
1 parent 1386f4e commit 084c962

File tree

8 files changed

+37
-35
lines changed

8 files changed

+37
-35
lines changed

src/ExplainableAI.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,16 @@ include("compat.jl")
2121
include("neuron_selection.jl")
2222
include("analyze_api.jl")
2323
include("types.jl")
24-
include("flux.jl")
24+
include("flux_utils.jl")
2525
include("utils.jl")
26-
include("canonize.jl")
2726
include("input_augmentation.jl")
2827
include("gradient.jl")
29-
include("lrp_checks.jl")
30-
include("lrp_rules.jl")
31-
include("lrp.jl")
28+
include("lrp/canonize.jl")
29+
include("lrp/checks.jl")
30+
include("lrp/rules.jl")
31+
include("lrp/lrp.jl")
3232
include("heatmap.jl")
33-
include("imagenet.jl")
33+
include("preprocessing.jl")
3434
export analyze
3535

3636
# Analyzers
File renamed without changes.
File renamed without changes.

src/lrp_checks.jl renamed to src/lrp/checks.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,30 @@ function check_model(::Val{:LRP}, c::Chain; verbose=true)
107107
throw(ArgumentError("Unknown or unsupported activation functions found in model"))
108108
end
109109
end
110+
111+
# Utils for printing model check summary using PrettyTable.jl
112+
_print_name(layer) = "$layer"
113+
_print_name(layer::Parallel) = "Parallel(...)"
114+
_print_activation(layer) = hasproperty(layer, ) ? "$(layer.σ)" : ""
115+
_print_activation(layer::Parallel) = ""
116+
117+
function _show_check_summary(
118+
c::Chain, layer_names, layer_checks, activation_names, activation_checks
119+
)
120+
hl_pass = Highlighter((data, i, j) -> j in (3, 5) && data[i, j]; foreground=:green)
121+
hl_fail = Highlighter((data, i, j) -> j in (3, 5) && !data[i, j]; foreground=:red)
122+
data = hcat(
123+
collect(1:length(c)),
124+
layer_names,
125+
collect(layer_checks),
126+
activation_names,
127+
collect(activation_checks),
128+
)
129+
pretty_table(
130+
data;
131+
header=["", "Layer", "Layer supported", "Activation", "Act. supported"],
132+
alignment=[:r, :l, :r, :c, :r],
133+
highlighters=(hl_pass, hl_fail),
134+
)
135+
return nothing
136+
end
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/utils.jl

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -76,39 +76,14 @@ ones_like(x::AbstractArray) = ones(eltype(x), size(x))
7676
ones_like(x::Number) = oneunit(x)
7777

7878
function keep_positive!(x::AbstractArray{T}) where {T}
79-
x[x .< 0] .= zero(T)
79+
z = zero(T)
80+
x[x .< 0] .= z
8081
return x
8182
end
8283
function keep_negative!(x::AbstractArray{T}) where {T}
83-
x[x .> 0] .= zero(T)
84+
z = zero(T)
85+
x[x .> 0] .= z
8486
return x
8587
end
8688
keep_positive(x) = keep_positive!(deepcopy(x))
8789
keep_negative(x) = keep_negative!(deepcopy(x))
88-
89-
# Utils for printing model check summary using PrettyTable.jl
90-
_print_name(layer) = "$layer"
91-
_print_name(layer::Parallel) = "Parallel(...)"
92-
_print_activation(layer) = hasproperty(layer, ) ? "$(layer.σ)" : ""
93-
_print_activation(layer::Parallel) = ""
94-
95-
function _show_check_summary(
96-
c::Chain, layer_names, layer_checks, activation_names, activation_checks
97-
)
98-
hl_pass = Highlighter((data, i, j) -> j in (3, 5) && data[i, j]; foreground=:green)
99-
hl_fail = Highlighter((data, i, j) -> j in (3, 5) && !data[i, j]; foreground=:red)
100-
data = hcat(
101-
collect(1:length(c)),
102-
layer_names,
103-
collect(layer_checks),
104-
activation_names,
105-
collect(activation_checks),
106-
)
107-
pretty_table(
108-
data;
109-
header=["", "Layer", "Layer supported", "Activation", "Act. supported"],
110-
alignment=[:r, :l, :r, :c, :r],
111-
highlighters=(hl_pass, hl_fail),
112-
)
113-
return nothing
114-
end

0 commit comments

Comments
 (0)