Skip to content

Commit 7cd3adf

Browse files
committed
Fixed Zygote tests
1 parent 81adab6 commit 7cd3adf

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

src/basic.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,19 @@ end
170170
# Out-of-place: v is action vector, u is update vector
171171
function (nn::NullOperator)(v::AbstractVecOrMat, u, p, t; kwargs...)
172172
@assert size(v, 1) == nn.len
173-
update_coefficients(nn, u, p, t; kwargs...)
174173
zero(v)
175174
end
176175

177176
# In-place: w is destination, v is action vector, u is update vector
178177
function (nn::NullOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t; kwargs...)
179178
@assert size(v, 1) == nn.len
180-
update_coefficients!(nn, u, p, t; kwargs...)
181179
lmul!(false, w)
182180
w
183181
end
184182

185183
# In-place with scaling: w = α*(nn*v) + β*w
186184
function (nn::NullOperator)(w::AbstractVecOrMat, v::AbstractVecOrMat, u, p, t, α, β; kwargs...)
187185
@assert size(v, 1) == nn.len
188-
update_coefficients!(nn, u, p, t; kwargs...)
189186
lmul!(β, w)
190187
w
191188
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ using SafeTestsets
1717
include("total.jl")
1818
end
1919
@time @safetestset "Zygote.jl" begin
20-
# include("zygote.jl")
20+
include("zygote.jl")
2121
end
2222
end

test/zygote.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,26 @@ for (LType, L) in ((IdentityOperator, IdentityOperator(N)),
6666
(AddedScalarOperator, α + α),
6767
(ComposedScalarOperator, α * α))
6868
@assert L isa LType
69+
70+
# Cache the operator for efficient application
71+
L_cached = cache_operator(L, u0)
6972

73+
# Updated loss function using the new interface:
74+
# v is the action vector, u0 is the update vector
7075
loss_mul = function (p)
7176
v = Diagonal(p) * u0
72-
w = L(v, p, t)
77+
# Use new interface: L(v, u, p, t)
78+
w = L_cached(v, u0, p, t)
7379
l = sum(w)
7480
end
7581

7682
loss_div = function (p)
7783
v = Diagonal(p) * u0
78-
79-
L = update_coefficients(L, v, p, t)
80-
w = L \ v
81-
84+
85+
# Update coefficients first, then apply inverse
86+
L_updated = update_coefficients(L_cached, u0, p, t)
87+
w = L_updated \ v
88+
8289
l = sum(w)
8390
end
8491

@@ -99,4 +106,4 @@ for (LType, L) in ((IdentityOperator, IdentityOperator(N)),
99106
@test !isa(g_div, Nothing)
100107
end
101108
end
102-
end
109+
end

0 commit comments

Comments
 (0)