Skip to content

Commit 4996de6

Browse files
authored
improve inference in space tensor product (#203)
* improve inference in space tensor product * Fix ones for ArraySpace
1 parent 7d7c6db commit 4996de6

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunBase"
22
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
3-
version = "0.7.4"
3+
version = "0.7.5"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/Multivariate/TensorSpace.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,11 @@ setdomain(sp::TensorSpace, d::ProductDomain) = TensorSpace(setdomain.(factors(sp
254254
*(A::Space, B::Space) = AB
255255
function ^(A::Space, p::Integer)
256256
p >= 1 || throw(ArgumentError("exponent must be >= 1, received $p"))
257-
p == 1 ? A : foldl(*, ntuple(_ -> A, p))
257+
# Enumerate common cases to help with constant propagation
258+
p == 1 ? A :
259+
p == 2 ? A * A :
260+
p == 3 ? A * A * A :
261+
foldl(*, ntuple(_ -> A, p))
258262
end
259263

260264

src/Spaces/ArraySpace.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ Fun(M::UniformScaling,sp::MatrixSpace) = Fun(M.λ*Matrix(I,size(sp)...),sp)
234234

235235

236236

237-
ones(::Type{T},A::ArraySpace) where {T<:Number} = Fun(ones.(T,spaces(A)))
238-
ones(A::ArraySpace) = Fun(ones.(spaces(A)))
237+
ones(::Type{T},A::ArraySpace) where {T<:Number} = Fun(ones.(T,A.spaces))
238+
ones(A::ArraySpace) = ones(Float64, A)
239239

240240

241241
## EuclideanSpace

test/SpacesTest.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,19 @@ using LinearAlgebra
6262
end
6363
end
6464
@testset "intpow" begin
65+
sp = PointSpace(1:3)
66+
f = Fun(sp, Float64[1:3;])
6567
@test ApproxFunBase.intpow(f, 0) == f^0 == Fun(space(f), ones(ncoefficients(f)))
6668
for n in 1:3
6769
@test ApproxFunBase.intpow(f, n) == f^n == reduce(*, fill(f, n))
6870
end
6971
@test ApproxFunBase.intpow(f,-2) == f^-2 == 1/(f*f)
72+
73+
@test sp^1 == sp
74+
@test sp^2 == sp * sp
75+
@test sp^3 == sp * sp * sp
76+
@test sp^4 == sp * sp * sp * sp
77+
@test sp^5 == sp * sp * sp * sp * sp
7078
end
7179

7280
@testset "Fun accepts callables" begin

0 commit comments

Comments
 (0)