@@ -33,6 +33,14 @@ function _select_truncation(f, ::AbstractTensorMap,
33
33
trunc:: MatrixAlgebraKit.TruncationStrategy )
34
34
return trunc
35
35
end
36
+ function _select_truncation (:: typeof (left_null!), :: AbstractTensorMap , trunc:: NamedTuple )
37
+ return MatrixAlgebraKit. null_truncation_strategy (; trunc... )
38
+ end
39
+
40
+ function MatrixAlgebraKit. diagview (t:: AbstractTensorMap )
41
+ return SectorDict (c => MatrixAlgebraKit. diagview (b) for (c, b) in blocks (t))
42
+ end
43
+
36
44
# function factorisation_scalartype(::typeof(MAK.eig_full!), t::AbstractTensorMap)
37
45
# T = scalartype(t)
38
46
# return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
@@ -103,6 +111,11 @@ function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), t::AbstractT
103
111
return U, S, Vᴴ
104
112
end
105
113
114
+ function MatrixAlgebraKit. initialize_output (:: typeof (svd_trunc!), t:: AbstractTensorMap ,
115
+ alg:: MatrixAlgebraKit.AbstractAlgorithm )
116
+ return MatrixAlgebraKit. initialize_output (svd_compact!, t, alg)
117
+ end
118
+
106
119
# TODO : svd_vals
107
120
108
121
function MatrixAlgebraKit. svd_full! (t:: AbstractTensorMap , (U, S, Vᴴ),
@@ -613,8 +626,69 @@ function MatrixAlgebraKit.left_orth!(t::AbstractTensorMap, VC;
613
626
throw (ArgumentError (" `left_orth!` received unknown value `kind = $kind `" ))
614
627
end
615
628
629
+ # Nullspace
630
+ # ---------
631
+ function MatrixAlgebraKit. check_input (:: typeof (left_null!), t:: AbstractTensorMap , N)
632
+ # scalartype checks
633
+ @check_eltype N t
634
+
635
+ # space checks
636
+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
637
+ V_N = setdiff (fuse (codomain (t)), V_Q)
638
+ space (N) == (codomain (t) ← V_N) ||
639
+ throw (SpaceMismatch (" `left_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`" ))
640
+
641
+ return nothing
642
+ end
643
+
644
+ function MatrixAlgebraKit. initialize_output (:: typeof (left_null!), t:: AbstractTensorMap )
645
+ V_Q = infimum (fuse (codomain (t)), fuse (domain (t)))
646
+ V_N = setdiff (fuse (codomain (t)), V_Q)
647
+ N = similar (t, codomain (t) ← V_N)
648
+ return N
649
+ end
650
+
651
+ # TODO : the following functions shouldn't be necessary if the AbstractArray restrictions are
652
+ # removed
653
+ function MatrixAlgebraKit. left_null (t:: AbstractTensorMap ; kwargs... )
654
+ return left_null! (MatrixAlgebraKit. copy_input (left_null, t); kwargs... )
655
+ end
656
+ function MatrixAlgebraKit. left_null! (t:: AbstractTensorMap ; kwargs... )
657
+ N = MatrixAlgebraKit. initialize_output (left_null!, t)
658
+ return left_null! (t, N; kwargs... )
659
+ end
660
+
661
+ function MatrixAlgebraKit. left_null! (t:: AbstractTensorMap , N;
662
+ trunc= nothing ,
663
+ kind= isnothing (trunc) ? :qr : :svd ,
664
+ alg_qr= (; positive= true ),
665
+ alg_svd= (;))
666
+ MatrixAlgebraKit. check_input (left_null!, t, N)
667
+
668
+ if ! isnothing (trunc) && kind != :svd
669
+ throw (ArgumentError (" truncation not supported for left_null with kind=$kind " ))
670
+ end
671
+
672
+ if kind == :qr
673
+ alg_qr′ = MatrixAlgebraKit. _select_algorithm (qr_null!, t, alg_qr)
674
+ return qr_null! (t, N, alg_qr′)
675
+ elseif kind == :svd && isnothing (trunc)
676
+ alg_svd′ = MatrixAlgebraKit. _select_algorithm (svd_full!, t, alg_svd)
677
+ # TODO : refactor into separate function
678
+ U, _, _ = svd_full! (t, alg_svd′)
679
+ for (c, b) in blocks (N)
680
+ bU = block (U, c)
681
+ m, n = size (bU)
682
+ copy! (b, @view (bU[1 : m, (n + 1 ): m]))
683
+ end
684
+ return N
685
+ elseif kind == :svd
686
+ alg_svd′ = MatrixAlgebraKit. _select_algorithm (svd_full!, t, alg_svd)
687
+ U, S, _ = svd_full! (t, alg_svd′)
688
+ trunc′ = _select_truncation (left_null!, t, trunc)
689
+ return MatrixAlgebraKit. truncate! (left_null!, (U, S), trunc′)
616
690
else
617
- throw (ArgumentError (" `left_orth !` received unknown value `kind = $kind `" ))
691
+ throw (ArgumentError (" `left_null !` received unknown value `kind = $kind `" ))
618
692
end
619
693
end
620
694
@@ -643,3 +717,57 @@ function MatrixAlgebraKit.truncate!(::typeof(svd_trunc!), (U, S, Vᴴ),
643
717
644
718
return Ũ, S̃, Ṽᴴ
645
719
end
720
+
721
+ function MatrixAlgebraKit. truncate! (:: typeof (left_null!),
722
+ (U, S):: Tuple {<: AbstractTensorMap ,
723
+ <: AbstractTensorMap },
724
+ strategy:: MatrixAlgebraKit.TruncationStrategy )
725
+ extended_S = SectorDict (c => vcat (MatrixAlgebraKit. diagview (b),
726
+ zeros (eltype (b), max (0 , size (b, 2 ) - size (b, 1 ))))
727
+ for (c, b) in blocks (S))
728
+ ind = MatrixAlgebraKit. findtruncated (extended_S, strategy)
729
+ V_truncated = spacetype (S)(c => length (axes (b, 1 )[ind[c]]) for (c, b) in blocks (S))
730
+ Ũ = similar (U, codomain (U) ← V_truncated)
731
+ for (c, b) in blocks (Ũ)
732
+ copy! (b, @view (block (U, c)[:, ind[c]]))
733
+ end
734
+ return Ũ
735
+ end
736
+
737
+ const BlockWiseTruncations = Union{MatrixAlgebraKit. TruncationKeepAbove,
738
+ MatrixAlgebraKit. TruncationKeepBelow,
739
+ MatrixAlgebraKit. TruncationKeepFiltered}
740
+
741
+ # TODO : relative tolerances should be global
742
+ function MatrixAlgebraKit. findtruncated (values:: SectorDict , strategy:: BlockWiseTruncations )
743
+ return SectorDict (c => MatrixAlgebraKit. findtruncated (v, strategy) for (c, v) in values)
744
+ end
745
+ function MatrixAlgebraKit. findtruncated (vals:: SectorDict ,
746
+ strategy:: MatrixAlgebraKit.TruncationKeepSorted )
747
+ allpairs = mapreduce (vcat, vals) do (c, v)
748
+ return map (Base. Fix1 (=> , c), axes (v, 1 ))
749
+ end
750
+ by ((c, i)) = strategy. sortby (vals[c][i])
751
+ sort! (allpairs; by, strategy. rev)
752
+
753
+ howmany = zero (Base. promote_op (dim, valtype (values)))
754
+ i = 1
755
+ while i ≤ length (allpairs)
756
+ howmany += dim (first (allpairs[i]))
757
+
758
+ howmany == strategy. howmany && break
759
+
760
+ if howmany > strategy. howmany
761
+ i -= 1
762
+ break
763
+ end
764
+
765
+ i += 1
766
+ end
767
+
768
+ ind = SectorDict (c => allpairs[findall (== (c) ∘ first, view (allpairs, 1 : i))]
769
+ for c in keys (vals))
770
+ filter! (! isempty ∘ last, ind) # TODO : this might not be necessary
771
+
772
+ return ind
773
+ end
0 commit comments