Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 08de33b

Browse files
committed
add gradient test for transformations
fix
1 parent d49f1ac commit 08de33b

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

test/Transform/chebyshev_transform.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@
88
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
99
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
1010
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
11+
12+
@test_broken g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱)
13+
@test_broken size(g[1]) == (30, 40, 50, ch, batch)
1114
end

test/Transform/fourier_transform.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,7 @@
88
@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
99
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
1010
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch)
11+
12+
g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)))), 𝐱)
13+
@test size(g[1]) == (30, 40, 50, ch, batch)
1114
end

0 commit comments

Comments
 (0)