|
1 | 1 | using ExplainabilityMethods
|
2 |
| -using ExplainabilityMethods: SumReducer |
| 2 | +using ExplainabilityMethods: MaxAbsNormalizer, RangeNormalizer |
| 3 | +using ExplainabilityMethods: MaxAbsReducer, SumReducer |
3 | 4 |
|
4 | 5 | # Defaults assume nchannels=3
|
5 | 6 | A = rand(Float32, 3, 4, 5)
|
6 | 7 | B = reshape(A, 3, 1, 4, 5, 1)
|
7 |
| -@test heatmap(A) ≈ heatmap(B) |
| 8 | +@test heatmap(A; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) ≈ |
| 9 | + heatmap(B; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) |
| 10 | +@test heatmap(A; normalizer=RangeNormalizer(), reducer=SumReducer()) ≈ |
| 11 | + heatmap(B; normalizer=RangeNormalizer(), reducer=SumReducer()) |
8 | 12 |
|
9 | 13 | A = rand(Float32, 3, 4, 5)
|
10 | 14 | B = reshape(A, 3, 1, 4, 1, 5)
|
11 |
| -@test heatmap(A; reducer=SumReducer()) ≈ heatmap(B; reducer=SumReducer()) |
| 15 | +@test heatmap(A; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) ≈ |
| 16 | + heatmap(B; normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) |
| 17 | +@test heatmap(A; normalizer=RangeNormalizer(), reducer=SumReducer()) ≈ |
| 18 | + heatmap(B; normalizer=RangeNormalizer(), reducer=SumReducer()) |
12 | 19 |
|
13 | 20 | # Test with single channel
|
14 | 21 | A = rand(Float32, 4, 5)
|
15 | 22 | B = reshape(A, 4, 1, 1, 5, 1)
|
16 |
| -@test heatmap(A; nchannels=1) ≈ heatmap(B; nchannels=1) |
| 23 | +@test heatmap(A; nchannels=1, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) ≈ |
| 24 | + heatmap(B; nchannels=1, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) |
| 25 | +@test heatmap(A; nchannels=1, normalizer=RangeNormalizer(), reducer=SumReducer()) ≈ |
| 26 | + heatmap(B; nchannels=1, normalizer=RangeNormalizer(), reducer=SumReducer()) |
17 | 27 |
|
18 | 28 | # Test with 2 color channels
|
19 | 29 | A = rand(Float32, 2, 3, 4)
|
20 | 30 | B = reshape(A, 2, 1, 1, 3, 4)
|
21 |
| -@test heatmap(A; nchannels=2) ≈ heatmap(B; nchannels=2) |
| 31 | +@test heatmap(A; nchannels=2, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) ≈ |
| 32 | + heatmap(B; nchannels=2, normalizer=MaxAbsNormalizer(), reducer=MaxAbsReducer()) |
| 33 | +@test heatmap(A; nchannels=2, normalizer=RangeNormalizer(), reducer=SumReducer()) ≈ |
| 34 | + heatmap(B; nchannels=2, normalizer=RangeNormalizer(), reducer=SumReducer()) |
0 commit comments