@@ -16,66 +16,89 @@ n = 3
1616N = n* n
1717K = 12
1818
19+ t = rand ()
1920u0 = rand (N, K)
2021ps = rand (N)
2122
22- M = rand (N,N)
23-
24- for (op_type, A) in
23+ s = rand ()
24+ v = rand (N, K)
25+ M = rand (N, N)
26+ Mi= inv (M)
27+
28+ sca_update_func = (a, u, p, t) -> sum (p) * s
29+ vec_update_func = (b, u, p, t) -> Diagonal (p) * v
30+ mat_update_func = (A, u, p, t) -> Diagonal (p) * M
31+ inv_update_func = (A, u, p, t) -> Mi * inv (Diagonal (p))
32+ tsr_update_func = (A, u, p, t) -> reshape (p, n, n) |> copy
33+
34+ α = ScalarOperator (zero (Float32), update_func = sca_update_func)
35+ L_dia = DiagonalOperator (zeros (N, K); update_func = vec_update_func)
36+ L_mat = MatrixOperator (zeros (N, N); update_func = mat_update_func)
37+ L_mi = MatrixOperator (zeros (N, N); update_func = inv_update_func)
38+ L_aff = AffineOperator (L_mat, L_mat, zeros (N, K); update_func = vec_update_func)
39+ L_sca = α * L_mat
40+ L_inv = InvertibleOperator (L_mat, L_mi)
41+ L_fun = FunctionOperator ((u,p,t) -> Diagonal (p) * u, u0, u0;
42+ op_inverse = (u,p,t) -> inv (Diagonal (p)) * u)
43+
44+ Ti = MatrixOperator (zeros (n, n); update_func = tsr_update_func)
45+ To = deepcopy (Ti)
46+ L_tsr = TensorProductOperator (To, Ti)
47+
48+ for (LType, L) in
2549 (
2650 (IdentityOperator, IdentityOperator (N)),
2751 (NullOperator, NullOperator (N)),
28- (MatrixOperator, MatrixOperator ( rand (N,N)) ),
29- (AffineOperator, AffineOperator ( rand (N,N), rand (N,N), rand (N,K)) ),
30- (ScaledOperator, rand () * MatrixOperator ( rand (N,N)) ),
31- (InvertedOperator, InvertedOperator (rand (N,N) |> MatrixOperator )),
32- (InvertibleOperator, InvertibleOperator ( MatrixOperator (M), MatrixOperator ( inv (M))) ),
33- (BatchedDiagonalOperator, DiagonalOperator ( rand (N,K)) ),
34- (AddedOperator, MatrixOperator ( rand (N,N)) + MatrixOperator ( rand (N,N)) ),
35- (ComposedOperator, MatrixOperator ( rand (N,N)) * MatrixOperator ( rand (N,N)) ),
36- (TensorProductOperator, TensorProductOperator ( rand (n,n), rand (n,n)) ),
37- (FunctionOperator, FunctionOperator ((u,p,t) -> M * u, u0, u0; op_inverse = (u,p,t) -> M \ u) ),
52+ (MatrixOperator, L_mat ),
53+ (AffineOperator, L_aff ),
54+ (ScaledOperator, L_sca ),
55+ (InvertedOperator, InvertedOperator (L_mat )),
56+ (InvertibleOperator, L_inv ),
57+ (BatchedDiagonalOperator, L_dia ),
58+ (AddedOperator, L_mat + L_dia ),
59+ (ComposedOperator, L_mat * L_dia ),
60+ (TensorProductOperator, L_tsr ),
61+ (FunctionOperator, L_fun ),
3862
3963 # # ignore wrappers
40- # (AdjointOperator, AdjointOperator(rand(N,N) |> MatrixOperator) |> adjoint),
41- # (TransposedOperator, TransposedOperator(rand(N,N) |> MatrixOperator) |> transpose),
64+ # (AdjointOperator, AdjointOperator(rand(N,N) |> MatrixOperator) |> adjoint),
65+ # (TransposedOperator, TransposedOperator(rand(N,N) |> MatrixOperator) |> transpose),
4266
43- (ScalarOperator, ScalarOperator ( rand ()) ),
44- (AddedScalarOperator, ScalarOperator ( rand ()) + ScalarOperator ( rand ()) ),
45- (ComposedScalarOperator, ScalarOperator ( rand ()) * ScalarOperator ( rand ()) ),
67+ (ScalarOperator, α ),
68+ (AddedScalarOperator, α + α ),
69+ (ComposedScalarOperator, α * α ),
4670 )
4771
48- @assert A isa op_type
72+ @assert L isa LType
4973
5074 loss_mul = function (p)
5175
5276 v = Diagonal (p) * u0
53-
54- w = A * v
55-
77+ w = L (v, p, t)
5678 l = sum (w)
5779 end
5880
5981 loss_div = function (p)
6082
6183 v = Diagonal (p) * u0
6284
63- w = A \ v
85+ L = update_coefficients (L, v, p, t)
86+ w = L \ v
6487
6588 l = sum (w)
6689 end
6790
68- @testset " $op_type " begin
91+ @testset " $LType " begin
6992 l_mul = loss_mul (ps)
7093 g_mul = Zygote. gradient (loss_mul, ps)[1 ]
7194
72- if A isa NullOperator
95+ if L isa NullOperator
7396 @test isa (g_mul, Nothing)
7497 else
7598 @test ! isa (g_mul, Nothing)
7699 end
77100
78- if has_ldiv (A )
101+ if has_ldiv (L )
79102 l_div = loss_div (ps)
80103 g_div = Zygote. gradient (loss_div, ps)[1 ]
81104
0 commit comments