Skip to content

Commit d0b64d1

Browse files
committed
Vendor TypeParameterAccessors
1 parent 40dc0a3 commit d0b64d1

File tree

15 files changed

+652
-1
lines changed

15 files changed

+652
-1
lines changed

NDTensors/src/NDTensors.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
module NDTensors
2+
3+
module Vendored
4+
include(joinpath("vendored", "TypeParameterAccessors", "src", "TypeParameterAccessors.jl"))
5+
end
6+
27
#####################################
38
# Imports and exports
49
#
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
module AMDGPUExtensions
2+
3+
module Vendored
4+
include(
5+
joinpath(
6+
"..", "..", "..", "vendored", "TypeParameterAccessors", "src",
7+
"TypeParameterAccessors.jl",
8+
)
9+
)
10+
end
11+
212
include("roc.jl")
313

414
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
module CUDAExtensions
2+
3+
module Vendored
4+
include(
5+
joinpath(
6+
"..", "..", "..", "vendored", "TypeParameterAccessors", "src",
7+
"TypeParameterAccessors.jl",
8+
)
9+
)
10+
end
11+
212
include("cuda.jl")
313

414
end

NDTensors/src/lib/Expose/src/Expose.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
module Expose
2+
3+
module Vendored
4+
include(
5+
joinpath(
6+
"..", "..", "..", "vendored", "TypeParameterAccessors", "src",
7+
"TypeParameterAccessors.jl",
8+
)
9+
)
10+
end
11+
212
using SimpleTraits
313
using LinearAlgebra
414
using Base: ReshapedArray
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
module GPUArraysCoreExtensions
2+
3+
module Vendored
4+
include(
5+
joinpath(
6+
"..", "..", "..", "vendored", "TypeParameterAccessors", "src",
7+
"TypeParameterAccessors.jl",
8+
)
9+
)
10+
end
11+
212
include("gpuarrayscore.jl")
313

414
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
11
module MetalExtensions
2+
3+
module Vendored
4+
include(
5+
joinpath(
6+
"..", "..", "..", "vendored", "TypeParameterAccessors", "src",
7+
"TypeParameterAccessors.jl",
8+
)
9+
)
10+
end
11+
212
include("metal.jl")
313

414
end
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
module RankFactorization
2+
3+
module Vendored
4+
include(
5+
joinpath(
6+
"..", "..", "..", "vendored", "TypeParameterAccessors", "src",
7+
"TypeParameterAccessors.jl",
8+
)
9+
)
10+
end
11+
212
include("default_kwargs.jl")
313
include("truncate_spectrum.jl")
414
include("spectrum.jl")
15+
516
end
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module TypeParameterAccessors
2+
3+
# Exports
4+
export type_parameters, get_type_parameters
5+
export nparameters, is_parameter_specified
6+
export default_type_parameters
7+
export set_type_parameters, set_default_type_parameters
8+
export specify_type_parameters, specify_default_type_parameters
9+
export unspecify_type_parameters
10+
11+
# Imports
12+
using SimpleTraits: SimpleTraits, @traitdef, @traitimpl
13+
14+
include("type_utils.jl")
15+
include("type_parameters.jl")
16+
17+
# Implementations
18+
include("ndims.jl")
19+
include("base/abstractarray.jl")
20+
include("base/similartype.jl")
21+
include("base/linearalgebra.jl")
22+
23+
end
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
struct Self end
2+
position(a, ::Self) = Position(0)
3+
position(::Type, ::Self) = Position(0)
4+
function set_type_parameters(type::Type, ::Self, param)
5+
return error("Can't set the parent type of an unwrapped array type.")
6+
end
7+
8+
position(::Type{AbstractArray}, ::typeof(eltype)) = Position(1)
9+
position(::Type{AbstractArray}, ::typeof(ndims)) = Position(2)
10+
default_type_parameters(::Type{AbstractArray}) = (Float64, 1)
11+
12+
position(::Type{<:Array}, ::typeof(eltype)) = Position(1)
13+
position(::Type{<:Array}, ::typeof(ndims)) = Position(2)
14+
default_type_parameters(::Type{<:Array}) = (Float64, 1)
15+
16+
position(::Type{<:BitArray}, ::typeof(ndims)) = Position(1)
17+
default_type_parameters(::Type{<:BitArray}) = (1,)
18+
19+
function set_eltype(array::AbstractArray, param)
20+
return convert(set_eltype(typeof(array), param), array)
21+
end
22+
23+
## This will fail if position of `ndims` is not defined for `type`
24+
function set_ndims(type::Type{<:AbstractArray}, param)
25+
return set_type_parameters(type, ndims, param)
26+
end
27+
function set_ndims(type::Type{<:AbstractArray}, param::NDims)
28+
return set_type_parameters(type, ndims, ndims(param))
29+
end
30+
31+
# Trait indicating if the AbstractArray type is an array wrapper.
32+
# Assumes that it implements `NDTensors.parenttype`.
33+
@traitdef IsWrappedArray{ArrayType}
34+
35+
#! format: off
36+
@traitimpl IsWrappedArray{ArrayType} <- is_wrapped_array(ArrayType)
37+
#! format: on
38+
39+
parenttype(type::Type{<:AbstractArray}) = type_parameters(type, parenttype)
40+
parenttype(object::AbstractArray) = parenttype(typeof(object))
41+
position(::Type{<:AbstractArray}, ::typeof(parenttype)) = Self()
42+
43+
is_wrapped_array(arraytype::Type{<:AbstractArray}) = (parenttype(arraytype) arraytype)
44+
@inline is_wrapped_array(array::AbstractArray) = is_wrapped_array(typeof(array))
45+
@inline is_wrapped_array(object) = false
46+
47+
using SimpleTraits: Not, @traitfn
48+
49+
@traitfn function unwrap_array_type(
50+
arraytype::Type{ArrayType}
51+
) where {{ArrayType; IsWrappedArray{ArrayType}}}
52+
return unwrap_array_type(parenttype(arraytype))
53+
end
54+
55+
@traitfn function unwrap_array_type(
56+
arraytype::Type{ArrayType}
57+
) where {{ArrayType; !IsWrappedArray{ArrayType}}}
58+
return arraytype
59+
end
60+
61+
# For working with instances.
62+
unwrap_array_type(array::AbstractArray) = unwrap_array_type(typeof(array))
63+
64+
function set_parenttype(t::Type, param)
65+
return set_type_parameters(t, parenttype, param)
66+
end
67+
68+
@traitfn function set_eltype(
69+
type::Type{ArrayType}, param
70+
) where {{ArrayType <: AbstractArray; IsWrappedArray{ArrayType}}}
71+
new_parenttype = set_eltype(parenttype(type), param)
72+
# Need to set both in one `set_type_parameters` call to avoid
73+
# conflicts in type parameter constraints of certain wrapper types.
74+
return set_type_parameters(type, (eltype, parenttype), (param, new_parenttype))
75+
end
76+
77+
@traitfn function set_eltype(
78+
type::Type{ArrayType}, param
79+
) where {{ArrayType <: AbstractArray; !IsWrappedArray{ArrayType}}}
80+
return set_type_parameters(type, eltype, param)
81+
end
82+
83+
for wrapper in [:PermutedDimsArray, :(Base.ReshapedArray), :SubArray]
84+
@eval begin
85+
position(type::Type{<:$wrapper}, ::typeof(eltype)) = Position(1)
86+
position(type::Type{<:$wrapper}, ::typeof(ndims)) = Position(2)
87+
end
88+
end
89+
for wrapper in [:(Base.ReshapedArray), :SubArray]
90+
@eval position(type::Type{<:$wrapper}, ::typeof(parenttype)) = Position(3)
91+
end
92+
for wrapper in [:PermutedDimsArray]
93+
@eval position(type::Type{<:$wrapper}, ::typeof(parenttype)) = Position(5)
94+
end
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using LinearAlgebra:
2+
Adjoint,
3+
Diagonal,
4+
Hermitian,
5+
LowerTriangular,
6+
Symmetric,
7+
Transpose,
8+
UnitLowerTriangular,
9+
UnitUpperTriangular,
10+
UpperTriangular
11+
12+
for wrapper in [
13+
:Transpose,
14+
:Adjoint,
15+
:Symmetric,
16+
:Hermitian,
17+
:UpperTriangular,
18+
:LowerTriangular,
19+
:UnitUpperTriangular,
20+
:UnitLowerTriangular,
21+
:Diagonal,
22+
]
23+
@eval position(::Type{<:$wrapper}, ::typeof(eltype)) = Position(1)
24+
@eval position(::Type{<:$wrapper}, ::typeof(parenttype)) = Position(2)
25+
end

0 commit comments

Comments
 (0)