Skip to content

Commit f9928a1

Browse files
committed
Rework left_null
1 parent 4bac6c8 commit f9928a1

File tree

2 files changed

+144
-26
lines changed

2 files changed

+144
-26
lines changed

src/tensors/factorizations.jl

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -383,36 +383,26 @@ function leftnull!(t::TensorMap{<:RealOrComplexFloat};
383383
eps(real(float(one(scalartype(t))))) * iszero(atol))
384384
InnerProductStyle(t) === EuclideanInnerProduct() ||
385385
throw_invalid_innerproduct(:leftnull!)
386-
if !iszero(rtol)
387-
atol = max(atol, rtol * norm(t))
388-
end
389-
I = sectortype(t)
390-
dims = SectorDict{I,Int}()
391386

392-
# compute QR factorization for each block
393-
V = codomain(t)
394-
if !isempty(blocksectors(V))
395-
generator = Base.Iterators.map(blocksectors(V)) do c
396-
Nc = MatrixAlgebra.leftnull!(block(t, c), alg, atol)
397-
dims[c] = size(Nc, 2)
398-
return c => Nc
387+
if alg == SVD() || alg == SDD()
388+
kind = :svd
389+
alg_svd = BlockAlgorithm(alg == SVD() ? MatrixAlgebraKit.LAPACK_QRIteration() :
390+
MatrixAlgebraKit.LAPACK_DivideAndConquer(),
391+
default_blockscheduler(t))
392+
trunc = if iszero(atol) && iszero(rtol)
393+
nothing
394+
else
395+
(; atol, rtol)
399396
end
400-
Ndata = SectorDict(generator)
397+
return left_null!(t; kind, alg_svd, trunc)
401398
end
402399

403-
# construct new space
404-
S = spacetype(t)
405-
W = S(dims)
400+
(iszero(atol) && iszero(rtol)) ||
401+
throw(ArgumentError("`leftnull!` with nonzero atol or rtol requires SVD or SDD algorithm"))
406402

407-
# construct output tensor
408-
T = float(scalartype(t))
409-
N = similar(t, T, V W)
410-
if !isempty(blocksectors(V))
411-
for (c, Nc) in Ndata
412-
copy!(block(N, c), Nc)
413-
end
414-
end
415-
return N
403+
kind = :qr
404+
alg_qr = (; positive=alg == QRpos())
405+
return left_null!(t; kind, alg_qr)
416406
end
417407

418408
function rightorth!(t::TensorMap{<:RealOrComplexFloat};

src/tensors/matrixalgebrakit.jl

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ function _select_truncation(f, ::AbstractTensorMap,
3333
trunc::MatrixAlgebraKit.TruncationStrategy)
3434
return trunc
3535
end
36+
function _select_truncation(::typeof(left_null!), ::AbstractTensorMap, trunc::NamedTuple)
37+
return MatrixAlgebraKit.null_truncation_strategy(; trunc...)
38+
end
39+
40+
function MatrixAlgebraKit.diagview(t::AbstractTensorMap)
41+
return SectorDict(c => MatrixAlgebraKit.diagview(b) for (c, b) in blocks(t))
42+
end
43+
3644
# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
3745
# T = scalartype(t)
3846
# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
@@ -103,6 +111,11 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractT
103111
return U, S, Vᴴ
104112
end
105113

114+
function MatrixAlgebraKit.initialize_output(::typeof(svd_trunc!), t::AbstractTensorMap,
115+
alg::MatrixAlgebraKit.AbstractAlgorithm)
116+
return MatrixAlgebraKit.initialize_output(svd_compact!, t, alg)
117+
end
118+
106119
# TODO: svd_vals
107120

108121
function MatrixAlgebraKit.svd_full!(t::AbstractTensorMap, (U, S, Vᴴ),
@@ -613,8 +626,69 @@ function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC;
613626
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
614627
end
615628

629+
# Nullspace
630+
# ---------
631+
function MatrixAlgebraKit.check_input(::typeof(left_null!), t::AbstractTensorMap, N)
632+
# scalartype checks
633+
@check_eltype N t
634+
635+
# space checks
636+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
637+
V_N = setdiff(fuse(codomain(t)), V_Q)
638+
space(N) == (codomain(t) V_N) ||
639+
throw(SpaceMismatch("`left_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`"))
640+
641+
return nothing
642+
end
643+
644+
function MatrixAlgebraKit.initialize_output(::typeof(left_null!), t::AbstractTensorMap)
645+
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
646+
V_N = setdiff(fuse(codomain(t)), V_Q)
647+
N = similar(t, codomain(t) V_N)
648+
return N
649+
end
650+
651+
# TODO: the following functions shouldn't be necessary if the AbstractArray restrictions are
652+
# removed
653+
function MatrixAlgebraKit.left_null(t::AbstractTensorMap; kwargs...)
654+
return left_null!(MatrixAlgebraKit.copy_input(left_null, t); kwargs...)
655+
end
656+
function MatrixAlgebraKit.left_null!(t::AbstractTensorMap; kwargs...)
657+
N = MatrixAlgebraKit.initialize_output(left_null!, t)
658+
return left_null!(t, N; kwargs...)
659+
end
660+
661+
function MatrixAlgebraKit.left_null!(t::AbstractTensorMap, N;
662+
trunc=nothing,
663+
kind=isnothing(trunc) ? :qr : :svd,
664+
alg_qr=(; positive=true),
665+
alg_svd=(;))
666+
MatrixAlgebraKit.check_input(left_null!, t, N)
667+
668+
if !isnothing(trunc) && kind != :svd
669+
throw(ArgumentError("truncation not supported for left_null with kind=$kind"))
670+
end
671+
672+
if kind == :qr
673+
alg_qr′ = MatrixAlgebraKit._select_algorithm(qr_null!, t, alg_qr)
674+
return qr_null!(t, N, alg_qr′)
675+
elseif kind == :svd && isnothing(trunc)
676+
alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd)
677+
# TODO: refactor into separate function
678+
U, _, _ = svd_full!(t, alg_svd′)
679+
for (c, b) in blocks(N)
680+
bU = block(U, c)
681+
m, n = size(bU)
682+
copy!(b, @view(bU[1:m, (n + 1):m]))
683+
end
684+
return N
685+
elseif kind == :svd
686+
alg_svd′ = MatrixAlgebraKit._select_algorithm(svd_full!, t, alg_svd)
687+
U, S, _ = svd_full!(t, alg_svd′)
688+
trunc′ = _select_truncation(left_null!, t, trunc)
689+
return MatrixAlgebraKit.truncate!(left_null!, (U, S), trunc′)
616690
else
617-
throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`"))
691+
throw(ArgumentError("`left_null!` received unknown value `kind = $kind`"))
618692
end
619693
end
620694

@@ -643,3 +717,57 @@ function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, Vᴴ),
643717

644718
return Ũ, S̃, Ṽᴴ
645719
end
720+
721+
function MatrixAlgebraKit.truncate!(::typeof(left_null!),
722+
(U, S)::Tuple{<:AbstractTensorMap,
723+
<:AbstractTensorMap},
724+
strategy::MatrixAlgebraKit.TruncationStrategy)
725+
extended_S = SectorDict(c => vcat(MatrixAlgebraKit.diagview(b),
726+
zeros(eltype(b), max(0, size(b, 2) - size(b, 1))))
727+
for (c, b) in blocks(S))
728+
ind = MatrixAlgebraKit.findtruncated(extended_S, strategy)
729+
V_truncated = spacetype(S)(c => length(axes(b, 1)[ind[c]]) for (c, b) in blocks(S))
730+
= similar(U, codomain(U) V_truncated)
731+
for (c, b) in blocks(Ũ)
732+
copy!(b, @view(block(U, c)[:, ind[c]]))
733+
end
734+
return
735+
end
736+
737+
const BlockWiseTruncations = Union{MatrixAlgebraKit.TruncationKeepAbove,
738+
MatrixAlgebraKit.TruncationKeepBelow,
739+
MatrixAlgebraKit.TruncationKeepFiltered}
740+
741+
# TODO: relative tolerances should be global
742+
function MatrixAlgebraKit.findtruncated(values::SectorDict, strategy::BlockWiseTruncations)
743+
return SectorDict(c => MatrixAlgebraKit.findtruncated(v, strategy) for (c, v) in values)
744+
end
745+
function MatrixAlgebraKit.findtruncated(vals::SectorDict,
746+
strategy::MatrixAlgebraKit.TruncationKeepSorted)
747+
allpairs = mapreduce(vcat, vals) do (c, v)
748+
return map(Base.Fix1(=>, c), axes(v, 1))
749+
end
750+
by((c, i)) = strategy.sortby(vals[c][i])
751+
sort!(allpairs; by, strategy.rev)
752+
753+
howmany = zero(Base.promote_op(dim, valtype(values)))
754+
i = 1
755+
while i length(allpairs)
756+
howmany += dim(first(allpairs[i]))
757+
758+
howmany == strategy.howmany && break
759+
760+
if howmany > strategy.howmany
761+
i -= 1
762+
break
763+
end
764+
765+
i += 1
766+
end
767+
768+
ind = SectorDict(c => allpairs[findall(==(c) first, view(allpairs, 1:i))]
769+
for c in keys(vals))
770+
filter!(!isempty last, ind) # TODO: this might not be necessary
771+
772+
return ind
773+
end

0 commit comments

Comments
 (0)