@@ -146,15 +146,24 @@ end
146
146
xp = Tracker. param (x)
147
147
Tracker. back! (sum (mapcols (F (wp), xp)))
148
148
@test Tracker. grad (xp) ≈ gradx
149
- @test_broken Tracker. grad (wp) ≈ gradw # zero
149
+ @test Tracker. grad (wp) == 0 .* gradw # bug or a feature?
150
150
151
- grad_mapcols = Zygote. gradient (() -> sum (mapcols (F (w), x)), Zygote. Params ([w,x]))
151
+ # fp = F(wp)
152
+ # wp.grad .= 0; xp.grad .= 0;
153
+ # Tracker.back!(sum(mapcols(fp, xp)))
154
+ # @test Tracker.grad(xp) ≈ gradx
155
+ # @test_broken Tracker.grad(wp) ≈ gradw # zero
156
+
157
+ f = F (w)
158
+ grad_mapcols = Zygote. gradient (() -> sum (mapcols (f, x)), Zygote. Params ([w,x]))
152
159
@test grad_mapcols[x] ≈ gradx
153
- @test_broken grad_mapcols[w] ≈ gradw # grad_mapcols[w] === nothing
160
+ @test grad_mapcols[w] == nothing # bug or a feature?
154
161
155
- grad_slicemap = Zygote. gradient (() -> sum (slicemap (F (w) , x, dims= 1 )), Zygote. Params ([w,x]))
162
+ grad_slicemap = Zygote. gradient (() -> sum (slicemap (f , x, dims= 1 )), Zygote. Params ([w,x]))
156
163
@test grad_slicemap[x] ≈ gradx
157
- @test_broken grad_slicemap[w] ≈ gradw # wrong numbers
164
+ @test grad_slicemap[w] ≈ gradw
158
165
@test gradw ≈ Zygote. gradient (w -> sum (slicemap (F (w), x, dims= 1 )), w)[1 ]
166
+ # Using F(w) with Params() gives wrong answers:
167
+ # https://github.com/FluxML/Zygote.jl/issues/522#issuecomment-605935652
159
168
160
169
end
0 commit comments