|
305 | 305 | @test v2 ≈ a2 * A * u2 + b2 * w2 |
306 | 306 | @test v1 ≈ a1 * A * u1 + b1 * w1 |
307 | 307 | @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2) |
| 308 | + |
| 309 | + ## Do the same with Val((:scale,)) |
| 310 | + |
| 311 | + L = FunctionOperator(f, u, u; p = zero(p), t = zero(t), batch = true, |
| 312 | + accepted_kwargs = Val((:scale,)), scale = 1.0) |
| 313 | + |
| 314 | + @test_throws ArgumentError FunctionOperator( |
| 315 | + f, u, u; p = zero(p), t = zero(t), batch = true, |
| 316 | + accepted_kwargs = Val((:scale,))) |
| 317 | + |
| 318 | + @test size(L) == (N, N) |
| 319 | + |
| 320 | + ans = @. u * p * t * scale |
| 321 | + @test L(u, p, t; scale) ≈ ans |
| 322 | + v = copy(u) |
| 323 | + @test L(v, u, p, t; scale) ≈ ans |
| 324 | + |
| 325 | + # test that output isn't accidentally mutated by passing an internal cache. |
| 326 | + |
| 327 | + A = Diagonal(p * t * scale) |
| 328 | + u1 = rand(N, K) |
| 329 | + u2 = rand(N, K) |
| 330 | + |
| 331 | + v1 = L * u1 |
| 332 | + @test v1 ≈ A * u1 |
| 333 | + v2 = L * u2 |
| 334 | + @test v2 ≈ A * u2 |
| 335 | + @test v1 ≈ A * u1 |
| 336 | + @test v1 + v2 ≈ A * (u1 + u2) |
| 337 | + |
| 338 | + v1 .= 0.0 |
| 339 | + v2 .= 0.0 |
| 340 | + |
| 341 | + mul!(v1, L, u1) |
| 342 | + @test v1 ≈ A * u1 |
| 343 | + mul!(v2, L, u2) |
| 344 | + @test v2 ≈ A * u2 |
| 345 | + @test v1 ≈ A * u1 |
| 346 | + @test v1 + v2 ≈ A * (u1 + u2) |
| 347 | + |
| 348 | + v1 = rand(N, K) |
| 349 | + w1 = copy(v1) |
| 350 | + v2 = rand(N, K) |
| 351 | + w2 = copy(v2) |
| 352 | + a1, a2, b1, b2 = rand(4) |
| 353 | + |
| 354 | + mul!(v1, L, u1, a1, b1) |
| 355 | + @test v1 ≈ a1 * A * u1 + b1 * w1 |
| 356 | + mul!(v2, L, u2, a2, b2) |
| 357 | + @test v2 ≈ a2 * A * u2 + b2 * w2 |
| 358 | + @test v1 ≈ a1 * A * u1 + b1 * w1 |
| 359 | + @test v1 + v2 ≈ (a1 * A * u1 + b1 * w1) + (a2 * A * u2 + b2 * w2) |
308 | 360 | end |
309 | 361 | # |
0 commit comments