Skip to content

Commit 76f2879

Browse files
Clean tests
1 parent 3c9d1b7 commit 76f2879

File tree

1 file changed

+51
-49
lines changed

1 file changed

+51
-49
lines changed

test/func.jl

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -256,55 +256,57 @@ end
256256
f(du, u, p, t; scale = 1.0) = mul!(du, Diagonal(p * t * scale), u)
257257
f(u, p, t; scale = 1.0) = Diagonal(p * t * scale) * u
258258

259-
L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true,
260-
accepted_kwargs = (:scale,), scale = 1.0)
261-
262-
@test_throws ArgumentError FunctionOperator(
263-
f, u, u; p = zero(p), t = zero(t), batch = true,
264-
accepted_kwargs = (:scale,))
265-
266-
@test size(L) == (N, N)
267-
268-
ans = @. u * p * t * scale
269-
@test L(u, p, t; scale) ans
270-
v = copy(u)
271-
@test L(v, u, p, t; scale) ans
272-
273-
# test that output isn't accidentally mutated by passing an internal cache.
274-
275-
A = Diagonal(p * t * scale)
276-
u1 = rand(N, K)
277-
u2 = rand(N, K)
278-
279-
v1 = L * u1
280-
@test v1 A * u1
281-
v2 = L * u2
282-
@test v2 A * u2
283-
@test v1 A * u1
284-
@test v1 + v2 A * (u1 + u2)
285-
286-
v1 .= 0.0
287-
v2 .= 0.0
288-
289-
mul!(v1, L, u1)
290-
@test v1 A * u1
291-
mul!(v2, L, u2)
292-
@test v2 A * u2
293-
@test v1 A * u1
294-
@test v1 + v2 A * (u1 + u2)
295-
296-
v1 = rand(N, K)
297-
w1 = copy(v1)
298-
v2 = rand(N, K)
299-
w2 = copy(v2)
300-
a1, a2, b1, b2 = rand(4)
301-
302-
mul!(v1, L, u1, a1, b1)
303-
@test v1 a1 * A * u1 + b1 * w1
304-
mul!(v2, L, u2, a2, b2)
305-
@test v2 a2 * A * u2 + b2 * w2
306-
@test v1 a1 * A * u1 + b1 * w1
307-
@test v1 + v2 (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2)
259+
for acc_kw in ((:scale,), Val((:scale,)))
260+
L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true,
261+
accepted_kwargs = acc_kw, scale = 1.0)
262+
263+
@test_throws ArgumentError FunctionOperator(
264+
f, u, u; p = zero(p), t = zero(t), batch = true,
265+
accepted_kwargs = acc_kw)
266+
267+
@test size(L) == (N, N)
268+
269+
ans = @. u * p * t * scale
270+
@test L(u, p, t; scale) ans
271+
v = copy(u)
272+
@test L(v, u, p, t; scale) ans
273+
274+
# test that output isn't accidentally mutated by passing an internal cache.
275+
276+
A = Diagonal(p * t * scale)
277+
u1 = rand(N, K)
278+
u2 = rand(N, K)
279+
280+
v1 = L * u1
281+
@test v1 A * u1
282+
v2 = L * u2
283+
@test v2 A * u2
284+
@test v1 A * u1
285+
@test v1 + v2 A * (u1 + u2)
286+
287+
v1 .= 0.0
288+
v2 .= 0.0
289+
290+
mul!(v1, L, u1)
291+
@test v1 A * u1
292+
mul!(v2, L, u2)
293+
@test v2 A * u2
294+
@test v1 A * u1
295+
@test v1 + v2 A * (u1 + u2)
296+
297+
v1 = rand(N, K)
298+
w1 = copy(v1)
299+
v2 = rand(N, K)
300+
w2 = copy(v2)
301+
a1, a2, b1, b2 = rand(4)
302+
303+
mul!(v1, L, u1, a1, b1)
304+
@test v1 a1 * A * u1 + b1 * w1
305+
mul!(v2, L, u2, a2, b2)
306+
@test v2 a2 * A * u2 + b2 * w2
307+
@test v1 a1 * A * u1 + b1 * w1
308+
@test v1 + v2 (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2)
309+
end
308310

309311
## Do the same with Val((:scale,))
310312

0 commit comments

Comments
 (0)