Skip to content

Commit 7beae37

Browse files
authored
[TypeParameterAccessors] Fix similartype(Diagonal), introduce similartype(Array, NDims(3)) (#1570)
1 parent c105287 commit 7beae37

File tree

6 files changed

+59
-10
lines changed

6 files changed

+59
-10
lines changed

NDTensors/src/lib/TypeParameterAccessors/src/TypeParameterAccessors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ include("unspecify_parameters.jl")
1212
include("set_parameters.jl")
1313
include("specify_parameters.jl")
1414
include("default_parameters.jl")
15+
include("ndims.jl")
1516
include("base/abstractarray.jl")
1617
include("base/similartype.jl")
1718
include("base/array.jl")

NDTensors/src/lib/TypeParameterAccessors/src/base/abstractarray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ end
1313
function set_ndims(type::Type{<:AbstractArray}, param)
1414
return set_type_parameter(type, ndims, param)
1515
end
16+
function set_ndims(type::Type{<:AbstractArray}, param::NDims)
17+
return set_type_parameter(type, ndims, ndims(param))
18+
end
1619

1720
using SimpleTraits: SimpleTraits, @traitdef, @traitimpl
1821

NDTensors/src/lib/TypeParameterAccessors/src/base/similartype.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,23 @@ like `OffsetArrays` or named indices
55
(such as ITensors).
66
"""
77
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
8-
return set_ndims(arraytype, length(dims))
8+
return set_ndims(arraytype, NDims(length(dims)))
9+
end
10+
11+
function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, ndims::NDims)
12+
return similartype(similartype(arraytype, eltype), ndims)
913
end
1014

1115
function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
1216
return similartype(similartype(arraytype, eltype), dims)
1317
end
1418

19+
@traitfn function similartype(
20+
arraytype::Type{ArrayT}
21+
) where {ArrayT; !IsWrappedArray{ArrayT}}
22+
return arraytype
23+
end
24+
1525
@traitfn function similartype(
1626
arraytype::Type{ArrayT}, eltype::Type
1727
) where {ArrayT; !IsWrappedArray{ArrayT}}
@@ -24,19 +34,29 @@ end
2434
return set_indstype(arraytype, dims)
2535
end
2636

27-
function similartype(arraytype::Type{<:AbstractArray}, dims::Base.DimOrInd...)
28-
return similartype(arraytype, dims)
37+
@traitfn function similartype(
38+
arraytype::Type{ArrayT}, ndims::NDims
39+
) where {ArrayT; !IsWrappedArray{ArrayT}}
40+
return set_ndims(arraytype, ndims)
2941
end
3042

31-
function similartype(arraytype::Type{<:AbstractArray})
32-
return similartype(arraytype, eltype(arraytype))
43+
function similartype(
44+
arraytype::Type{<:AbstractArray}, dim1::Base.DimOrInd, dim_rest::Base.DimOrInd...
45+
)
46+
return similartype(arraytype, (dim1, dim_rest...))
3347
end
3448

3549
## Wrapped arrays
50+
@traitfn function similartype(
51+
arraytype::Type{ArrayT}
52+
) where {ArrayT; IsWrappedArray{ArrayT}}
53+
return similartype(unwrap_array_type(arraytype), NDims(arraytype))
54+
end
55+
3656
@traitfn function similartype(
3757
arraytype::Type{ArrayT}, eltype::Type
3858
) where {ArrayT; IsWrappedArray{ArrayT}}
39-
return similartype(unwrap_array_type(arraytype), eltype)
59+
return similartype(unwrap_array_type(arraytype), eltype, NDims(arraytype))
4060
end
4161

4262
@traitfn function similartype(
@@ -45,6 +65,12 @@ end
4565
return similartype(unwrap_array_type(arraytype), dims)
4666
end
4767

68+
@traitfn function similartype(
69+
arraytype::Type{ArrayT}, ndims::NDims
70+
) where {ArrayT; IsWrappedArray{ArrayT}}
71+
return similartype(unwrap_array_type(arraytype), ndims)
72+
end
73+
4874
# This is for uniform `Diag` storage which uses
4975
# a Number as the data type.
5076
# TODO: Delete this when we change to using a
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
struct NDims{ndims} end
2+
Base.ndims(::NDims{ndims}) where {ndims} = ndims
3+
4+
NDims(ndims::Integer) = NDims{ndims}()
5+
NDims(arraytype::Type{<:AbstractArray}) = NDims(ndims(arraytype))
6+
NDims(array::AbstractArray) = NDims(typeof(array))
Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
@eval module $(gensym())
22
using Test: @test, @test_broken, @testset
3-
using LinearAlgebra: Adjoint
4-
using NDTensors.TypeParameterAccessors: similartype
3+
using LinearAlgebra: Adjoint, Diagonal
4+
using NDTensors.TypeParameterAccessors: NDims, similartype
55
@testset "TypeParameterAccessors similartype" begin
66
@test similartype(Array, Float64, (2, 2)) == Matrix{Float64}
7-
# TODO: Is this a good definition? Probably it should be left unspecified.
8-
@test similartype(Array) == Array{Any}
7+
@test similartype(Array) == Array
98
@test similartype(Array, Float64) == Array{Float64}
109
@test similartype(Array, (2, 2)) == Matrix
10+
@test similartype(Array, NDims(2)) == Matrix
11+
@test similartype(Array, Float64, (2, 2)) == Matrix{Float64}
12+
@test similartype(Array, Float64, NDims(2)) == Matrix{Float64}
1113
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, (2, 2, 2)) ==
1214
Array{Float64,3}
15+
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, NDims(3)) == Array{Float64,3}
1316
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64) == Matrix{Float64}
17+
@test similartype(Diagonal{Float32,Vector{Float32}}) == Matrix{Float32}
18+
@test similartype(Diagonal{Float32,Vector{Float32}}, Float64) == Matrix{Float64}
19+
@test similartype(Diagonal{Float32,Vector{Float32}}, (2, 2, 2)) == Array{Float32,3}
20+
@test similartype(Diagonal{Float32,Vector{Float32}}, NDims(3)) == Array{Float32,3}
21+
@test similartype(Diagonal{Float32,Vector{Float32}}, Float64, (2, 2, 2)) ==
22+
Array{Float64,3}
23+
@test similartype(Diagonal{Float32,Vector{Float32}}, Float64, NDims(3)) ==
24+
Array{Float64,3}
1425
end
1526
end

NDTensors/src/lib/TypeParameterAccessors/test/test_wrappers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ using LinearAlgebra:
1111
UnitUpperTriangular,
1212
UpperTriangular
1313
using NDTensors.TypeParameterAccessors:
14+
NDims,
1415
TypeParameter,
1516
is_wrapped_array,
1617
parenttype,
@@ -33,6 +34,7 @@ include("utils/test_inferred.jl")
3334
@test_inferred set_eltype(array, Float32) array
3435
@test_inferred set_eltype(Array{<:Any,2}, Float64) == Matrix{Float64}
3536
@test_inferred set_ndims(Array{Float64}, 2) == Matrix{Float64} wrapped = true
37+
@test_inferred set_ndims(Array{Float64}, NDims(2)) == Matrix{Float64} wrapped = true
3638
@test_inferred set_ndims(Array{Float64}, TypeParameter(2)) == Matrix{Float64}
3739
@test_inferred unwrap_array_type(array_type) == array_type
3840
end

0 commit comments

Comments
 (0)