|
| 1 | +# |
| 2 | +using SciMLOperators, Zygote, LinearAlgebra |
| 3 | +using Random |
| 4 | + |
| 5 | +using SciMLOperators |
| 6 | +using SciMLOperators: AbstractSciMLOperator, |
| 7 | + IdentityOperator, NullOperator, |
| 8 | + AdjointOperator, TransposedOperator, |
| 9 | + InvertedOperator, InvertibleOperator, |
| 10 | + BatchedDiagonalOperator, AddedOperator, ComposedOperator, |
| 11 | + AddedScalarOperator, ComposedScalarOperator, ScaledOperator, |
| 12 | + has_mul, has_ldiv |
| 13 | + |
| 14 | +Random.seed!(0) |
| 15 | +n = 3 |
| 16 | +N = n*n |
| 17 | +K = 12 |
| 18 | + |
| 19 | +u0 = rand(N, K) |
| 20 | +ps = rand(N) |
| 21 | + |
| 22 | +M = rand(N,N) |
| 23 | + |
| 24 | +for (op_type, A) in |
| 25 | + ( |
| 26 | + (IdentityOperator, IdentityOperator{N}()), |
| 27 | + (NullOperator, NullOperator{N}()), |
| 28 | + (MatrixOperator, MatrixOperator(rand(N,N))), |
| 29 | + (AffineOperator, AffineOperator(rand(N,N), rand(N,N), rand(N,K))), |
| 30 | + (ScaledOperator, rand() * MatrixOperator(rand(N,N))), |
| 31 | + (InvertedOperator, InvertedOperator(rand(N,N) |> MatrixOperator)), |
| 32 | + (InvertibleOperator, InvertibleOperator(rand(N,N) |> MatrixOperator)), |
| 33 | + (BatchedDiagonalOperator, DiagonalOperator(rand(N,K))), |
| 34 | + (AddedOperator, MatrixOperator(rand(N,N)) + MatrixOperator(rand(N,N))), |
| 35 | + (ComposedOperator, MatrixOperator(rand(N,N)) * MatrixOperator(rand(N,N))), |
| 36 | + (TensorProductOperator, TensorProductOperator(rand(n,n), rand(n,n))), |
| 37 | + (FunctionOperator, FunctionOperator((u,p,t)->M*u, op_inverse=(u,p,t)->M\u, |
| 38 | + T=Float64, isinplace=false, size=(N,N), |
| 39 | + input_prototype=u0, output_prototype=u0)), |
| 40 | + |
| 41 | + ## ignore wrappers |
| 42 | + #(AdjointOperator, AdjointOperator(rand(N,N) |> MatrixOperator) |> adjoint), |
| 43 | + #(TransposedOperator, TransposedOperator(rand(N,N) |> MatrixOperator) |> transpose), |
| 44 | + |
| 45 | + (ScalarOperator, ScalarOperator(rand())), |
| 46 | + (AddedScalarOperator, ScalarOperator(rand()) + ScalarOperator(rand())), |
| 47 | + (ComposedScalarOperator, ScalarOperator(rand()) * ScalarOperator(rand())), |
| 48 | + ) |
| 49 | + |
| 50 | + @assert A isa op_type |
| 51 | + |
| 52 | + loss_mul = function(p) |
| 53 | + |
| 54 | + v = Diagonal(p) * u0 |
| 55 | + |
| 56 | + w = A * v |
| 57 | + |
| 58 | + l = sum(w) |
| 59 | + end |
| 60 | + |
| 61 | + loss_div = function(p) |
| 62 | + |
| 63 | + v = Diagonal(p) * u0 |
| 64 | + |
| 65 | + w = A \ v |
| 66 | + |
| 67 | + l = sum(w) |
| 68 | + end |
| 69 | + |
| 70 | + @testset "$op_type" begin |
| 71 | + l_mul = loss_mul(ps) |
| 72 | + g_mul = Zygote.gradient(loss_mul, ps)[1] |
| 73 | + |
| 74 | + if A isa NullOperator |
| 75 | + @test isa(g_mul, Nothing) |
| 76 | + else |
| 77 | + @test !isa(g_mul, Nothing) |
| 78 | + end |
| 79 | + |
| 80 | + if has_ldiv(A) |
| 81 | + l_div = loss_div(ps) |
| 82 | + g_div = Zygote.gradient(loss_div, ps)[1] |
| 83 | + |
| 84 | + @test !isa(g_div, Nothing) |
| 85 | + end |
| 86 | + end |
| 87 | +end |
| 88 | + |
0 commit comments