Skip to content

Commit 96c768f

Browse files
authored
Implement taking subbasis of subbasis (#99)
* Implement taking subbasis of subbasis * Use sub_basis * Remove debug * Fix
1 parent e997915 commit 96c768f

File tree

5 files changed

+42
-9
lines changed

5 files changed

+42
-9
lines changed

src/bases_fixed.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,36 @@ struct SubBasis{T,I,K,B<:AbstractBasis{T,K},V<:AbstractVector{K}} <:
113113
end
114114
end
115115

116+
"""
117+
sub_basis(basis::SubBasis, indices::Vector)
118+
119+
Return a `SubBasis` of the same `parent` but the subset of keys
120+
`basis.keys[indices]`.
121+
"""
122+
function sub_basis(basis::SubBasis, indices::AbstractVector)
123+
# If `basis.is_sorted` is `true` and `indices` is not sorted,
124+
# maybe we should sort it automatically ?
125+
# Or maybe it's best to just error so that the user has to sort it explicitly
126+
@assert issorted(indices)
127+
return SubBasis(parent(basis), basis.keys[indices])
128+
end
129+
130+
"""
131+
sub_basis(basis::ImplicitBasis, keys::AbstractVector)
132+
133+
Return a `SubBasis` of `basis` but the subset of keys `keys`.
134+
135+
## Note
136+
137+
Even though the constructor `SubBasis(basis, keys)` allows `basis`
138+
to be a `ExplicitBasis`, this function restricts it to be
139+
an `ImplicitBasis` in order to provide an interface that would
140+
help catch common mistakes.
141+
"""
142+
function sub_basis(basis::ImplicitBasis, keys::AbstractVector)
143+
return SubBasis(basis, keys)
144+
end
145+
116146
Base.parent(sub::SubBasis) = sub.parent_basis
117147

118148
function Base.:(==)(a::SubBasis, b::SubBasis)

test/basic.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@ Base.iterate(b::DummyBasis, args...) = iterate(b.elements, args...)
4646
@test A == StarAlgebra(1.0, SA.MappedBasis(b, float, error))
4747
@test A != StarAlgebra(1.0, m2)
4848

49-
sub = SA.SubBasis(m, Irrational[π])
49+
sub = SA.sub_basis(m, Irrational[π])
5050
B = StarAlgebra(1.0, sub)
5151
@test B == B
52-
@test B == StarAlgebra(1.0, SA.SubBasis(m, Irrational[π]))
53-
@test B != StarAlgebra(1.0, SA.SubBasis(m, Irrational[ℯ]))
54-
@test B != StarAlgebra(1.0, SA.SubBasis(m2, Irrational[π]))
52+
@test B == StarAlgebra(1.0, SA.sub_basis(m, Irrational[π]))
53+
@test B != StarAlgebra(1.0, SA.sub_basis(m, Irrational[ℯ]))
54+
@test B != StarAlgebra(1.0, SA.sub_basis(m2, Irrational[π]))
55+
@test B == StarAlgebra(1.0, SA.sub_basis(SA.sub_basis(m, Irrational[π, ℯ]), 1:1))
56+
@test B != StarAlgebra(1.0, SA.sub_basis(SA.sub_basis(m, Irrational[π, ℯ]), 1:2))
57+
@test B != StarAlgebra(1.0, SA.sub_basis(SA.sub_basis(m, Irrational[π, ℯ]), 2:2))
5558

5659
el = SA.AlgebraElement(
5760
[Variable()],

test/graded_lex.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import StarAlgebras as SA
2222
@test c.coeffs.values == [4, -4, 12, 1, -6, 9]
2323
c_keys = [(0, 0), (0, 1), (1, 0), (0, 2), (1, 1), (2, 0)]
2424
@test c.coeffs.basis_elements == c_keys
25-
sub = SA.SubBasis(SA.basis(alg), c_keys)
25+
sub = SA.sub_basis(SA.basis(alg), c_keys)
2626
@test sub.is_sorted
2727
for (i, key) in enumerate(c_keys)
2828
el = SA.basis(alg)[key]

test/perm_grp_algebra.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ import StarAlgebras as SA
124124
Random.seed!(0)
125125
S1 = unique!(rand(G, 7))
126126
S = unique!([S1; [a * b for a in S1 for b in S1]])
127-
subb = SA.SubBasis(db, S)
127+
subb = SA.sub_basis(db, S)
128128
a = S1[1]
129129
@test subb[a] == 1
130130
@test a in subb
@@ -160,7 +160,7 @@ import StarAlgebras as SA
160160
end
161161

162162
S2 = unique([S; one(G)])
163-
subb2 = SA.SubBasis(db, S2)
163+
subb2 = SA.sub_basis(db, S2)
164164
let sRG = SA.StarAlgebra(G, subb2)
165165
x = let z = spzeros(Int, length(SA.basis(sRG)))
166166
z[rand(1:length(S2), 10)] += rand(-1:1, 10)

test/quadratic_form.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ end
8282
implicit = SA.MappedBasis(NaturalNumbers(), float, Int)
8383
@test 3.0 in implicit
8484
@test haskey(implicit, 3)
85-
explicit = SA.SubBasis(implicit, 1:3)
85+
explicit = SA.sub_basis(implicit, 1:3)
8686
@test 3.0 in explicit
8787
@test collect(explicit) == [1.0, 2.0, 3.0]
8888
@test haskey(explicit, 3)
@@ -109,7 +109,7 @@ end
109109
implicit = cheby_basis()
110110
mstr = ChebyMStruct(implicit)
111111
mt = SA.MTable(mstr, (0, 0))
112-
sub = SA.SubBasis(implicit, 1:3)
112+
sub = SA.sub_basis(implicit, 1:3)
113113
test_vector_interface(sub)
114114
fixed = SA.FixedBasis(implicit; n = 3)
115115
test_vector_interface(fixed)

0 commit comments

Comments
 (0)