Skip to content

Commit d442a80

Browse files
committed
Update clenshaw.jl
1 parent 5e200d9 commit d442a80

File tree

1 file changed

+1
-249
lines changed

1 file changed

+1
-249
lines changed

src/clenshaw.jl

Lines changed: 1 addition & 249 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,7 @@
11

2-
##
3-
# For Chebyshev T. Note the shift in indexing is fine due to the AbstractFill
4-
##
5-
Base.@propagate_inbounds _forwardrecurrence_next(n, A::Vcat{<:Any,1,<:Tuple{<:Number,<:AbstractFill}}, B::Zeros, C::Ones, x, p0, p1) =
6-
_forwardrecurrence_next(n, A.args[2], B, C, x, p0, p1)
7-
8-
Base.@propagate_inbounds _clenshaw_next(n, A::Vcat{<:Any,1,<:Tuple{<:Number,<:AbstractFill}}, B::Zeros, C::Ones, x, c, bn1, bn2) =
9-
_clenshaw_next(n, A.args[2], B, C, x, c, bn1, bn2)
10-
112
# Assume 1 normalization
123
_p0(A) = one(eltype(A))
134

14-
function initiateforwardrecurrence(N, A, B, C, x, μ)
15-
T = promote_type(eltype(A), eltype(B), eltype(C), typeof(x))
16-
p0 = convert(T, μ)
17-
N == 0 && return zero(T), p0
18-
p1 = convert(T, muladd(A[1],x,B[1])*p0)
19-
@inbounds for n = 2:N
20-
p1,p0 = _forwardrecurrence_next(n, A, B, C, x, p0, p1),p1
21-
end
22-
p0,p1
23-
end
245

256
for (get, vie) in ((:getindex, :view), (:(Base.unsafe_getindex), :(Base.unsafe_view)))
267
@eval begin
@@ -118,228 +99,6 @@ end
11899
Base.@propagate_inbounds getindex(f::Mul{<:WeightedOPLayout,<:AbstractPaddedLayout}, x::Number, j...) =
119100
weight(f.A)[x] * (unweighted(f.A) * f.B)[x, j...]
120101

121-
###
122-
# Operator clenshaw
123-
###
124-
125-
126-
Base.@propagate_inbounds function _clenshaw_next!(n, A::AbstractFill, ::Zeros, C::Ones, x::AbstractMatrix, c, bn1::AbstractMatrix{T}, bn2::AbstractMatrix{T}) where T
127-
muladd!(getindex_value(A), x, bn1, -one(T), bn2)
128-
view(bn2,band(0)) .+= c[n]
129-
bn2
130-
end
131-
132-
Base.@propagate_inbounds function _clenshaw_next!(n, A::AbstractVector, ::Zeros, C::AbstractVector, x::AbstractMatrix, c, bn1::AbstractMatrix{T}, bn2::AbstractMatrix{T}) where T
133-
muladd!(A[n], x, bn1, -C[n+1], bn2)
134-
view(bn2,band(0)) .+= c[n]
135-
bn2
136-
end
137-
138-
Base.@propagate_inbounds function _clenshaw_next!(n, A::AbstractVector, B::AbstractVector, C::AbstractVector, x::AbstractMatrix, c, bn1::AbstractMatrix{T}, bn2::AbstractMatrix{T}) where T
139-
# bn2 .= B[n] .* bn1 .- C[n+1] .* bn2
140-
lmul!(-C[n+1], bn2)
141-
LinearAlgebra.axpy!(B[n], bn1, bn2)
142-
muladd!(A[n], x, bn1, one(T), bn2)
143-
view(bn2,band(0)) .+= c[n]
144-
bn2
145-
end
146-
147-
# Operator * f Clenshaw
148-
Base.@propagate_inbounds function _clenshaw_next!(n, A::AbstractFill, ::Zeros, C::Ones, X::AbstractMatrix, c, f::AbstractVector, bn1::AbstractVector{T}, bn2::AbstractVector{T}) where T
149-
muladd!(getindex_value(A), X, bn1, -one(T), bn2)
150-
bn2 .+= c[n] .* f
151-
bn2
152-
end
153-
154-
Base.@propagate_inbounds function _clenshaw_next!(n, A, ::Zeros, C, X::AbstractMatrix, c, f::AbstractVector, bn1::AbstractVector{T}, bn2::AbstractVector{T}) where T
155-
muladd!(A[n], X, bn1, -C[n+1], bn2)
156-
bn2 .+= c[n] .* f
157-
bn2
158-
end
159-
160-
Base.@propagate_inbounds function _clenshaw_next!(n, A, B, C, X::AbstractMatrix, c, f::AbstractVector, bn1::AbstractVector{T}, bn2::AbstractVector{T}) where T
161-
bn2 .= B[n] .* bn1 .- C[n+1] .* bn2 .+ c[n] .* f
162-
muladd!(A[n], X, bn1, one(T), bn2)
163-
bn2
164-
end
165-
166-
# allow special casing first arg, for ChebyshevT in ClassicalOrthogonalPolynomials
167-
Base.@propagate_inbounds function _clenshaw_first!(A, ::Zeros, C, X, c, bn1, bn2)
168-
muladd!(A[1], X, bn1, -C[2], bn2)
169-
view(bn2,band(0)) .+= c[1]
170-
bn2
171-
end
172-
173-
Base.@propagate_inbounds function _clenshaw_first!(A, B, C, X, c, bn1, bn2)
174-
lmul!(-C[2], bn2)
175-
LinearAlgebra.axpy!(B[1], bn1, bn2)
176-
muladd!(A[1], X, bn1, one(eltype(bn2)), bn2)
177-
view(bn2,band(0)) .+= c[1]
178-
bn2
179-
end
180-
181-
Base.@propagate_inbounds function _clenshaw_first!(A, ::Zeros, C, X, c, f::AbstractVector, bn1, bn2)
182-
muladd!(A[1], X, bn1, -C[2], bn2)
183-
bn2 .+= c[1] .* f
184-
bn2
185-
end
186-
187-
Base.@propagate_inbounds function _clenshaw_first!(A, B, C, X, c, f::AbstractVector, bn1, bn2)
188-
bn2 .= B[1] .* bn1 .- C[2] .* bn2 .+ c[1] .* f
189-
muladd!(A[1], X, bn1, one(eltype(bn2)), bn2)
190-
bn2
191-
end
192-
193-
_clenshaw_op(::AbstractBandedLayout, Z, N) = BandedMatrix(Z, (N-1,N-1))
194-
195-
function clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, X::AbstractMatrix)
196-
N = length(c)
197-
T = promote_type(eltype(c),eltype(A),eltype(B),eltype(C),eltype(X))
198-
@boundscheck check_clenshaw_recurrences(N, A, B, C)
199-
m = size(X,1)
200-
m == size(X,2) || throw(DimensionMismatch("X must be square"))
201-
N == 0 && return zero(T)
202-
bn2 = _clenshaw_op(MemoryLayout(X), Zeros{T}(m, m), N)
203-
bn1 = _clenshaw_op(MemoryLayout(X), c[N]*Eye{T}(m), N)
204-
_clenshaw_op!(c, A, B, C, X, bn1, bn2)
205-
end
206-
207-
function clenshaw(c::AbstractVector, A::AbstractVector, B::AbstractVector, C::AbstractVector, X::AbstractMatrix, f::AbstractVector)
208-
N = length(c)
209-
T = promote_type(eltype(c),eltype(A),eltype(B),eltype(C),eltype(X))
210-
@boundscheck check_clenshaw_recurrences(N, A, B, C)
211-
m = size(X,1)
212-
m == size(X,2) || throw(DimensionMismatch("X must be square"))
213-
m == length(f) || throw(DimensionMismatch("Dimensions must match"))
214-
N == 0 && return [zero(T)]
215-
bn2 = zeros(T,m)
216-
bn1 = Vector{T}(undef,m)
217-
bn1 .= c[N] .* f
218-
_clenshaw_op!(c, A, B, C, X, f, bn1, bn2)
219-
end
220-
221-
function _clenshaw_op!(c, A, B, C, X, bn1, bn2)
222-
N = length(c)
223-
N == 1 && return bn1
224-
@inbounds begin
225-
for n = N-1:-1:2
226-
bn1,bn2 = _clenshaw_next!(n, A, B, C, X, c, bn1, bn2),bn1
227-
end
228-
bn1 = _clenshaw_first!(A, B, C, X, c, bn1, bn2)
229-
end
230-
bn1
231-
end
232-
233-
function _clenshaw_op!(c, A, B, C, X, f::AbstractVector, bn1, bn2)
234-
N = length(c)
235-
N == 1 && return bn1
236-
@inbounds begin
237-
for n = N-1:-1:2
238-
bn1,bn2 = _clenshaw_next!(n, A, B, C, X, c, f, bn1, bn2),bn1
239-
end
240-
bn1 = _clenshaw_first!(A, B, C, X, c, f, bn1, bn2)
241-
end
242-
bn1
243-
end
244-
245-
246-
247-
"""
248-
Clenshaw(a, X)
249-
250-
represents the operator `a(X)` where a is a polynomial.
251-
Here `a` is to stored as a quasi-vector.
252-
"""
253-
struct Clenshaw{T, Coefs<:AbstractVector, AA<:AbstractVector, BB<:AbstractVector, CC<:AbstractVector, Jac<:AbstractMatrix} <: AbstractBandedMatrix{T}
254-
c::Coefs
255-
A::AA
256-
B::BB
257-
C::CC
258-
X::Jac
259-
p0::T
260-
end
261-
262-
Clenshaw(c::AbstractVector{T}, A::AbstractVector, B::AbstractVector, C::AbstractVector, X::AbstractMatrix{T}, p0::T) where T =
263-
Clenshaw{T,typeof(c),typeof(A),typeof(B),typeof(C),typeof(X)}(c, A, B, C, X, p0)
264-
265-
Clenshaw(c::Number, A, B, C, X, p) = Clenshaw([c], A, B, C, X, p)
266-
267-
function Clenshaw(a::AbstractQuasiVector, X::AbstractQuasiMatrix)
268-
P,c = arguments(a)
269-
Clenshaw(paddeddata(c), recurrencecoefficients(P)..., jacobimatrix(X), _p0(P))
270-
end
271-
272-
copy(M::Clenshaw) = M
273-
size(M::Clenshaw) = size(M.X)
274-
axes(M::Clenshaw) = axes(M.X)
275-
bandwidths(M::Clenshaw) = (length(M.c)-1,length(M.c)-1)
276-
277-
Base.array_summary(io::IO, C::Clenshaw{T}, inds::Tuple{Vararg{OneToInf{Int}}}) where T =
278-
print(io, Base.dims2string(length.(inds)), " Clenshaw{$T} with $(length(C.c)) degree polynomial")
279-
280-
struct ClenshawLayout <: AbstractLazyBandedLayout end
281-
MemoryLayout(::Type{<:Clenshaw}) = ClenshawLayout()
282-
sublayout(::ClenshawLayout, ::Type{<:NTuple{2,AbstractUnitRange{Int}}}) = ClenshawLayout()
283-
sublayout(::ClenshawLayout, ::Type{<:Tuple{AbstractUnitRange{Int},Union{Slice,AbstractInfUnitRange{Int}}}}) = LazyBandedLayout()
284-
sublayout(::ClenshawLayout, ::Type{<:Tuple{Union{Slice,AbstractInfUnitRange{Int}},AbstractUnitRange{Int}}}) = LazyBandedLayout()
285-
sublayout(::ClenshawLayout, ::Type{<:Tuple{Union{Slice,AbstractInfUnitRange{Int}},Union{Slice,AbstractInfUnitRange{Int}}}}) = LazyBandedLayout()
286-
sub_materialize(::ClenshawLayout, V) = BandedMatrix(V)
287-
288-
function _BandedMatrix(::ClenshawLayout, V::SubArray{<:Any,2})
289-
M = parent(V)
290-
kr,jr = parentindices(V)
291-
b = bandwidth(M,1)
292-
jkr = max(1,min(first(jr),first(kr))-b÷2):max(last(jr),last(kr))+b÷2
293-
# relationship between jkr and kr, jr
294-
kr2,jr2 = kr.-first(jkr).+1,jr.-first(jkr).+1
295-
lmul!(M.p0, clenshaw(M.c, M.A, M.B, M.C, M.X[jkr, jkr])[kr2,jr2])
296-
end
297-
298-
function getindex(M::Clenshaw{T}, kr::AbstractUnitRange, j::Integer) where T
299-
b = bandwidth(M,1)
300-
jkr = max(1,min(j,first(kr))-b÷2):max(j,last(kr))+b÷2
301-
# relationship between jkr and kr, jr
302-
kr2,j2 = kr.-first(jkr).+1,j-first(jkr)+1
303-
f = [Zeros{T}(j2-1); one(T); Zeros{T}(length(jkr)-j2)]
304-
lmul!(M.p0, clenshaw(M.c, M.A, M.B, M.C, M.X[jkr, jkr], f)[kr2])
305-
end
306-
307-
getindex(M::Clenshaw, k::Int, j::Int) = M[k:k,j][1]
308-
309-
function getindex(S::Symmetric{T,<:Clenshaw}, k::Integer, jr::AbstractUnitRange) where T
310-
m = max(jr.start,jr.stop,k)
311-
return Symmetric(getindex(S.data,1:m,1:m),Symbol(S.uplo))[k,jr]
312-
end
313-
function getindex(S::Symmetric{T,<:Clenshaw}, kr::AbstractUnitRange, j::Integer) where T
314-
m = max(kr.start,kr.stop,j)
315-
return Symmetric(getindex(S.data,1:m,1:m),Symbol(S.uplo))[kr,j]
316-
end
317-
function getindex(S::Symmetric{T,<:Clenshaw}, kr::AbstractUnitRange, jr::AbstractUnitRange) where T
318-
m = max(kr.start,jr.start,kr.stop,jr.stop)
319-
return Symmetric(getindex(S.data,1:m,1:m),Symbol(S.uplo))[kr,jr]
320-
end
321-
322-
transposelayout(M::ClenshawLayout) = LazyBandedMatrices.LazyBandedLayout()
323-
# TODO: generalise for layout, use Base.PermutedDimsArray
324-
Base.permutedims(M::Clenshaw{<:Number}) = transpose(M)
325-
326-
327-
function materialize!(M::MatMulVecAdd{<:ClenshawLayout,<:AbstractPaddedLayout,<:AbstractPaddedLayout})
328-
α,A,x,β,y = M.α,M.A,M.B,M.β,M.C
329-
length(y) == size(A,1) || throw(DimensionMismatch("Dimensions must match"))
330-
length(x) == size(A,2) || throw(DimensionMismatch("Dimensions must match"))
331-
= paddeddata(x);
332-
m = length(x̃)
333-
b = bandwidth(A,1)
334-
jkr=1:m+b
335-
p = [x̃; zeros(eltype(x̃),length(jkr)-m)];
336-
Ax = lmul!(A.p0, clenshaw(A.c, A.A, A.B, A.C, A.X[jkr, jkr], p))
337-
_fill_lmul!(β,y)
338-
resizedata!(y, last(jkr))
339-
v = view(paddeddata(y),jkr)
340-
LinearAlgebra.axpy!(α, Ax, v)
341-
y
342-
end
343102

344103
# TODO: generalise this to be trait based
345104
function layout_broadcasted(::Tuple{ExpansionLayout{<:AbstractOPLayout},AbstractOPLayout}, ::typeof(*), a, P)
@@ -364,11 +123,4 @@ function _broadcasted_layout_broadcasted_mul(::Tuple{AbstractWeightLayout,Polyno
364123
Q = OrthogonalPolynomial(w)
365124
a = (w .* Q) * (Q \ v)
366125
a .* P
367-
end
368-
369-
370-
##
371-
# Banded dot is slow
372-
###
373-
374-
LinearAlgebra.dot(x::AbstractVector, A::Clenshaw, y::AbstractVector) = dot(x, mul(A, y))
126+
end

0 commit comments

Comments
 (0)