Skip to content

Commit 6ccf0d6

Browse files
authored
improve TensorSpace inference (#261)
1 parent 85f3efb commit 6ccf0d6

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunBase"
22
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
3-
version = "0.7.32"
3+
version = "0.7.33"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/Multivariate/TensorSpace.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,13 @@ tensor_eval_type(::Type{Vector{Any}},::Type{Vector{Any}}) = Vector{Any}
222222
tensor_eval_type(::Type{Vector{Any}},_) = Vector{Any}
223223
tensor_eval_type(_,::Type{Vector{Any}}) = Vector{Any}
224224

225-
225+
# Specialize some common cases to avoid mapreduce, which has inference issues
226+
_typeofproddomain(sp::Tuple{Any}) = typeof(domain(sp[1]))
227+
_typeofproddomain(sp::Tuple{Any,Any}) = typeof(domain(sp[1]) × domain(sp[2]))
228+
_typeofproddomain(sp) = typeof(mapreduce(domain,×,sp))
226229
TensorSpace(sp::Tuple) =
227-
TensorSpace{typeof(sp),typeof(mapreduce(domain,×,sp)),
228-
mapreduce(rangetype,(a,b)->tensor_eval_type(a,b),sp)}(sp)
229-
230+
TensorSpace{typeof(sp), _typeofproddomain(sp),
231+
mapreduce(rangetype,tensor_eval_type,sp)}(sp)
230232

231233
dimension(sp::TensorSpace) = mapreduce(dimension,*,sp.spaces)
232234

test/SpacesTest.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,12 @@ using LinearAlgebra
253253
@test dimension(a^3) == dimension(a)^3
254254
@test @inferred(domain(a^3)) == domain(a)^3
255255
@test_broken @inferred(points(a^3)) == vec(Vec.(points(a), points(a)', reshape(points(a), 1,1,4)))
256+
257+
p = PointSpace(1:4)
258+
d = domain(p)
259+
@test domain(TensorSpace(p)) == d
260+
@test components(domain(TensorSpace(p, p))) == (d, d)
261+
@test components(domain(TensorSpace(p, p, p))) == (d, d, d)
256262
end
257263

258264
@testset "ConstantSpace" begin

0 commit comments

Comments
 (0)