@@ -256,55 +256,57 @@ end
256256 f (du, u, p, t; scale = 1.0 ) = mul! (du, Diagonal (p * t * scale), u)
257257 f (u, p, t; scale = 1.0 ) = Diagonal (p * t * scale) * u
258258
259- L = FunctionOperator (f, u, u; p = zero (p), t = zero (t), batch = true ,
260- accepted_kwargs = (:scale ,), scale = 1.0 )
261-
262- @test_throws ArgumentError FunctionOperator (
263- f, u, u; p = zero (p), t = zero (t), batch = true ,
264- accepted_kwargs = (:scale ,))
265-
266- @test size (L) == (N, N)
267-
268- ans = @. u * p * t * scale
269- @test L (u, p, t; scale) ≈ ans
270- v = copy (u)
271- @test L (v, u, p, t; scale) ≈ ans
272-
273- # test that output isn't accidentally mutated by passing an internal cache.
274-
275- A = Diagonal (p * t * scale)
276- u1 = rand (N, K)
277- u2 = rand (N, K)
278-
279- v1 = L * u1
280- @test v1 ≈ A * u1
281- v2 = L * u2
282- @test v2 ≈ A * u2
283- @test v1 ≈ A * u1
284- @test v1 + v2 ≈ A * (u1 + u2)
285-
286- v1 .= 0.0
287- v2 .= 0.0
288-
289- mul! (v1, L, u1)
290- @test v1 ≈ A * u1
291- mul! (v2, L, u2)
292- @test v2 ≈ A * u2
293- @test v1 ≈ A * u1
294- @test v1 + v2 ≈ A * (u1 + u2)
295-
296- v1 = rand (N, K)
297- w1 = copy (v1)
298- v2 = rand (N, K)
299- w2 = copy (v2)
300- a1, a2, b1, b2 = rand (4 )
301-
302- mul! (v1, L, u1, a1, b1)
303- @test v1 ≈ a1 * A * u1 + b1 * w1
304- mul! (v2, L, u2, a2, b2)
305- @test v2 ≈ a2 * A * u2 + b2 * w2
306- @test v1 ≈ a1 * A * u1 + b1 * w1
307- @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2)
259+ for acc_kw in ((:scale ,), Val ((:scale ,)))
260+ L = FunctionOperator (f, u, u; p = zero (p), t = zero (t), batch = true ,
261+ accepted_kwargs = acc_kw, scale = 1.0 )
262+
263+ @test_throws ArgumentError FunctionOperator (
264+ f, u, u; p = zero (p), t = zero (t), batch = true ,
265+ accepted_kwargs = acc_kw)
266+
267+ @test size (L) == (N, N)
268+
269+ ans = @. u * p * t * scale
270+ @test L (u, p, t; scale) ≈ ans
271+ v = copy (u)
272+ @test L (v, u, p, t; scale) ≈ ans
273+
274+ # test that output isn't accidentally mutated by passing an internal cache.
275+
276+ A = Diagonal (p * t * scale)
277+ u1 = rand (N, K)
278+ u2 = rand (N, K)
279+
280+ v1 = L * u1
281+ @test v1 ≈ A * u1
282+ v2 = L * u2
283+ @test v2 ≈ A * u2
284+ @test v1 ≈ A * u1
285+ @test v1 + v2 ≈ A * (u1 + u2)
286+
287+ v1 .= 0.0
288+ v2 .= 0.0
289+
290+ mul! (v1, L, u1)
291+ @test v1 ≈ A * u1
292+ mul! (v2, L, u2)
293+ @test v2 ≈ A * u2
294+ @test v1 ≈ A * u1
295+ @test v1 + v2 ≈ A * (u1 + u2)
296+
297+ v1 = rand (N, K)
298+ w1 = copy (v1)
299+ v2 = rand (N, K)
300+ w2 = copy (v2)
301+ a1, a2, b1, b2 = rand (4 )
302+
303+ mul! (v1, L, u1, a1, b1)
304+ @test v1 ≈ a1 * A * u1 + b1 * w1
305+ mul! (v2, L, u2, a2, b2)
306+ @test v2 ≈ a2 * A * u2 + b2 * w2
307+ @test v1 ≈ a1 * A * u1 + b1 * w1
308+ @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2)
309+ end
308310
309311 # # Do the same with Val((:scale,))
310312
0 commit comments