Skip to content

Commit 865496d

Browse files
benjionejishnub
andauthored
Multivariate functions (#185)
* first version for multivariate functions of type trivialtensorizer * performance improvements * added simple test cases * fix test and extension * comment broken test out for old Julia versions * fixed ordering of tensorization * simd change and vector type * small changes * enabled broken test * fixed unbound parameters * type fix * type fixes * add Polynomials for testing in Project toml * invoked 2D case * move multivariate tests to orthogonal polynomial repository * delete multivariate * fixed block type and rewritten trivialtensor iterator * remove Orthogonal polynomials from Project toml * adding TensorIteratorFun * changed project toml * rename TrivialTensorFun and add check if all spaces in Tensorspace are trivial * performance optimization * added ProductFun which is not working as a comment Co-authored-by: Benjamin Zanger <[email protected]> Co-authored-by: Jishnu Bhattacharya <[email protected]>
1 parent 62be723 commit 865496d

File tree

6 files changed

+146
-10
lines changed

6 files changed

+146
-10
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
88
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
99
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1010
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
11+
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1112
DSP = "717857b8-e6f2-59f4-9121-6e50c889abd2"
1213
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
1314
DualNumbers = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74"
@@ -32,6 +33,7 @@ BandedMatrices = "0.16, 0.17"
3233
BlockArrays = "0.14, 0.15, 0.16"
3334
BlockBandedMatrices = "0.10, 0.11"
3435
Calculus = "0.5"
36+
Combinatorics = "1.0.2"
3537
DSP = "0.6, 0.7"
3638
DomainSets = "0.5"
3739
DualNumbers = "0.6.2"

src/ApproxFunBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ import Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle,
3535

3636
import Statistics: mean
3737

38+
import Combinatorics: multiexponents
39+
3840
import LinearAlgebra: BlasInt, BlasFloat, norm, ldiv!, mul!, det, cross,
3941
qr, qr!, rank, isdiag, istril, istriu, issymmetric,
4042
Tridiagonal, diagm, diagm_container, factorize,

src/Multivariate/Multivariate.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include("VectorFun.jl")
3030
include("TensorSpace.jl")
3131
include("LowRankFun.jl")
3232
include("ProductFun.jl")
33+
include("TrivialTensorFun.jl")
3334

3435

3536
arglength(f)=length(Base.uncompressed_ast(f.code.def).args[1])

src/Multivariate/ProductFun.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
export ProductFun
77

8+
89
"""
910
ProductFun(f, space::TensorSpace; [tol=eps()])
1011
@@ -28,6 +29,7 @@ julia> coefficients(P) # power only at the (1,1) Chebyshev mode
2829
0.0 1.0
2930
```
3031
"""
32+
3133
struct ProductFun{S<:UnivariateSpace,V<:UnivariateSpace,SS<:AbstractProductSpace,T} <: BivariateFun{T}
3234
coefficients::Vector{VFun{S,T}} # coefficients are in x
3335
space::SS
@@ -268,7 +270,8 @@ function coefficients(f::ProductFun,ox::Space,oy::Space)
268270
end
269271

270272
(f::ProductFun)(x,y) = evaluate(f,x,y)
271-
(f::ProductFun)(x,y,z) = evaluate(f,x,y,z)
273+
# ProductFun does only support BivariateFunctions, this function below just does not work
274+
# (f::ProductFun)(x,y,z) = evaluate(f,x,y,z)
272275

273276
coefficients(f::ProductFun,ox::TensorSpace) = coefficients(f,ox[1],ox[2])
274277

@@ -303,7 +306,6 @@ canonicalevaluate(f::ProductFun,xx::AbstractVector,yy::AbstractVector) =
303306

304307

305308
evaluate(f::ProductFun,x,y) = canonicalevaluate(f,tocanonical(f,x,y)...)
306-
evaluate(f::ProductFun,x,y,z) = canonicalevaluate(f,tocanonical(f,x,y,z)...)
307309

308310
# TensorSpace does not use map
309311
evaluate(f::ProductFun{S,V,SS,T},x::Number,::Colon) where {S<:UnivariateSpace,V<:UnivariateSpace,SS<:TensorSpace,T} =
@@ -313,6 +315,7 @@ evaluate(f::ProductFun{S,V,SS,T},x::Number,y::Number) where {S<:UnivariateSpace,
313315
evaluate(f,x,:)(y)
314316

315317

318+
316319
evaluate(f::ProductFun,x) = evaluate(f,x...)
317320

318321
*(c::Number,f::F) where {F<:ProductFun} = F(c*f.coefficients,f.space)

src/Multivariate/TensorSpace.jl

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,76 @@ struct Tensorizer{DMS<:Tuple}
2424
blocks::DMS
2525
end
2626

27+
const Tensorizer2D{AA, BB} = Tensorizer{Tuple{AA, BB}}
2728
const TrivialTensorizer{d} = Tensorizer{NTuple{d,Ones{Int,1,Tuple{OneToInf{Int}}}}}
2829

2930
Base.eltype(a::Tensorizer) = NTuple{length(a.blocks),Int}
3031
Base.eltype(::Tensorizer{<:NTuple{d}}) where {d} = NTuple{d,Int}
3132
dimensions(a::Tensorizer) = map(sum,a.blocks)
3233
Base.length(a::Tensorizer) = mapreduce(sum,*,a.blocks)
3334

35+
36+
function start(a::TrivialTensorizer{d}) where {d}
37+
if d==2
38+
return invoke(start, Tuple{Tensorizer2D}, a)
39+
else
40+
# ((block_dim_1, block_dim_2,...), (itaration_number, iterator, iterator_state)), (itemssofar, length)
41+
return (ones(Int, d),(0, nothing, nothing)), (0,length(a))
42+
end
43+
end
44+
45+
function next(a::TrivialTensorizer{d}, iterator_tuple) where {d}
46+
47+
if d==2
48+
return invoke(next, Tuple{Tensorizer2D, Tuple}, a, iterator_tuple)
49+
end
50+
51+
(block, (j, iterator, iter_state)), (i,tot) = iterator_tuple
52+
53+
54+
@inline function check_block_finished()
55+
if iterator === nothing
56+
return true
57+
end
58+
# there are N-1 over d-1 combinations in a block
59+
amount_combinations_block = binomial(sum(block)-1, d-1)
60+
# check if all combinations have been iterated over
61+
amount_combinations_block <= j
62+
end
63+
64+
ret = reverse(block)
65+
66+
if check_block_finished() # end of new block
67+
68+
# set up iterator for new block
69+
current_sum = sum(block)
70+
iterator = multiexponents(d, current_sum+1-d)
71+
iter_state = nothing
72+
j = 0
73+
end
74+
75+
# increase block, or initialize new block
76+
res, iter_state = iterate(iterator, iter_state)
77+
block .= res.+1
78+
j = j+1
79+
80+
ret, ((block, (j, iterator, iter_state)), (i,tot))
81+
end
82+
83+
84+
function done(a::TrivialTensorizer{d}, iterator_tuple) where {d}
85+
if d==2
86+
return invoke(done, Tuple{Tensorizer2D, Tuple}, a, iterator_tuple)
87+
end
88+
(_, (i,tot)) = iterator_tuple
89+
return i tot
90+
end
91+
92+
3493
# (blockrow,blockcol), (subrow,subcol), (rowshift,colshift), (numblockrows,numblockcols), (itemssofar, length)
35-
start(a::Tensorizer{Tuple{AA,BB}}) where {AA,BB} = (1,1), (1,1), (0,0), (a.blocks[1][1],a.blocks[2][1]), (0,length(a))
94+
start(a::Tensorizer2D{AA, BB}) where {AA,BB} = (1,1), (1,1), (0,0), (a.blocks[1][1],a.blocks[2][1]), (0,length(a))
3695

37-
function next(a::Tensorizer{Tuple{AA,BB}}, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))) where {AA,BB}
96+
function next(a::Tensorizer2D{AA, BB}, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))) where {AA,BB}
3897
ret = k+rsh,j+csh
3998
if k==n && j==m # end of block
4099
if J == 1 || K == length(a.blocks[1]) # end of new block
@@ -59,7 +118,7 @@ function next(a::Tensorizer{Tuple{AA,BB}}, ((K,J), (k,j), (rsh,csh), (n,m), (i,t
59118
end
60119

61120

62-
done(a::Tensorizer, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))) = i tot
121+
done(a::Tensorizer2D, ((K,J), (k,j), (rsh,csh), (n,m), (i,tot))) = i tot
63122

64123
iterate(a::Tensorizer) = next(a, start(a))
65124
function iterate(a::Tensorizer, st)
@@ -104,6 +163,14 @@ block(ci::CachedIterator{T,TrivialTensorizer{2}},k::Int) where {T} =
104163
block(::TrivialTensorizer{2},n::Int) =
105164
Block(floor(Integer,sqrt(2n) + 1/2))
106165

166+
function block(::TrivialTensorizer{d},n::Int) where {d}
167+
order::Int = 0
168+
while binomial(order+d, d) < n
169+
order = order + 1
170+
end
171+
return Block(order+1)
172+
end
173+
107174
block(sp::Tensorizer{<:Tuple{<:AbstractFill{S},<:AbstractFill{T}}},n::Int) where {S,T} =
108175
Block(floor(Integer,sqrt(2floor(Integer,(n-1)/(getindex_value(sp.blocks[1])*getindex_value(sp.blocks[2])))+1) + 1/2))
109176
_cumsum(x) = cumsum(x)
@@ -211,6 +278,10 @@ struct TensorSpace{SV,D,R} <:AbstractProductSpace{SV,D,R}
211278
spaces::SV
212279
end
213280

281+
# Tensorspace of 2 univariate spaces
282+
const TensorSpace2D{AA, BB, D,R} = TensorSpace{<:Tuple{AA, BB}, D, R} where {AA<:UnivariateSpace, BB<:UnivariateSpace}
283+
const TensorSpaceND{d, D, R} = TensorSpace{<:NTuple{d, <:UnivariateSpace}, D, R}
284+
214285
tensorizer(sp::TensorSpace) = Tensorizer(map(blocklengths,sp.spaces))
215286
blocklengths(S::TensorSpace) = tensorblocklengths(map(blocklengths,S.spaces)...)
216287

@@ -473,6 +544,8 @@ end
473544

474545
fromtensor(S::Space,M::AbstractMatrix) = fromtensor(tensorizer(S),M)
475546
totensor(S::Space,M::AbstractVector) = totensor(tensorizer(S),M)
547+
totensor(SS::TensorSpace{<:NTuple{d}},M::AbstractVector) where {d} =
548+
if d>2; totensoriterator(tensorizer(SS),M) else totensor(tensorizer(SS),M) end
476549

477550
function fromtensor(it::Tensorizer,M::AbstractMatrix)
478551
n,m=size(M)
@@ -496,19 +569,27 @@ function totensor(it::Tensorizer,M::AbstractVector)
496569
B=block(it,n)
497570
ds = dimensions(it)
498571

572+
#ret=zeros(eltype(M),[sum(it.blocks[i][1:min(B.n[1],length(it.blocks[i]))]) for i=1:length(it.blocks)]...)
573+
499574
ret=zeros(eltype(M),sum(it.blocks[1][1:min(B.n[1],length(it.blocks[1]))]),
500575
sum(it.blocks[2][1:min(B.n[1],length(it.blocks[2]))]))
576+
501577
k=1
502-
for (K,J) in it
578+
for index in it
503579
if k > n
504580
break
505581
end
506-
ret[K,J] = M[k]
582+
ret[index...] = M[k]
507583
k += 1
508584
end
509585
ret
510586
end
511587

588+
@inline function totensoriterator(it::TrivialTensorizer{d},M::AbstractVector) where {d}
589+
B=block(it,length(M))
590+
return it, M, B
591+
end
592+
512593
for OP in (:block,:blockstart,:blockstop)
513594
@eval begin
514595
$OP(s::TensorSpace, ::PosInfinity) = ℵ₀
@@ -542,10 +623,12 @@ end
542623

543624
itransform(sp::TensorSpace,cfs::AbstractVector) = vec(itransform!(sp,coefficientmatrix(Fun(sp,cfs))))
544625

545-
evaluate(f::AbstractVector,S::AbstractProductSpace,x) = ProductFun(totensor(S,f),S)(x...)
546-
evaluate(f::AbstractVector,S::AbstractProductSpace,x,y) = ProductFun(totensor(S,f),S)(x,y)
547-
626+
# 2D evaluation functions
627+
evaluate(f::AbstractVector,S::TensorSpace2D,x) = ProductFun(totensor(S,f), S)(x...)
628+
evaluate(f::AbstractVector,S::TensorSpace2D,x,y) = ProductFun(totensor(S,f),S)(x,y)
548629

630+
# ND evaluation functions of Trivial Spaces
631+
evaluate(f::AbstractVector,S::TensorSpaceND,x) = TrivialTensorFun(totensor(S, f)..., S)(x...)
549632

550633
coefficientmatrix(f::Fun{<:AbstractProductSpace}) = totensor(space(f),f.coefficients)
551634

src/Multivariate/TrivialTensorFun.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
3+
4+
struct TrivialTensorFun{d, SS<:TensorSpaceND{d}, T<:Number} <: MultivariateFun{T, d}
5+
space::SS
6+
coefficients::Vector{T}
7+
iterator::TrivialTensorizer{d}
8+
orders::Block{1, Int}
9+
end
10+
11+
12+
function TrivialTensorFun(iter::TrivialTensorizer{d},cfs::Vector{T},blk::Block, sp::TensorSpaceND{d}) where {T<:Number,d}
13+
if any(map(dimension, sp.spaces).!=ℵ₀)
14+
error("This Space is not a Trivial Tensor space!")
15+
end
16+
TrivialTensorFun(sp, cfs, iter, blk)
17+
end
18+
19+
(f::TrivialTensorFun)(x...) = evaluate(f, x...)
20+
21+
# TensorSpace evaluation
22+
function evaluate(f::TrivialTensorFun{d, SS, T},x...) where {d, SS, T}
23+
highest_order = f.orders.n[1]
24+
n = length(f.coefficients)
25+
26+
# this could be lazy evaluated for the sparse case
27+
A = T[Fun(f.space.spaces[i], [zeros(T, k);1])(x[i]) for k=0:highest_order, i=1:d]
28+
result::T = 0
29+
coef_counter::Int = 1
30+
for i in f.iterator
31+
tmp = f.coefficients[coef_counter]
32+
if tmp != 0
33+
tmp_res = 1
34+
for k=1:d
35+
tmp_res *= A[i[k], k]
36+
end
37+
result += tmp * tmp_res
38+
end
39+
coef_counter += 1
40+
if coef_counter > n
41+
break
42+
end
43+
end
44+
return result
45+
end

0 commit comments

Comments
 (0)