Skip to content

Commit c2bea39

Browse files
authored
Merge pull request #182 from vpuri3/zygote
Test update_coeffs with Zygote
2 parents e5c6b7c + d265a15 commit c2bea39

File tree

5 files changed

+74
-47
lines changed

5 files changed

+74
-47
lines changed

src/basic.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,12 @@ function Base.resize!(L::AddedOperator, n::Integer)
393393
end
394394

395395
function update_coefficients(L::AddedOperator, u, p, t)
396-
for i in 1:length(L.ops)
397-
@set! L.ops[i] = update_coefficients(L.ops[i], u, p, t)
396+
ops = ()
397+
for op in L.ops
398+
ops = (ops..., update_coefficients(op, u, p, t))
398399
end
399-
L
400+
401+
@set! L.ops = ops
400402
end
401403

402404
getops(L::AddedOperator) = L.ops
@@ -546,10 +548,12 @@ end
546548
LinearAlgebra.opnorm(L::ComposedOperator) = prod(opnorm, L.ops)
547549

548550
function update_coefficients(L::ComposedOperator, u, p, t)
549-
for i in 1:length(L.ops)
550-
@set! L.ops[i] = update_coefficients(L.ops[i], u, p, t)
551+
ops = ()
552+
for op in L.ops
553+
ops = (ops..., update_coefficients(op, u, p, t))
551554
end
552-
L
555+
556+
@set! L.ops = ops
553557
end
554558

555559
getops(L::ComposedOperator) = L.ops

src/interface.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,6 @@ DEFAULT_UPDATE_FUNC(A,u,p,t) = A
2020
update_coefficients(L,u,p,t) = L
2121
update_coefficients!(L,u,p,t) = L
2222

23-
function update_coefficients(L::AbstractSciMLOperator, u, p, t)
24-
@error """Out-of-place update method not implemented for $L.
25-
Please file an issue at https://github.com/SciML/SciMLOperators.jl
26-
with a minimal example."""
27-
end
28-
2923
function update_coefficients!(L::AbstractSciMLOperator, u, p, t)
3024
for op in getops(L)
3125
update_coefficients!(op, u, p, t)

src/scalar.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,12 @@ end
181181
Base.conj(L::AddedScalarOperator) = AddedScalarOperator(conj.(L.ops))
182182

183183
function update_coefficients(L::AddedScalarOperator, u, p, t)
184-
for i in 1:length(L.ops)
185-
@set! L.ops[i] = update_coefficients(L.ops[i], u, p, t)
184+
ops = ()
185+
for op in L.ops
186+
ops = (ops..., update_coefficients(op, u, p, t))
186187
end
187-
L
188+
189+
@set! L.ops = ops
188190
end
189191

190192
getops::AddedScalarOperator) = α.ops
@@ -232,10 +234,12 @@ Base.conj(L::ComposedScalarOperator) = ComposedScalarOperator(conj.(L.ops))
232234
Base.:-::AbstractSciMLScalarOperator{T}) where{T} = (-one(T)) * α
233235

234236
function update_coefficients(L::ComposedScalarOperator, u, p, t)
235-
for i in 1:length(L.ops)
236-
@set! L.ops[i] = update_coefficients(L.ops[i], u, p, t)
237+
ops = ()
238+
for op in L.ops
239+
ops = (ops..., update_coefficients(op, u, p, t))
237240
end
238-
L
241+
242+
@set! L.ops = ops
239243
end
240244

241245
getops::ComposedScalarOperator) = α.ops

src/tensor.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ end
8080
Base.conj(L::TensorProductOperator) = TensorProductOperator(conj.(L.ops)...; cache=L.cache)
8181

8282
function update_coefficients(L::TensorProductOperator, u, p, t)
83-
for i in 1:length(L.ops)
84-
@set! L.ops[i] = update_coefficients(L.ops[i], u, p, t)
83+
ops = ()
84+
for op in L.ops
85+
ops = (ops..., update_coefficients(op, u, p, t))
8586
end
86-
L
87+
88+
@set! L.ops = ops
8789
end
8890

8991
getops(L::TensorProductOperator) = L.ops

test/zygote.jl

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,66 +16,89 @@ n = 3
1616
N = n*n
1717
K = 12
1818

19+
t = rand()
1920
u0 = rand(N, K)
2021
ps = rand(N)
2122

22-
M = rand(N,N)
23-
24-
for (op_type, A) in
23+
s = rand()
24+
v = rand(N, K)
25+
M = rand(N, N)
26+
Mi= inv(M)
27+
28+
sca_update_func = (a, u, p, t) -> sum(p) * s
29+
vec_update_func = (b, u, p, t) -> Diagonal(p) * v
30+
mat_update_func = (A, u, p, t) -> Diagonal(p) * M
31+
inv_update_func = (A, u, p, t) -> Mi * inv(Diagonal(p))
32+
tsr_update_func = (A, u, p, t) -> reshape(p, n, n) |> copy
33+
34+
α = ScalarOperator(zero(Float32), update_func = sca_update_func)
35+
L_dia = DiagonalOperator(zeros(N, K); update_func = vec_update_func)
36+
L_mat = MatrixOperator(zeros(N, N); update_func = mat_update_func)
37+
L_mi = MatrixOperator(zeros(N, N); update_func = inv_update_func)
38+
L_aff = AffineOperator(L_mat, L_mat, zeros(N, K); update_func = vec_update_func)
39+
L_sca = α * L_mat
40+
L_inv = InvertibleOperator(L_mat, L_mi)
41+
L_fun = FunctionOperator((u,p,t) -> Diagonal(p) * u, u0, u0;
42+
op_inverse = (u,p,t) -> inv(Diagonal(p)) * u)
43+
44+
Ti = MatrixOperator(zeros(n, n); update_func = tsr_update_func)
45+
To = deepcopy(Ti)
46+
L_tsr = TensorProductOperator(To, Ti)
47+
48+
for (LType, L) in
2549
(
2650
(IdentityOperator, IdentityOperator(N)),
2751
(NullOperator, NullOperator(N)),
28-
(MatrixOperator, MatrixOperator(rand(N,N))),
29-
(AffineOperator, AffineOperator(rand(N,N), rand(N,N), rand(N,K))),
30-
(ScaledOperator, rand() * MatrixOperator(rand(N,N))),
31-
(InvertedOperator, InvertedOperator(rand(N,N) |> MatrixOperator)),
32-
(InvertibleOperator, InvertibleOperator(MatrixOperator(M), MatrixOperator(inv(M)))),
33-
(BatchedDiagonalOperator, DiagonalOperator(rand(N,K))),
34-
(AddedOperator, MatrixOperator(rand(N,N)) + MatrixOperator(rand(N,N))),
35-
(ComposedOperator, MatrixOperator(rand(N,N)) * MatrixOperator(rand(N,N))),
36-
(TensorProductOperator, TensorProductOperator(rand(n,n), rand(n,n))),
37-
(FunctionOperator, FunctionOperator((u,p,t)->M*u, u0, u0; op_inverse=(u,p,t)->M\u)),
52+
(MatrixOperator, L_mat),
53+
(AffineOperator, L_aff),
54+
(ScaledOperator, L_sca),
55+
(InvertedOperator, InvertedOperator(L_mat)),
56+
(InvertibleOperator, L_inv),
57+
(BatchedDiagonalOperator, L_dia),
58+
(AddedOperator, L_mat + L_dia),
59+
(ComposedOperator, L_mat * L_dia),
60+
(TensorProductOperator, L_tsr),
61+
(FunctionOperator, L_fun),
3862

3963
## ignore wrappers
40-
#(AdjointOperator, AdjointOperator(rand(N,N) |> MatrixOperator) |> adjoint),
41-
#(TransposedOperator, TransposedOperator(rand(N,N) |> MatrixOperator) |> transpose),
64+
# (AdjointOperator, AdjointOperator(rand(N,N) |> MatrixOperator) |> adjoint),
65+
# (TransposedOperator, TransposedOperator(rand(N,N) |> MatrixOperator) |> transpose),
4266

43-
(ScalarOperator, ScalarOperator(rand())),
44-
(AddedScalarOperator, ScalarOperator(rand()) + ScalarOperator(rand())),
45-
(ComposedScalarOperator, ScalarOperator(rand()) * ScalarOperator(rand())),
67+
(ScalarOperator, α),
68+
(AddedScalarOperator, α + α),
69+
(ComposedScalarOperator, α * α),
4670
)
4771

48-
@assert A isa op_type
72+
@assert L isa LType
4973

5074
loss_mul = function(p)
5175

5276
v = Diagonal(p) * u0
53-
54-
w = A * v
55-
77+
w = L(v, p, t)
5678
l = sum(w)
5779
end
5880

5981
loss_div = function(p)
6082

6183
v = Diagonal(p) * u0
6284

63-
w = A \ v
85+
L = update_coefficients(L, v, p, t)
86+
w = L \ v
6487

6588
l = sum(w)
6689
end
6790

68-
@testset "$op_type" begin
91+
@testset "$LType" begin
6992
l_mul = loss_mul(ps)
7093
g_mul = Zygote.gradient(loss_mul, ps)[1]
7194

72-
if A isa NullOperator
95+
if L isa NullOperator
7396
@test isa(g_mul, Nothing)
7497
else
7598
@test !isa(g_mul, Nothing)
7699
end
77100

78-
if has_ldiv(A)
101+
if has_ldiv(L)
79102
l_div = loss_div(ps)
80103
g_div = Zygote.gradient(loss_div, ps)[1]
81104

0 commit comments

Comments
 (0)