Skip to content

Commit 11a5c98

Browse files
committed
[TypeParameterAccessors] similartype
1 parent 8bb156a commit 11a5c98

File tree

8 files changed

+83
-66
lines changed

8 files changed

+83
-66
lines changed

NDTensors/src/abstractarray/set_types.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,3 @@ TODO: Use `Accessors.jl` notation:
1414
function TypeParameterAccessors.set_ndims(numbertype::Type{<:Number}, ndims)
1515
return numbertype
1616
end
17-
18-
"""
19-
`set_indstype` should be overloaded for
20-
types with structured dimensions,
21-
like `OffsetArrays` or named indices
22-
(such as ITensors).
23-
"""
24-
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
25-
return set_ndims(arraytype, length(dims))
26-
end

NDTensors/src/abstractarray/similar.jl

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using Base: DimOrInd, Dims, OneTo
12
using .TypeParameterAccessors: IsWrappedArray, unwrap_array_type, set_eltype
23

34
## Custom `NDTensors.similar` implementation.
@@ -96,58 +97,3 @@ end
9697
# Use the `size` to determine the dimensions
9798
# NDTensors.similar
9899
similar(array::AbstractArray) = NDTensors.similar(typeof(array), size(array))
99-
100-
## similartype
101-
102-
function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
103-
return similartype(similartype(arraytype, eltype), dims)
104-
end
105-
106-
@traitfn function similartype(
107-
arraytype::Type{ArrayT}, eltype::Type
108-
) where {ArrayT; !IsWrappedArray{ArrayT}}
109-
return set_eltype(arraytype, eltype)
110-
end
111-
112-
@traitfn function similartype(
113-
arraytype::Type{ArrayT}, dims::Tuple
114-
) where {ArrayT; !IsWrappedArray{ArrayT}}
115-
return set_indstype(arraytype, dims)
116-
end
117-
118-
function similartype(arraytype::Type{<:AbstractArray}, dims::DimOrInd...)
119-
return similartype(arraytype, dims)
120-
end
121-
122-
function similartype(arraytype::Type{<:AbstractArray})
123-
return similartype(arraytype, eltype(arraytype))
124-
end
125-
126-
## Wrapped arrays
127-
@traitfn function similartype(
128-
arraytype::Type{ArrayT}, eltype::Type
129-
) where {ArrayT; IsWrappedArray{ArrayT}}
130-
return similartype(unwrap_array_type(arraytype), eltype)
131-
end
132-
133-
@traitfn function similartype(
134-
arraytype::Type{ArrayT}, dims::Tuple
135-
) where {ArrayT; IsWrappedArray{ArrayT}}
136-
return similartype(unwrap_array_type(arraytype), dims)
137-
end
138-
139-
# This is for uniform `Diag` storage which uses
140-
# a Number as the data type.
141-
# TODO: Delete this when we change to using a
142-
# `FillArray` instead. This is a stand-in
143-
# to make things work with the current design.
144-
function similartype(numbertype::Type{<:Number})
145-
return numbertype
146-
end
147-
148-
# Instances
149-
function similartype(array::AbstractArray, eltype::Type, dims...)
150-
return similartype(typeof(array), eltype, dims...)
151-
end
152-
similartype(array::AbstractArray, eltype::Type) = similartype(typeof(array), eltype)
153-
similartype(array::AbstractArray, dims...) = similartype(typeof(array), dims...)

NDTensors/src/lib/TypeParameterAccessors/src/TypeParameterAccessors.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include("set_parameters.jl")
1313
include("specify_parameters.jl")
1414
include("default_parameters.jl")
1515
include("base/abstractarray.jl")
16+
include("base/similartype.jl")
1617
include("base/array.jl")
1718
include("base/linearalgebra.jl")
1819
include("base/stridedviews.jl")
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
`set_indstype` should be overloaded for
3+
types with structured dimensions,
4+
like `OffsetArrays` or named indices
5+
(such as ITensors).
6+
"""
7+
function set_indstype(arraytype::Type{<:AbstractArray}, dims::Tuple)
8+
return set_ndims(arraytype, length(dims))
9+
end
10+
11+
function similartype(arraytype::Type{<:AbstractArray}, eltype::Type, dims::Tuple)
12+
return similartype(similartype(arraytype, eltype), dims)
13+
end
14+
15+
@traitfn function similartype(
16+
arraytype::Type{ArrayT}, eltype::Type
17+
) where {ArrayT; !IsWrappedArray{ArrayT}}
18+
return set_eltype(arraytype, eltype)
19+
end
20+
21+
@traitfn function similartype(
22+
arraytype::Type{ArrayT}, dims::Tuple
23+
) where {ArrayT; !IsWrappedArray{ArrayT}}
24+
return set_indstype(arraytype, dims)
25+
end
26+
27+
function similartype(arraytype::Type{<:AbstractArray}, dims::Base.DimOrInd...)
28+
return similartype(arraytype, dims)
29+
end
30+
31+
function similartype(arraytype::Type{<:AbstractArray})
32+
return similartype(arraytype, eltype(arraytype))
33+
end
34+
35+
## Wrapped arrays
36+
@traitfn function similartype(
37+
arraytype::Type{ArrayT}, eltype::Type
38+
) where {ArrayT; IsWrappedArray{ArrayT}}
39+
return similartype(unwrap_array_type(arraytype), eltype)
40+
end
41+
42+
@traitfn function similartype(
43+
arraytype::Type{ArrayT}, dims::Tuple
44+
) where {ArrayT; IsWrappedArray{ArrayT}}
45+
return similartype(unwrap_array_type(arraytype), dims)
46+
end
47+
48+
# This is for uniform `Diag` storage which uses
49+
# a Number as the data type.
50+
# TODO: Delete this when we change to using a
51+
# `FillArray` instead. This is a stand-in
52+
# to make things work with the current design.
53+
function similartype(numbertype::Type{<:Number})
54+
return numbertype
55+
end
56+
57+
# Instances
58+
function similartype(array::AbstractArray, eltype::Type, dims...)
59+
return similartype(typeof(array), eltype, dims...)
60+
end
61+
similartype(array::AbstractArray, eltype::Type) = similartype(typeof(array), eltype)
62+
similartype(array::AbstractArray, dims...) = similartype(typeof(array), dims...)

NDTensors/src/lib/TypeParameterAccessors/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ using Test: @testset
55
include("test_defaults.jl")
66
include("test_custom_types.jl")
77
include("test_wrappers.jl")
8+
include("test_similartype.jl")
89
end
910
end
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
@eval module $(gensym())
2+
using Test: @test, @test_broken, @testset
3+
using LinearAlgebra: Adjoint
4+
using NDTensors.TypeParameterAccessors: similartype
5+
@testset "TypeParameterAccessors similartype" begin
6+
@test similartype(Array, Float64, (2, 2)) == Matrix{Float64}
7+
# TODO: Is this a good definition? Probably it should be left unspecified.
8+
@test similartype(Array) == Array{Any}
9+
@test similartype(Array, Float64) == Array{Float64}
10+
@test similartype(Array, (2, 2)) == Matrix
11+
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64, (2, 2, 2)) ==
12+
Array{Float64,3}
13+
@test similartype(Adjoint{Float32,Matrix{Float32}}, Float64) == Matrix{Float64}
14+
end
15+
end

NDTensors/src/tensor/set_types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717

1818
# TODO: Modify the `storagetype` according to `inds`, such as the dimensions?
1919
# TODO: Make a version that accepts `indstype::Type`?
20-
function set_indstype(tensortype::Type{<:Tensor}, inds::Tuple)
20+
function TypeParameterAccessors.set_indstype(tensortype::Type{<:Tensor}, inds::Tuple)
2121
return Tensor{eltype(tensortype),length(inds),storagetype(tensortype),typeof(inds)}
2222
end
2323

NDTensors/src/tensor/similar.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using .TypeParameterAccessors: set_indstype
2+
13
# NDTensors.similar
24
similar(tensor::Tensor) = setstorage(tensor, similar(storage(tensor)))
35

0 commit comments

Comments
 (0)