Skip to content

Commit d1f83d7

Browse files
committed
add tests
1 parent 9fa5308 commit d1f83d7

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

test/func.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,54 @@ N = 8
99
K = 12
1010
NK = N * K
1111

12+
@testset "(Unbatched) FunctionOperator ND array" begin
13+
N1, N2, N3 = 3, 4, 5
14+
M1, M2, M3 = 4, 5, 6
15+
16+
p = nothing
17+
t = 0.0
18+
α = rand()
19+
β = rand()
20+
21+
for (sz_in, sz_out) in (
22+
((N1, N2, N3), (N1, N2, N3)), # equal size
23+
((N1, N2, N3), (M1, M2, M3)), # different size
24+
)
25+
N = prod(sz_in)
26+
M = prod(sz_out)
27+
28+
A = rand(M, N)
29+
u = rand(sz_in... )
30+
v = rand(sz_out...)
31+
32+
_mul(A, u) = reshape(A * vec(u), sz_out)
33+
f(u, p, t) = _mul(A, u)
34+
f(du, u, p, t) = (mul!( vec(du), A, vec(u)); du)
35+
36+
kw = (;) # FunctionOp kwargs
37+
38+
if sz_in == sz_out
39+
F = lu(A)
40+
_div(A, v) = reshape(A \ vec(v), sz_in)
41+
fi(u, p, t) = _div(A, u)
42+
fi(du, u, p, t) = (ldiv!(vec(du), F, vec(u)); du)
43+
44+
kw = (; op_inverse = fi)
45+
end
46+
47+
L = FunctionOperator(f, u, v; kw...)
48+
L = cache_operator(L, u)
49+
50+
@test _mul(A, u) L(u, p, t) L * u mul!(zero(v), L, u)
51+
@test α * _mul(A, u)+ β * v mul!(copy(v), L, u, α, β)
52+
53+
if sz_in == sz_out
54+
@test _div(A, v) L \ v ldiv!(zero(u), L, v) ldiv!(L, copy(v))
55+
end
56+
end
57+
58+
end
59+
1260
@testset "(Unbatched) FunctionOperator" begin
1361
u = rand(N, K)
1462
p = nothing

0 commit comments

Comments
 (0)