Skip to content

Commit 71904f9

Browse files
committed
Change Base.setdiff for ominus
1 parent bcbbd33 commit 71904f9

File tree

6 files changed

+40
-33
lines changed

6 files changed

+40
-33
lines changed

docs/src/lib/spaces.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ dual
9090
conj
9191
flip
9292
93+
9394
zero(::ElementarySpace)
9495
oneunit
9596
supremum

src/spaces/cartesianspace.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,17 @@ sectortype(::Type{CartesianSpace}) = Trivial
4949

5050
Base.oneunit(::Type{CartesianSpace}) = CartesianSpace(1)
5151
Base.zero(::Type{CartesianSpace}) = CartesianSpace(0)
52+
5253
(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d + V₂.d)
54+
function (V::CartesianSpace, W::CartesianSpace)
55+
V W || throw(ArgumentError("$(W) is not a subspace of $(V)"))
56+
return CartesianSpace(dim(V) - dim(W))
57+
end
58+
5359
fuse(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(V₁.d * V₂.d)
5460
flip(V::CartesianSpace) = V
5561

5662
infimum(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(min(V₁.d, V₂.d))
5763
supremum(V₁::CartesianSpace, V₂::CartesianSpace) = CartesianSpace(max(V₁.d, V₂.d))
5864

59-
function Base.setdiff(V::CartesianSpace, W::CartesianSpace)
60-
V W || throw(ArgumentError("$(W) is not a subspace of $(V)"))
61-
return CartesianSpace(dim(V) - dim(W))
62-
end
63-
6465
Base.show(io::IO, V::CartesianSpace) = print(io, "ℝ^$(V.d)")

src/spaces/complexspace.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,18 @@ Base.conj(V::ComplexSpace) = ComplexSpace(dim(V), !isdual(V))
5050

5151
Base.oneunit(::Type{ComplexSpace}) = ComplexSpace(1)
5252
Base.zero(::Type{ComplexSpace}) = ComplexSpace(0)
53+
5354
function (V₁::ComplexSpace, V₂::ComplexSpace)
5455
return isdual(V₁) == isdual(V₂) ?
5556
ComplexSpace(dim(V₁) + dim(V₂), isdual(V₁)) :
5657
throw(SpaceMismatch("Direct sum of a vector space and its dual does not exist"))
5758
end
59+
function (V::ComplexSpace, W::ComplexSpace)
60+
(V W && isdual(V) == isdual(W)) ||
61+
throw(ArgumentError("$(W) is not a subspace of $(V)"))
62+
return ComplexSpace(dim(V) - dim(W), isdual(V))
63+
end
64+
5865
fuse(V₁::ComplexSpace, V₂::ComplexSpace) = ComplexSpace(V₁.d * V₂.d)
5966
flip(V::ComplexSpace) = dual(V)
6067

@@ -69,10 +76,4 @@ function supremum(V₁::ComplexSpace, V₂::ComplexSpace)
6976
throw(SpaceMismatch("Supremum of space and dual space does not exist"))
7077
end
7178

72-
function Base.setdiff(V::ComplexSpace, W::ComplexSpace)
73-
(V W && isdual(V) == isdual(W)) ||
74-
throw(ArgumentError("$(W) is not a subspace of $(V)"))
75-
return ComplexSpace(dim(V) - dim(W), isdual(V))
76-
end
77-
7879
Base.show(io::IO, V::ComplexSpace) = print(io, isdual(V) ? "(ℂ^$(V.d))'" : "ℂ^$(V.d)")

src/spaces/gradedspace.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ function ⊕(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector}
149149
return typeof(V₁)(dims; dual=dual1)
150150
end
151151

152+
function (V::GradedSpace{I}, W::GradedSpace{I}) where {I<:Sector}
153+
V W && isdual(V) == isdual(W) ||
154+
throw(SpaceMismatch("$(W) is not a subspace of $(V)"))
155+
return typeof(V)(c => dim(V, c) - dim(W, c) for c in sectors(V))
156+
end
157+
152158
function flip(V::GradedSpace{I}) where {I<:Sector}
153159
if isdual(V)
154160
typeof(V)(c => dim(V, c) for c in sectors(V))
@@ -187,12 +193,6 @@ function supremum(V₁::GradedSpace{I}, V₂::GradedSpace{I}) where {I<:Sector}
187193
end
188194
end
189195

190-
function Base.setdiff(V::GradedSpace{I}, W::GradedSpace{I}) where {I<:Sector}
191-
V W && isdual(V) == isdual(W) ||
192-
throw(SpaceMismatch("$(W) is not a subspace of $(V)"))
193-
return typeof(V)(c => dim(V, c) - dim(W, c) for c in sectors(V))
194-
end
195-
196196
function Base.show(io::IO, V::GradedSpace{I}) where {I<:Sector}
197197
print(io, type_repr(typeof(V)), "(")
198198
seperator = ""

src/spaces/vectorspaces.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ function ⊕ end
149149
(V::Vararg{VectorSpace}) = foldl(, V)
150150
const oplus =
151151

152+
"""
153+
⊖(V::ElementarySpace, W::ElementarySpace) -> X::ElementarySpace
154+
ominus(V::ElementarySpace, W::ElementarySpace) -> X::ElementarySpace
155+
156+
Return the set difference of two elementary spaces, i.e. an instance `X::ElementarySpace`
157+
such that `V = W ⊕ X`.
158+
"""
159+
(V₁::S, V₂::S) where {S<:ElementarySpace}
160+
(V₁::VectorSpace, V₂::VectorSpace) = (promote(V₁, V₂)...)
161+
const ominus =
162+
152163
"""
153164
⊗(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace} -> S
154165
@@ -396,10 +407,3 @@ function supremum(V₁::S, V₂::S, V₃::S...) where {S<:ElementarySpace}
396407
return supremum(supremum(V₁, V₂), V₃...)
397408
end
398409

399-
"""
400-
setdiff(V::ElementarySpace, W::ElementarySpace)
401-
402-
Return the set difference of two elementary spaces, i.e. an instance `X::ElementarySpace`
403-
such that `V = W ⊕ X`.
404-
"""
405-
Base.setdiff(V₁::S, V₂::S) where {S<:ElementarySpace}

src/tensors/matrixalgebrakit.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -296,9 +296,9 @@ function MatrixAlgebraKit.check_input(::typeof(qr_null!), t::AbstractTensorMap,
296296

297297
# space checks
298298
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
299-
V_N = setdiff(fuse(codomain(t)), V_Q)
299+
V_N = (fuse(codomain(t)), V_Q)
300300
space(N) == (codomain(t) V_N) ||
301-
throw(SpaceMismatch("`qr_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`"))
301+
throw(SpaceMismatch("`qr_null!(t, N)` requires `space(N) == (codomain(t) ← (fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`"))
302302

303303
return nothing
304304
end
@@ -322,7 +322,7 @@ end
322322
function MatrixAlgebraKit.initialize_output(::typeof(qr_null!), t::AbstractTensorMap,
323323
::MatrixAlgebraKit.AbstractAlgorithm)
324324
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
325-
V_N = setdiff(fuse(codomain(t)), V_Q)
325+
V_N = (fuse(codomain(t)), V_Q)
326326
N = similar(t, codomain(t) V_N)
327327
return N
328328
end
@@ -414,9 +414,9 @@ function MatrixAlgebraKit.check_input(::typeof(lq_null!), t::AbstractTensorMap,
414414

415415
# space checks
416416
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
417-
V_N = setdiff(fuse(domain(t)), V_Q)
417+
V_N = (fuse(domain(t)), V_Q)
418418
space(N) == (V_N domain(t)) ||
419-
throw(SpaceMismatch("`lq_null!(t, N)` requires `space(N) == setdiff(fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`"))
419+
throw(SpaceMismatch("`lq_null!(t, N)` requires `space(N) == (fuse(domain(t)), infimum(fuse(codomain(t)), fuse(domain(t)))`"))
420420

421421
return nothing
422422
end
@@ -440,7 +440,7 @@ end
440440
function MatrixAlgebraKit.initialize_output(::typeof(lq_null!), t::AbstractTensorMap,
441441
::MatrixAlgebraKit.AbstractAlgorithm)
442442
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
443-
V_N = setdiff(fuse(domain(t)), V_Q)
443+
V_N = (fuse(domain(t)), V_Q)
444444
N = similar(t, V_N domain(t))
445445
return N
446446
end
@@ -634,16 +634,16 @@ function MatrixAlgebraKit.check_input(::typeof(left_null!), t::AbstractTensorMap
634634

635635
# space checks
636636
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
637-
V_N = setdiff(fuse(codomain(t)), V_Q)
637+
V_N = (fuse(codomain(t)), V_Q)
638638
space(N) == (codomain(t) V_N) ||
639-
throw(SpaceMismatch("`left_null!(t, N)` requires `space(N) == (codomain(t) ← setdiff(fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`"))
639+
throw(SpaceMismatch("`left_null!(t, N)` requires `space(N) == (codomain(t) ← (fuse(codomain(t)), infimum(fuse(codomain(t)), fuse(domain(t))))`"))
640640

641641
return nothing
642642
end
643643

644644
function MatrixAlgebraKit.initialize_output(::typeof(left_null!), t::AbstractTensorMap)
645645
V_Q = infimum(fuse(codomain(t)), fuse(domain(t)))
646-
V_N = setdiff(fuse(codomain(t)), V_Q)
646+
V_N = (fuse(codomain(t)), V_Q)
647647
N = similar(t, codomain(t) V_N)
648648
return N
649649
end

0 commit comments

Comments
 (0)