Skip to content

Commit 5851335

Browse files
Merge pull request #94 from vpuri3/zygote
Zygote gradient tests
2 parents c25e53b + 3003f30 commit 5851335

File tree

6 files changed

+122
-4
lines changed

6 files changed

+122
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLOperators"
22
uuid = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
33
authors = ["xtalax <[email protected]>"]
4-
version = "0.1.9"
4+
version = "0.1.10"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/basic.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ struct AddedOperator{T,
293293
ops::O
294294

295295
function AddedOperator(ops)
296+
@assert !isempty(ops)
296297
T = promote_type(eltype.(ops)...)
297298
new{T,typeof(ops)}(ops)
298299
end
@@ -414,6 +415,7 @@ struct ComposedOperator{T,O,C} <: AbstractSciMLOperator{T}
414415
isset::Bool
415416

416417
function ComposedOperator(ops, cache, isset::Bool)
418+
@assert !isempty(ops)
417419
for i in reverse(2:length(ops))
418420
opcurr = ops[i]
419421
opnext = ops[i-1]
@@ -518,8 +520,27 @@ for fact in (
518520
end
519521

520522
# operator application
521-
Base.:*(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op * acc, reverse(L.ops); init=u)
522-
Base.:\(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op \ acc, L.ops; init=u)
523+
# https://github.com/SciML/SciMLOperators.jl/pull/94
524+
#Base.:*(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op * acc, reverse(L.ops); init=u)
525+
#Base.:\(L::ComposedOperator, u::AbstractVecOrMat) = foldl((acc, op) -> op \ acc, L.ops; init=u)
526+
527+
function Base.:\(L::ComposedOperator, u::AbstractVecOrMat)
528+
v = u
529+
for op in L.ops
530+
v = op \ v
531+
end
532+
533+
v
534+
end
535+
536+
function Base.:*(L::ComposedOperator, u::AbstractVecOrMat)
537+
v = u
538+
for op in reverse(L.ops)
539+
v = op * v
540+
end
541+
542+
v
543+
end
523544

524545
function cache_self(L::ComposedOperator, u::AbstractVecOrMat)
525546
vec = zero(u)

src/scalar.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct AddedScalarOperator{T,O} <: AbstractSciMLScalarOperator{T}
115115
ops::O
116116

117117
function AddedScalarOperator(ops::NTuple{N,AbstractSciMLScalarOperator}) where{N}
118+
@assert !isempty(ops)
118119
T = promote_type(eltype.(ops)...)
119120
new{T,typeof(ops)}(ops)
120121
end
@@ -141,12 +142,14 @@ for op in (
141142
end
142143

143144
function Base.convert(::Type{Number}, α::AddedScalarOperator{T}) where{T}
144-
sum(op -> convert(Number, op), α.ops; init=zero(T))
145+
sum(op -> convert(Number, op), α.ops)
145146
end
146147

147148
Base.conj(L::AddedScalarOperator) = AddedScalarOperator(conj.(L.ops))
148149

149150
getops::AddedScalarOperator) = α.ops
151+
has_ldiv::AddedScalarOperator) = !iszero(convert(Number, α))
152+
has_ldiv!::AddedScalarOperator) = has_ldiv(α)
150153

151154
"""
152155
Lazy composition of Scalar Operators
@@ -155,6 +158,7 @@ struct ComposedScalarOperator{T,O} <: AbstractSciMLScalarOperator{T}
155158
ops::O
156159

157160
function ComposedScalarOperator(ops::NTuple{N,AbstractSciMLScalarOperator}) where{N}
161+
@assert !isempty(ops)
158162
T = promote_type(eltype.(ops)...)
159163
new{T,typeof(ops)}(ops)
160164
end
@@ -188,4 +192,6 @@ Base.conj(L::ComposedScalarOperator) = ComposedScalarOperator(conj.(L.ops))
188192
Base.:-::AbstractSciMLScalarOperator{T}) where{T} = (-one(T)) * α
189193

190194
getops::ComposedScalarOperator) = α.ops
195+
has_ldiv::ComposedScalarOperator) = all(has_ldiv, α.ops)
196+
has_ldiv!::ComposedScalarOperator) = all(has_ldiv!, α.ops)
191197
#

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
55
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
66
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
77
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,7 @@ if GROUP == "All" || GROUP == "OperatorInterface"
1111
@time @safetestset "Matrix Operators" begin include("matrix.jl") end
1212
@time @safetestset "Function Operator" begin include("func.jl") end
1313
@time @safetestset "Full tests" begin include("total.jl") end
14+
15+
@time @safetestset "Zygote.jl" begin include("zygote.jl") end
1416
end
1517
end

test/zygote.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#
2+
using SciMLOperators, Zygote, LinearAlgebra
3+
using Random
4+
5+
using SciMLOperators
6+
using SciMLOperators: AbstractSciMLOperator,
7+
IdentityOperator, NullOperator,
8+
AdjointOperator, TransposedOperator,
9+
InvertedOperator, InvertibleOperator,
10+
BatchedDiagonalOperator, AddedOperator, ComposedOperator,
11+
AddedScalarOperator, ComposedScalarOperator, ScaledOperator,
12+
has_mul, has_ldiv
13+
14+
Random.seed!(0)
15+
n = 3
16+
N = n*n
17+
K = 12
18+
19+
u0 = rand(N, K)
20+
ps = rand(N)
21+
22+
M = rand(N,N)
23+
24+
for (op_type, A) in
25+
(
26+
(IdentityOperator, IdentityOperator{N}()),
27+
(NullOperator, NullOperator{N}()),
28+
(MatrixOperator, MatrixOperator(rand(N,N))),
29+
(AffineOperator, AffineOperator(rand(N,N), rand(N,N), rand(N,K))),
30+
(ScaledOperator, rand() * MatrixOperator(rand(N,N))),
31+
(InvertedOperator, InvertedOperator(rand(N,N) |> MatrixOperator)),
32+
(InvertibleOperator, InvertibleOperator(rand(N,N) |> MatrixOperator)),
33+
(BatchedDiagonalOperator, DiagonalOperator(rand(N,K))),
34+
(AddedOperator, MatrixOperator(rand(N,N)) + MatrixOperator(rand(N,N))),
35+
(ComposedOperator, MatrixOperator(rand(N,N)) * MatrixOperator(rand(N,N))),
36+
(TensorProductOperator, TensorProductOperator(rand(n,n), rand(n,n))),
37+
(FunctionOperator, FunctionOperator((u,p,t)->M*u, op_inverse=(u,p,t)->M\u,
38+
T=Float64, isinplace=false, size=(N,N),
39+
input_prototype=u0, output_prototype=u0)),
40+
41+
## ignore wrappers
42+
#(AdjointOperator, AdjointOperator(rand(N,N) |> MatrixOperator) |> adjoint),
43+
#(TransposedOperator, TransposedOperator(rand(N,N) |> MatrixOperator) |> transpose),
44+
45+
(ScalarOperator, ScalarOperator(rand())),
46+
(AddedScalarOperator, ScalarOperator(rand()) + ScalarOperator(rand())),
47+
(ComposedScalarOperator, ScalarOperator(rand()) * ScalarOperator(rand())),
48+
)
49+
50+
@assert A isa op_type
51+
52+
loss_mul = function(p)
53+
54+
v = Diagonal(p) * u0
55+
56+
w = A * v
57+
58+
l = sum(w)
59+
end
60+
61+
loss_div = function(p)
62+
63+
v = Diagonal(p) * u0
64+
65+
w = A \ v
66+
67+
l = sum(w)
68+
end
69+
70+
@testset "$op_type" begin
71+
l_mul = loss_mul(ps)
72+
g_mul = Zygote.gradient(loss_mul, ps)[1]
73+
74+
if A isa NullOperator
75+
@test isa(g_mul, Nothing)
76+
else
77+
@test !isa(g_mul, Nothing)
78+
end
79+
80+
if has_ldiv(A)
81+
l_div = loss_div(ps)
82+
g_div = Zygote.gradient(loss_div, ps)[1]
83+
84+
@test !isa(g_div, Nothing)
85+
end
86+
end
87+
end
88+

0 commit comments

Comments
 (0)