|
38 | 38 |
|
39 | 39 | using LinearAlgebra: qr, svd |
40 | 40 | # TODO: Define this in `MatrixAlgebra.jl`/`TensorAlgebra.jl`. |
41 | | -function factorize(a::AbstractITensor, args...; maxdim=nothing, cutoff=nothing, kwargs...) |
| 41 | +function factorize( |
| 42 | + a::AbstractITensor, codomain_inds; maxdim=nothing, cutoff=nothing, kwargs... |
| 43 | +) |
| 44 | + # TODO: Perform this intersection in `TensorAlgebra.qr`/`TensorAlgebra.svd`? |
| 45 | + # See https://github.com/ITensor/NamedDimsArrays.jl/issues/22. |
| 46 | + codomain_inds′ = intersect(inds(a), codomain_inds) |
42 | 47 | if isnothing(maxdim) && isnothing(cutoff) |
43 | | - Q, R = qr(a, args...) |
44 | | - return Q, R |
| 48 | + Q, R = qr(a, codomain_inds′) |
| 49 | + return Q, R, (; truncerr=zero(Bool),) |
45 | 50 | else |
46 | | - error("Truncation in `factorize` not implemented yet.") |
47 | | - U, S, V = svd(a, args...; kwargs...) |
48 | | - return U, S * V |
| 51 | + U, S, V = svd(a, codomain_inds′) |
| 52 | + # TODO: This is just a stand-in for truncated SVD |
| 53 | + # that only makes use of `maxdim`, just to get some |
| 54 | + # functionality running in `ITensorMPS.jl`. |
| 55 | + # Define a proper truncated SVD in |
| 56 | + # `MatrixAlgebra.jl`/`TensorAlgebra.jl`. |
| 57 | + r = Base.OneTo(min(maxdim, minimum(Int.(size(S))))) |
| 58 | + u = commonind(U, S) |
| 59 | + v = commonind(V, S) |
| 60 | + us = uniqueinds(U, S) |
| 61 | + vs = uniqueinds(V, S) |
| 62 | + U′ = U[(us .=> :)..., u => r] |
| 63 | + S′ = S[u => r, v => r] |
| 64 | + V′ = V[v => r, (vs .=> :)...] |
| 65 | + return U′, S′ * V′, (; truncerr=zero(Bool),) |
49 | 66 | end |
50 | 67 | end |
51 | 68 |
|
|
0 commit comments