Skip to content

Commit 41874be

Browse files
authored
Add XAIBase package extension (#4)
Continues work started in Julia-XAI/XAIBase.jl#16 and Julia-XAI/VisionHeatmaps.jl#7 by moving `heatmap` methods on `Explanation` type to TextHeatmaps.jl via package extensions on XAIBase.
1 parent 5b982d9 commit 41874be

15 files changed

+187
-45
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
matrix:
2121
version:
2222
- '1.6'
23+
- '1'
2324
- 'nightly'
2425
os:
2526
- ubuntu-latest
@@ -33,7 +34,9 @@ jobs:
3334
arch: ${{ matrix.arch }}
3435
- uses: julia-actions/cache@v1
3536
- uses: julia-actions/julia-buildpkg@v1
37+
continue-on-error: ${{ matrix.version == 'nightly' }}
3638
- uses: julia-actions/julia-runtest@v1
39+
continue-on-error: ${{ matrix.version == 'nightly' }}
3740
- uses: julia-actions/julia-processcoverage@v1
3841
- uses: codecov/codecov-action@v3
3942
with:

Project.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
name = "TextHeatmaps"
22
uuid = "2dd6718a-6083-4824-b9f7-90e4a57f72d2"
33
authors = ["Adrian Hill <[email protected]>"]
4-
version = "1.1.0"
4+
version = "1.2.0-DEV"
55

66
[deps]
77
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
88
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
99
Crayons = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
1010
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
11+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
12+
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
13+
14+
[weakdeps]
15+
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
16+
17+
[extensions]
18+
TextHeatmapsXAIBaseExt = "XAIBase"
1119

1220
[compat]
1321
ColorSchemes = "3"
1422
Colors = "0.12"
1523
Crayons = "4"
1624
FixedPointNumbers = "0.8"
25+
Requires = "1"
26+
XAIBase = "3"
1727
julia = "1.6"

ext/TextHeatmapsXAIBaseExt.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
module TextHeatmapsXAIBaseExt
2+
3+
using TextHeatmaps, XAIBase
4+
5+
struct HeatmapConfig
6+
colorscheme::Symbol
7+
reduce::Symbol
8+
rangescale::Symbol
9+
end
10+
11+
const DEFAULT_COLORSCHEME = :seismic
12+
const DEFAULT_REDUCE = :sum
13+
const DEFAULT_RANGESCALE = :centered
14+
const DEFAULT_HEATMAP_PRESET = HeatmapConfig(
15+
DEFAULT_COLORSCHEME, DEFAULT_REDUCE, DEFAULT_RANGESCALE
16+
)
17+
18+
const HEATMAP_PRESETS = Dict{Symbol,HeatmapConfig}(
19+
:attribution => HeatmapConfig(:seismic, :sum, :centered),
20+
:sensitivity => HeatmapConfig(:grays, :norm, :extrema),
21+
:cam => HeatmapConfig(:jet, :sum, :extrema),
22+
)
23+
24+
# Select HeatmapConfig preset based on heatmapping style in Explanation
25+
function get_heatmapping_config(heatmap::Symbol)
26+
return get(HEATMAP_PRESETS, heatmap, DEFAULT_HEATMAP_PRESET)
27+
end
28+
29+
# Override HeatmapConfig preset with keyword arguments
30+
function get_heatmapping_config(expl::Explanation; kwargs...)
31+
c = get_heatmapping_config(expl.heatmap)
32+
33+
colorscheme = get(kwargs, :colorscheme, c.colorscheme)
34+
rangescale = get(kwargs, :rangescale, c.rangescale)
35+
reduce = get(kwargs, :reduce, c.reduce)
36+
return HeatmapConfig(colorscheme, reduce, rangescale)
37+
end
38+
39+
"""
40+
heatmap(explanation, text)
41+
42+
Visualize [`Explanation`](@ref) from XAIBase as text heatmap.
43+
Text should be a vector containing vectors of strings, one for each input in the batched explanation.
44+
45+
## Keyword arguments
46+
- `colorscheme::Union{ColorScheme,Symbol}`: color scheme from ColorSchemes.jl.
47+
Defaults to `:$DEFAULT_COLORSCHEME`.
48+
- `rangescale::Symbol`: selects how the color channel reduced heatmap is normalized
49+
before the color scheme is applied. Can be either `:extrema` or `:centered`.
50+
Defaults to `:$DEFAULT_RANGESCALE` for use with the default color scheme `:$DEFAULT_COLORSCHEME`.
51+
"""
52+
function TextHeatmaps.heatmap(
53+
expl::Explanation, texts::AbstractVector{<:AbstractVector{<:AbstractString}}; kwargs...
54+
)
55+
ndims(expl.val) != 2 && throw(
56+
ArgumentError(
57+
"To heatmap text, `explanation.val` must be 2D array of shape `(input_length, batchsize)`. Got array of shape $(size(x)) instead.",
58+
),
59+
)
60+
batchsize = size(expl.val, 2)
61+
textsize = length(texts)
62+
batchsize != textsize && throw(
63+
ArgumentError("Batchsize $batchsize doesn't match number of texts $textsize.")
64+
)
65+
66+
c = get_heatmapping_config(expl; kwargs...)
67+
return [
68+
TextHeatmaps.heatmap(v, t; colorscheme=c.colorscheme, rangescale=c.rangescale) for
69+
(v, t) in zip(eachcol(expl.val), texts)
70+
]
71+
end
72+
73+
function TextHeatmaps.heatmap(
74+
expl::Explanation, text::AbstractVector{<:AbstractString}; kwargs...
75+
)
76+
return heatmap(expl, [text]; kwargs...)
77+
end
78+
79+
end # module

src/TextHeatmaps.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,19 @@ using Crayons: Crayon
44
using FixedPointNumbers: N0f8
55
using Colors: Colorant, RGB, hex
66
using ColorSchemes: ColorScheme, colorschemes, get, seismic
7+
using Requires: @require
78

89
include("heatmap.jl")
910

11+
if !isdefined(Base, :get_extension)
12+
using Requires
13+
function __init__()
14+
@require XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7" include(
15+
"../ext/TextHeatmapsXAIBaseExt.jl"
16+
)
17+
end
18+
end
19+
1020
export heatmap
1121

1222
end # module

src/heatmap.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ struct TextHeatmap{
3939
end
4040

4141
function TextHeatmap(
42-
val, words; colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME, rangescale=DEFAULT_RANGESCALE
42+
val,
43+
words;
44+
colorscheme::Union{ColorScheme,Symbol}=DEFAULT_COLORSCHEME,
45+
rangescale=DEFAULT_RANGESCALE,
4346
)
4447
if size(val) != size(words)
4548
throw(ArgumentError("Sizes of values and words don't match"))

test/Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
44
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
55
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
6+
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
67
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9+
XAIBase = "9b48221d-a747-4c1b-9860-46a1d8ba24a7"
810

9-
[compat]
10-
Aqua = "0.7"
11-
ColorSchemes = "3"
12-
Colors = "0.12"
13-
FixedPointNumbers = "0.8"
14-
ReferenceTests = "0.10"

test/references/Gradient1.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Test Text Heatmap

test/references/Gradient2.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
another dummy input

test/references/LRP1.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Test Text Heatmap

test/references/LRP1_extrema.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Test Text Heatmap

0 commit comments

Comments
 (0)