Skip to content

Commit 0c37d6f

Browse files
committed
change parametric types of structs
1 parent f776780 commit 0c37d6f

File tree

3 files changed

+79
-41
lines changed

3 files changed

+79
-41
lines changed

src/lie-groups/lie-algebras/algebras.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ dim(alg::ScalingLieAlgebra) = size(alg.exps, 1)
1717
exponents(alg::ScalingLieAlgebra) = alg.exps
1818
rank(alg::ScalingLieAlgebra) = dim(alg)
1919
Base.size(alg::ScalingLieAlgebra) = size(alg.exps, 2)
20-
weight_type(::ScalingLieAlgebra) = Int
20+
weight_type(::ScalingLieAlgebra) = Weight{Int}
2121

2222
function show_basis(io::IO, alg::ScalingLieAlgebra; offset::Int=0)
2323
for i in 1:dim(alg)
@@ -92,8 +92,8 @@ Base.convert(
9292
struct LieAlgebra{F, W<:Weight} <: AbstractLieAlgebra{F}
9393
name::String
9494
basis::ChevalleyBasis{F}
95-
weight_structure::WeightStructure{F, MatrixVectorSpace{F}, W}
96-
hw_spaces::Vector{WeightSpace{F, MatrixVectorSpace{F}, W}} # TODO: change to WeightStructure?
95+
weight_structure::WeightStructure{MatrixVectorSpace{F}, W}
96+
hw_spaces::Vector{WeightSpace{MatrixVectorSpace{F}, W}} # TODO: change to WeightStructure?
9797
end
9898

9999
function so3(field_type::DataType, weight_type::DataType)
@@ -125,7 +125,7 @@ name(alg::LieAlgebra) = alg.name
125125
dim(alg::LieAlgebra) = length(alg.basis.std_basis)
126126
rank(alg::LieAlgebra) = length(alg.basis.cartan)
127127
Base.size(alg::LieAlgebra) = size(alg.basis.std_basis[1], 1)
128-
weight_type(::LieAlgebra{F, Weight{W}}) where {F, W} = W
128+
weight_type(::LieAlgebra{F, W}) where {F, W} = W
129129

130130
function Base.show(io::IO, alg::LieAlgebra{F, Weight{W}}; offset::Int=0) where {F, W}
131131
println(io, " "^offset, "LieAlgebra $(name(alg))")
@@ -221,6 +221,6 @@ cartan_subalgebra(alg::SumLieAlgebra) = get_elements(alg, :cartan_subalgebra)
221221
positive_root_elements(alg::SumLieAlgebra) = get_elements(alg, :positive_root_elements)
222222
negative_root_elements(alg::SumLieAlgebra) = get_elements(alg, :negative_root_elements)
223223

224-
zero_weight(alg::AbstractLieAlgebra) = Weight(zeros(weight_type(alg), rank(alg)))
224+
zero_weight(alg::AbstractLieAlgebra) = zero(weight_type(alg), rank(alg))
225225
positive_roots(alg::AbstractLieAlgebra) = [root(pre) for pre in positive_root_elements(alg)]
226226
negative_roots(alg::AbstractLieAlgebra) = [root(nre) for nre in negative_root_elements(alg)]

src/lie-groups/lie-algebras/weights.jl

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ end
1515

1616
Base.show(io::IO, ::MIME"text/plain", w::Weight) = print(io, "Weight: ", w.weight)
1717
Base.show(io::IO, w::Weight) = print(io, w.weight)
18-
Base.zero(::Type{Weight{T}}) where T = Weight(zero(T))
18+
Base.zero(::Type{Weight{T}}, n) where T = Weight(zeros(T, n))
1919
Base.zero(w::Weight) = Weight(zero(w.weight))
2020
Base.vcat(w₁::Weight{T}, w₂::Weight{T}) where T = Weight(vcat(w₁.weight, w₂.weight))
2121
Base.hash(w::Weight, h::UInt) = hash(w.weight, h)
2222
Base.:(==)(w₁::Weight, w₂::Weight) = w₁.weight == w₂.weight
2323
Base.:*(n::Number, w::Weight) = Weight(n*w.weight)
2424
Base.:+(w₁::Weight{T}, w₂::Weight{T}) where {T} = Weight{T}(w₁.weight + w₂.weight)
25+
Base.promote_rule(::Type{Weight{T}}, ::Type{Weight{S}}) where {T,S} = Weight{promote_type(T, S)}
2526

2627
struct WeightVector{T, W<:Weight}
2728
weight::W
@@ -40,25 +41,26 @@ function Base.show(io::IO, wv::WeightVector)
4041
print(io, "WeightVector with weight $(weight(wv))")
4142
end
4243

43-
struct WeightSpace{F, T <: AbstractVectorSpace{F}, W<:Weight} # TODO: remove F?
44+
struct WeightSpace{T <: AbstractVectorSpace, W<:Weight}
4445
weight::W
4546
space::T
4647
end
4748

4849
WeightSpace(weight::Vector, space::Vector) = WeightSpace(Weight(weight), MatrixVectorSpace(space))
49-
WeightSpace{F,T,W}(wv::WeightVector) where {F,T,W} = WeightSpace{F,T,W}(weight(wv), T(vector(wv)))
50+
WeightSpace{T,W}(wv::WeightVector) where {T,W} = WeightSpace{T,W}(weight(wv), T(vector(wv)))
5051

5152
Base.convert(
52-
::Type{WeightSpace{F, T, W}},
53+
::Type{WeightSpace{T, W}},
5354
ws::WeightSpace
54-
) where {F, T<:AbstractVectorSpace{F}, W<:Weight} = WeightSpace(
55+
) where {T<:AbstractVectorSpace, W<:Weight} = WeightSpace(
5556
convert(W, weight(ws)),
5657
convert(T, space(ws))
5758
)
5859

5960
weight(ws::WeightSpace) = ws.weight
6061
space(ws::WeightSpace) = ws.space
6162
dim(ws::WeightSpace) = dim(space(ws))
63+
field_type(::WeightSpace{T}) where {T} = field_type(T)
6264

6365
function Base.show(io::IO, ::MIME"text/plain", ws::WeightSpace)
6466
println(io, "WeightSpace of dimension $(dim(ws))")
@@ -74,12 +76,12 @@ function Base.iterate(ws::WeightSpace, state=1)
7476
return (WeightVector(weight(ws), basis(space(ws), state)), state+1)
7577
end
7678

77-
struct WeightStructure{F, T<:AbstractVectorSpace{F}, W<:Weight}
79+
struct WeightStructure{T<:AbstractVectorSpace, W<:Weight}
7880
weights::Vector{W} # Ordering needed for sym_weight_structure
79-
dict::Dict{W, WeightSpace{F, T, W}} # TODO: new type for WeightSpace{F,T,W}? Do all spaces have to be of the same type?
81+
dict::Dict{W, WeightSpace{T, W}} # TODO: new type for WeightSpace{F,T,W}? Do all spaces have to be of the same type?
8082
end
8183

82-
WeightStructure{F,T,W}() where {F,T,W} = WeightStructure{F,T,W}(Weight[], Dict())
84+
WeightStructure{T,W}() where {T,W} = WeightStructure{T,W}(Weight[], Dict())
8385

8486
WeightStructure(
8587
weights::Vector{<:Weight},
@@ -98,27 +100,31 @@ WeightStructure(
98100
)
99101

100102
# TODO: improve
101-
function WeightStructure(w_spaces::Vector{<:WeightSpace{F,T,W}}) where {F,T,W}
102-
ws = WeightStructure{F,T,W}()
103+
function WeightStructure(w_spaces::Vector{<:WeightSpace{T,W}}) where {T,W}
104+
ws = WeightStructure{T,W}()
103105
for w_space in w_spaces
104106
push!(ws, w_space)
105107
end
106108
return ws
107109
end
108110

109111
Base.convert(
110-
::Type{WeightStructure{F, T, W}},
112+
::Type{WeightStructure{T, W}},
111113
ws::WeightStructure
112-
) where {F, T<:AbstractVectorSpace{F}, W<:Weight} = WeightStructure(
114+
) where {T<:AbstractVectorSpace, W<:Weight} = WeightStructure(
113115
convert(Vector{W}, ws.weights),
114-
convert(Dict{W, WeightSpace{F, T, W}}, ws.dict)
116+
convert(Dict{W, WeightSpace{T, W}}, ws.dict)
115117
)
116118

117119
weights(ws::WeightStructure) = ws.weights
118120
weights(ws::WeightStructure, inds...) = getindex(weights(ws), inds...)
119121
nweights(ws::WeightStructure) = length(ws.weights)
120-
weight(ws::WeightStructure, i::Integer) = weights(ws)[i]
121-
weight_space(ws::WeightStructure, i::Integer) = ws[weight(ws, i)]
122+
weight(ws::WeightStructure, i::Integer) = weights(ws)[i] # TODO: remove?
123+
weight_space(
124+
ws::WeightStructure,
125+
i::Integer;
126+
as_space::Bool=false
127+
) = as_space ? space(ws[weight(ws, i)]) : ws[weight(ws, i)]
122128
weight_spaces(
123129
ws::WeightStructure;
124130
as_spaces::Bool=false
@@ -131,12 +137,14 @@ weight_spaces(
131137
dims(ws::WeightStructure) = [dim(space(ws[w])) for w in weights(ws)]
132138
dim(ws::WeightStructure) = sum(dims(ws))
133139
Base.length(ws::WeightStructure) = nweights(ws)
134-
Base.getindex(ws::WeightStructure{F,T,W}, weight::W) where {F,T,W<:Weight} = ws.dict[weight]
140+
Base.isempty(ws::WeightStructure) = isempty(weights(ws))
141+
Base.getindex(ws::WeightStructure{T,W}, weight::W) where {T,W<:Weight} = ws.dict[weight]
135142
Base.getindex(ws::WeightStructure, i::Integer) = ws[weight(ws, i)]
136143
Base.setindex!(ws::WeightStructure, ws_new::WeightSpace, weight::Weight) = ws.dict[weight] = ws_new
137144
Base.haskey(ws::WeightStructure, w::Weight) = haskey(ws.dict, w)
138145
zero_weight(ws::WeightStructure) = zero(first(weights(ws)))
139-
field_space(::WeightStructure{F, T}) where {F, T} = field_space(T)
146+
field_space(::WeightStructure{T}) where {T} = field_space(T)
147+
field_type(::WeightStructure{T}) where {T} = field_type(T)
140148

141149
function Base.show(io::IO, ::MIME"text/plain", ws::WeightStructure)
142150
println(io, "WeightStructure of $(dim(ws))-dimensional vector space")
@@ -165,9 +173,9 @@ function Base.push!(ws::WeightStructure, w_space::WeightSpace)
165173
return ws
166174
end
167175

168-
Base.push!(ws::WeightStructure{F,T,W}, wv::WeightVector) where {F,T,W} = push!(ws, WeightSpace{F,T,W}(wv))
176+
Base.push!(ws::WeightStructure{T,W}, wv::WeightVector) where {T,W} = push!(ws, WeightSpace{T,W}(wv))
169177

170-
function sym(ws::WeightStructure{F,T,W}, d::Int) where {F, T, W}
178+
function sym(ws::WeightStructure{T,W}, d::Int) where {T, W}
171179
d == 0 && return WeightStructure([WeightSpace(zero_weight(ws), field_space(ws))])
172180
d == 1 && return ws
173181
combs = multiexponents(; degree=d, nvars=nweights(ws))
@@ -180,10 +188,31 @@ function sym(ws::WeightStructure{F,T,W}, d::Int) where {F, T, W}
180188
new_weights_dict[w] = [comb]
181189
end
182190
end
183-
new_ws = WeightStructure{F,T,W}()
191+
new_ws = WeightStructure{T,W}()
184192
for (weight, combs) in new_weights_dict
185193
w_sps = [*(weight_spaces(ws, comb.nzind; as_spaces=true), comb.nzval) for comb in combs]
186194
push!(new_ws, WeightSpace(weight, +(w_sps...)))
187195
end
188196
return new_ws
197+
end
198+
199+
function tensor(ws₁::WeightStructure{T,W}, ws₂::WeightStructure{T,W}) where {T, W}
200+
combs = Base.Iterators.product(1:nweights(ws₁), 1:nweights(ws₂))
201+
new_weights_dict = Dict{W, Vector{typeof(first(combs))}}()
202+
for comb in combs
203+
w = weight(ws₁, comb[1]) + weight(ws₂, comb[2])
204+
val = get(new_weights_dict, w, nothing)
205+
if isnothing(val)
206+
new_weights_dict[w] = [comb]
207+
else
208+
push!(new_weights_dict[w], comb)
209+
end
210+
end
211+
new_ws = WeightStructure{T,W}()
212+
for (weight, combs) in new_weights_dict
213+
w_sps = [*([weight_space(ws₁, comb[1]; as_space=true), weight_space(ws₂, comb[2]; as_space=true)]; in_rref=false) for comb in combs]
214+
total_ws = +(w_sps...; in_rref=false)
215+
push!(new_ws, WeightSpace(weight, total_ws))
216+
end
217+
return new_ws
189218
end

src/vector-spaces/basic.jl

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ export MatrixVectorSpace,
33
VariableSpace
44

55

6-
struct MatrixVectorSpace{F} <: AbstractVectorSpace{F}
6+
struct MatrixVectorSpace{F} <: AbstractVectorSpace{Vector{F}, F}
77
matrix::Matrix{F}
88
end
99

@@ -27,15 +27,15 @@ function Base.:∩(V₁::MatrixVectorSpace, V₂::MatrixVectorSpace)
2727
end
2828

2929
# TODO: add gens?
30-
struct VectorSpace{T, F} <: AbstractVectorSpace{F}
30+
struct VectorSpace{T, F} <: AbstractVectorSpace{T, F}
3131
basis::Vector{T}
3232
end
3333

3434
VectorSpace{T, F}() where {T, F} = VectorSpace{T, F}(Vector{T}())
3535
VectorSpace{T, F}(v::T) where {T, F} = VectorSpace{T, F}([v])
3636
VectorSpace(F::DataType, vars::Vector{T}) where T<:Variable = VectorSpace{T, F}(unique(vars))
3737
VectorSpace(F::DataType, vs::Vector{T}) where T = VectorSpace{T, F}(vs)
38-
VectorSpace(F::DataType, polys::Vector{T}) where T<:Polynomial = VectorSpace{T,F}(in_rref(polys))
38+
VectorSpace(F::DataType, polys::Vector{T}) where T<:Polynomial = VectorSpace{T,F}(rref(polys))
3939

4040
basis(V::VectorSpace) = V.basis
4141
basis(V::VectorSpace, i::Integer) = basis(V)[i]
@@ -49,7 +49,10 @@ function Base.show(io::IO, V::VectorSpace{T, F}; indent::Int=0) where {T, F}
4949
end
5050

5151
field_space(::Type{VectorSpace{T, F}}) where {T, F} = VectorSpace{T, F}(one(T))
52-
variables(V::VectorSpace{<:Variable}) = basis(V)
52+
DynamicPolynomials.variables(V::VectorSpace{<:Variable}) = basis(V)
53+
DynamicPolynomials.nvariables(V::VectorSpace{<:Variable}) = dim(V)
54+
DynamicPolynomials.variables(V::VectorSpace{<:Polynomial}) = variables(basis(V))
55+
DynamicPolynomials.nvariables(V::VectorSpace{<:Polynomial}) = nvariables(basis(V))
5356
Base.iszero(V::VectorSpace) = dim(V) == 0
5457
Base.rand(V::VectorSpace{T, F}) where {T, F} = sum(rand(F, dim(V)) .* basis(V))
5558
Base.push!(V::VectorSpace{T}, v::T) where T = push!(V.basis, v)
@@ -58,13 +61,25 @@ Base.convert(::Type{VectorSpace{T₁, F}}, V::VectorSpace{T₂, F}) where {T₁,
5861
Base.:+(
5962
Vs::VectorSpace{T, F}...
6063
) where {T<:Variable, F} = VectorSpace{T,F}(([basis(V) for V in Vs]...))
61-
Base.:+(
62-
Vs::VectorSpace{T, F}...
63-
) where {T<:Polynomial, F} = VectorSpace{T,F}(in_rref(vcat([basis(V) for V in Vs]...)))
6464

65-
Base.:*(
66-
Vs::Vector{VectorSpace{T, F}}
67-
) where {T<:Polynomial, F} = VectorSpace{T, F}(in_rref([prod(fs) for fs in product([basis(V) for V in Vs]...)][:]))
65+
function Base.:+(
66+
Vs::VectorSpace{T, F}...;
67+
in_rref::Bool=true
68+
) where {T<:Polynomial, F}
69+
if in_rref
70+
return VectorSpace{T,F}(rref(vcat([basis(V) for V in Vs]...)))
71+
end
72+
return VectorSpace{T,F}(vcat([basis(V) for V in Vs]...))
73+
end
74+
75+
function Base.:*(
76+
Vs::Vector{VectorSpace{T, F}};
77+
in_rref::Bool=true
78+
) where {T<:Polynomial, F}
79+
in_rref && return VectorSpace{T, F}(rref([prod(fs) for fs in product([basis(V) for V in Vs]...)][:]))
80+
return VectorSpace{T, F}([prod(fs) for fs in product([basis(V) for V in Vs]...)][:])
81+
end
82+
6883
Base.:*(
6984
Vs::Vector{VectorSpace{T, F}},
7085
muls::Vector{Int}
@@ -85,9 +100,3 @@ function Base.:∩(
85100
Vᵢ = M₁*N[1:dim(V₁), :]
86101
return VectorSpace(F, [sum(c .* all_mons) for c in eachcol(Vᵢ)])
87102
end
88-
89-
function zero_combinations(F::Vector{<:AbstractPolynomial}; tol::Real=1e-5)
90-
mons = monomials(F)
91-
M = coeffs_matrix(F, mons)
92-
return eachcol(nullspace(M; atol=tol))
93-
end

0 commit comments

Comments
 (0)