|
1 | 1 | using LinearAlgebra: Diagonal, I, diag, isposdef, norm
|
2 |
| -using MatrixAlgebraKit: qr_compact, svd_trunc |
| 2 | +using MatrixAlgebraKit: qr_compact, svd_trunc, truncrank |
3 | 3 | using StableRNGs: StableRNG
|
4 |
| -using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncerr |
| 4 | +using TensorAlgebra.MatrixAlgebra: MatrixAlgebra, truncdegen, truncerr |
5 | 5 | using Test: @test, @testset
|
6 | 6 |
|
7 | 7 | elts = (Float32, Float64, ComplexF32, ComplexF64)
|
@@ -304,4 +304,140 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
|
304 | 304 | @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.001])
|
305 | 305 | @test ũ * s̃ * ṽ ≈ a atol = 0.002 rtol = 0.002
|
306 | 306 | end
|
| 307 | + @testset "Truncate degenerate" begin |
| 308 | + s = Diagonal(real(elt)[2.0, 0.32, 0.3, 0.29, 0.01, 0.01]) |
| 309 | + n = length(diag(s)) |
| 310 | + rng = StableRNG(123) |
| 311 | + u, _ = qr_compact(randn(rng, elt, n, n); positive=true) |
| 312 | + v, _ = qr_compact(randn(rng, elt, n, n); positive=true) |
| 313 | + a = u * s * v |
| 314 | + |
| 315 | + ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(n); atol=0.1)) |
| 316 | + @test size(ũ) == (n, n) |
| 317 | + @test size(s̃) == (n, n) |
| 318 | + @test size(ṽ) == (n, n) |
| 319 | + @test ũ * s̃ * ṽ ≈ a |
| 320 | + |
| 321 | + for kwargs in ( |
| 322 | + (; atol=eps(real(elt))), |
| 323 | + (; rtol=(√eps(real(elt)))), |
| 324 | + (; atol=eps(real(elt)), rtol=(√eps(real(elt)))), |
| 325 | + ) |
| 326 | + ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(5); kwargs...)) |
| 327 | + @test size(ũ) == (n, 4) |
| 328 | + @test size(s̃) == (4, 4) |
| 329 | + @test size(ṽ) == (4, n) |
| 330 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) |
| 331 | + end |
| 332 | + |
| 333 | + for kwargs in ( |
| 334 | + (; atol=eps(real(elt))), |
| 335 | + (; rtol=eps(real(elt))), |
| 336 | + (; atol=eps(real(elt)), rtol=eps(real(elt))), |
| 337 | + ) |
| 338 | + ũ, s̃, ṽ = svd_trunc(a; trunc=truncdegen(truncrank(4); kwargs...)) |
| 339 | + @test size(ũ) == (n, 4) |
| 340 | + @test size(s̃) == (4, 4) |
| 341 | + @test size(ṽ) == (4, n) |
| 342 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.01, 0.01]) |
| 343 | + end |
| 344 | + |
| 345 | + trunc = truncdegen(truncrank(3); atol=0.01 - √eps(real(elt))) |
| 346 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 347 | + @test size(ũ) == (n, 3) |
| 348 | + @test size(s̃) == (3, 3) |
| 349 | + @test size(ṽ) == (3, n) |
| 350 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) |
| 351 | + |
| 352 | + trunc = truncdegen(truncrank(3); rtol=0.01/0.3 - √eps(real(elt))) |
| 353 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 354 | + @test size(ũ) == (n, 3) |
| 355 | + @test size(s̃) == (3, 3) |
| 356 | + @test size(ṽ) == (3, n) |
| 357 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.29, 0.01, 0.01]) |
| 358 | + |
| 359 | + trunc = truncdegen(truncrank(3); atol=0.01 + √eps(real(elt))) |
| 360 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 361 | + @test size(ũ) == (n, 2) |
| 362 | + @test size(s̃) == (2, 2) |
| 363 | + @test size(ṽ) == (2, n) |
| 364 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) |
| 365 | + |
| 366 | + trunc = truncdegen(truncrank(3); rtol=0.01/0.29 + √eps(real(elt))) |
| 367 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 368 | + @test size(ũ) == (n, 2) |
| 369 | + @test size(s̃) == (2, 2) |
| 370 | + @test size(ṽ) == (2, n) |
| 371 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) |
| 372 | + |
| 373 | + trunc = truncdegen(truncrank(3); atol=0.02 - √eps(real(elt))) |
| 374 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 375 | + @test size(ũ) == (n, 2) |
| 376 | + @test size(s̃) == (2, 2) |
| 377 | + @test size(ṽ) == (2, n) |
| 378 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) |
| 379 | + |
| 380 | + trunc = truncdegen(truncrank(3); rtol=0.02/0.29 - √eps(real(elt))) |
| 381 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 382 | + @test size(ũ) == (n, 2) |
| 383 | + @test size(s̃) == (2, 2) |
| 384 | + @test size(ṽ) == (2, n) |
| 385 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.3, 0.29, 0.01, 0.01]) |
| 386 | + |
| 387 | + trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt))) |
| 388 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 389 | + @test size(ũ) == (n, 1) |
| 390 | + @test size(s̃) == (1, 1) |
| 391 | + @test size(ṽ) == (1, n) |
| 392 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) |
| 393 | + |
| 394 | + trunc = truncdegen(truncrank(3); rtol=0.03/0.29 + √eps(real(elt))) |
| 395 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 396 | + @test size(ũ) == (n, 1) |
| 397 | + @test size(s̃) == (1, 1) |
| 398 | + @test size(ṽ) == (1, n) |
| 399 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) |
| 400 | + |
| 401 | + trunc = truncdegen(truncrank(3); atol=0.01, rtol=0.03/0.29 + √eps(real(elt))) |
| 402 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 403 | + @test size(ũ) == (n, 1) |
| 404 | + @test size(s̃) == (1, 1) |
| 405 | + @test size(ṽ) == (1, n) |
| 406 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) |
| 407 | + |
| 408 | + trunc = truncdegen(truncrank(3); atol=0.03 + √eps(real(elt)), rtol=0.01/0.29) |
| 409 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 410 | + @test size(ũ) == (n, 1) |
| 411 | + @test size(s̃) == (1, 1) |
| 412 | + @test size(ṽ) == (1, n) |
| 413 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) |
| 414 | + |
| 415 | + trunc = truncdegen(truncrank(3); atol=(2 - 0.29) - √(eps(real(elt)))) |
| 416 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 417 | + @test size(ũ) == (n, 1) |
| 418 | + @test size(s̃) == (1, 1) |
| 419 | + @test size(ṽ) == (1, n) |
| 420 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) |
| 421 | + |
| 422 | + trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 - √(eps(real(elt)))) |
| 423 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 424 | + @test size(ũ) == (n, 1) |
| 425 | + @test size(s̃) == (1, 1) |
| 426 | + @test size(ṽ) == (1, n) |
| 427 | + @test norm(ũ * s̃ * ṽ - a) ≈ norm([0.32, 0.3, 0.29, 0.01, 0.01]) |
| 428 | + |
| 429 | + trunc = truncdegen(truncrank(3); atol=(2 - 0.29) + √(eps(real(elt)))) |
| 430 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 431 | + @test size(ũ) == (n, 0) |
| 432 | + @test size(s̃) == (0, 0) |
| 433 | + @test size(ṽ) == (0, n) |
| 434 | + @test norm(ũ * s̃ * ṽ) ≈ 0 |
| 435 | + |
| 436 | + trunc = truncdegen(truncrank(3); rtol=(2 - 0.29)/0.29 + √(eps(real(elt)))) |
| 437 | + ũ, s̃, ṽ = svd_trunc(a; trunc) |
| 438 | + @test size(ũ) == (n, 0) |
| 439 | + @test size(s̃) == (0, 0) |
| 440 | + @test size(ṽ) == (0, n) |
| 441 | + @test norm(ũ * s̃ * ṽ) ≈ 0 |
| 442 | + end |
307 | 443 | end
|
0 commit comments