Skip to content

Commit 66c238d

Browse files
authored
More default_type_parameters definitions (#35)
1 parent 2402cc8 commit 66c238d

9 files changed

+137
-50
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TypeParameterAccessors"
22
uuid = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module TypeParameterAccessorsAMDGPUExt
22

3-
using AMDGPU: ROCArray
3+
using AMDGPU: AMDGPU, ROCArray
44
using TypeParameterAccessors: TypeParameterAccessors, Position
55

66
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(eltype)) = Position(1)
77
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(ndims)) = Position(2)
8+
function TypeParameterAccessors.default_type_parameters(::Type{<:ROCArray})
9+
return (Float64, 1, AMDGPU.Mem.HIPBuffer)
10+
end
811

912
end
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module TypeParameterAccessorsCUDAExt
22

3-
using CUDA: CuArray
3+
using CUDA: CUDA, CuArray
44
using TypeParameterAccessors: TypeParameterAccessors, Position
55

66
TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype)) = Position(1)
77
TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(ndims)) = Position(2)
8+
function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
9+
return (Float64, 1, CUDA.default_memory)
10+
end
811

912
end

ext/TypeParameterAccessorsJLArraysExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ using TypeParameterAccessors: TypeParameterAccessors, Position
55

66
TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(eltype)) = Position(1)
77
TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(ndims)) = Position(2)
8+
TypeParameterAccessors.default_type_parameters(::Type{<:JLArray}) = (Float64, 1)
89

910
end
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module TypeParameterAccessorsMetalExt
22

3-
using Metal: MtlArray
3+
using Metal: Metal, MtlArray
44
using TypeParameterAccessors: TypeParameterAccessors, Position
55

66
TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype)) = Position(1)
77
TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(ndims)) = Position(2)
8+
function TypeParameterAccessors.default_type_parameters(::Type{<:MtlArray})
9+
return (Float64, 1, Metal.DefaultStorageMode)
10+
end
811

912
end

ext/TypeParameterAccessorsStridedViewsExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,8 @@ TypeParameterAccessors.position(::Type{<:StridedView}, ::typeof(ndims)) = Positi
88
function TypeParameterAccessors.position(::Type{<:StridedView}, ::typeof(parenttype))
99
return Position(3)
1010
end
11+
function TypeParameterAccessors.default_type_parameters(::Type{<:StridedView})
12+
return (Float64, 1, Vector{Float64}, typeof(identity))
13+
end
1114

1215
end
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module TypeParameterAccessorsoneAPIExt
22

3-
using oneAPI: oneArray
3+
using oneAPI: oneAPI, oneArray
44
using TypeParameterAccessors: TypeParameterAccessors, Position
55

66
TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(eltype)) = Position(1)
77
TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(ndims)) = Position(2)
8+
function TypeParameterAccessors.default_type_parameters(::Type{<:oneAPI})
9+
return (Float64, 1, oneAPI.oneL0.DeviceBuffer)
10+
end
811

912
end

src/type_parameters.jl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,18 @@ end
7070
# julia> Base.unwrap_unionall(supertype(MyArray)).parameters
7171
# svec(A, B)
7272
# ```
73-
supertype_params = Base.unwrap_unionall(supertype(T)).parameters
73+
supertype_params = Base.unwrap_unionall(supertype(T)).parameters
7474
supertype_param = supertype_params[Int(supertype_pos)]
75+
if !(supertype_param isa TypeVar)
76+
error("Position not found.")
77+
end
7578
pos = findfirst(param -> (param.name == supertype_param.name), type_params)
7679
if isnothing(pos)
7780
return error("Position not found.")
7881
end
7982
return :(@inline; $(Position(pos)))
8083
end
8184

82-
# Automatically determine the position of a type parameter of a type given
83-
# a supertype and the name of the parameter.
84-
function position_from_supertype(type::Type, supertype_target::Type, name)
85-
return position_from_supertype(type, supertype_target, position(supertype_target, name))
86-
end
87-
8885
function positions(::Type{T}, pos::Tuple) where {T}
8986
return ntuple(length(pos)) do i
9087
return position(T, pos[i])
@@ -256,11 +253,48 @@ function default_type_parameters(::Type{T}, pos) where {T}
256253
return default_type_parameters(T, position(T, pos))
257254
end
258255
function default_type_parameters(::Type{T}, ::Position{pos}) where {T,pos}
259-
return default_type_parameters(T)[pos]
256+
param = default_type_parameters(T)[pos]
257+
if param isa UndefinedDefaultTypeParameter
258+
return error("No default parameter defined at this position.")
259+
end
260+
return param
260261
end
261262
default_type_parameters(::Type{T}, pos::Tuple) where {T} = default_type_parameters.(T, pos)
262-
default_type_parameters(t) = default_type_parameters(typeof(t))
263263
default_type_parameters(t, pos) = default_type_parameters(typeof(t), pos)
264+
default_type_parameters(t) = default_type_parameters(typeof(t))
265+
function default_type_parameters(type::Type)
266+
type′ = unspecify_type_parameters(type)
267+
if type === type′
268+
return default_type_parameters_from_supertype(type′)
269+
end
270+
return default_type_parameters(type′)
271+
end
272+
273+
struct UndefinedDefaultTypeParameter end
274+
275+
# Determine the default type parameters of a type from the default type
276+
# parameters of the supertype of the type. Uses similar logic as
277+
# `position_from_supertype_position` for matching TypeVar names
278+
# between the type and the supertype. Type parameters that exist
279+
# in the type but not the supertype will have a default type parameter
280+
# `UndefinedDefaultTypeParameter()`. Accessing those type parameters
281+
# by name/position will throw an error.
282+
@generated function default_type_parameters_from_supertype(::Type{T}) where {T}
283+
T′ = unspecify_type_parameters(T)
284+
supertype_default_type_params = default_type_parameters(supertype(T′))
285+
type_params = Base.unwrap_unionall(T′).parameters
286+
supertype_params = Base.unwrap_unionall(supertype(T′)).parameters
287+
defaults = Any[UndefinedDefaultTypeParameter() for _ in 1:nparameters(T′)]
288+
for (supertype_param, supertype_default_type_param) in
289+
zip(supertype_params, supertype_default_type_params)
290+
if !(supertype_param isa TypeVar)
291+
continue
292+
end
293+
param_position = findfirst(param -> (param.name == supertype_param.name), type_params)
294+
defaults[param_position] = supertype_default_type_param
295+
end
296+
return :(@inline; $(Tuple(defaults)))
297+
end
264298

265299
"""
266300
set_default_type_parameters(type::Type, [positions::Tuple])

test/test_defaults.jl

Lines changed: 73 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,63 @@
1-
using Test: @test_throws, @testset
1+
using JLArrays: JLArray
2+
using Test: @test, @test_throws, @testset
3+
using TestExtras: @constinferred
24
using TypeParameterAccessors:
35
TypeParameterAccessors,
46
Position,
57
default_type_parameters,
68
set_default_type_parameters,
79
specify_default_type_parameters
8-
using TestExtras: @constinferred
9-
@testset "TypeParameterAccessors defaults" begin
10-
@testset "Erroneously requires wrapping to infer" begin end
10+
11+
const arrayts = (Array, JLArray)
12+
@testset "TypeParameterAccessors defaults $arrayt" for arrayt in arrayts
13+
vectort = arrayt{<:Any,1}
1114
@testset "Get defaults" begin
12-
@test @constinferred(default_type_parameters($Array, 1)) == Float64
13-
@test @constinferred(default_type_parameters($Array, 2)) == 1
14-
@test @constinferred(default_type_parameters($Array)) == (Float64, 1)
15-
@test @constinferred(default_type_parameters($Array, $((2, 1)))) == (1, Float64)
16-
@test @constinferred(broadcast($default_type_parameters, $Array, (ndims, eltype))) ==
15+
@test @constinferred(default_type_parameters($arrayt, 1)) == Float64
16+
@test @constinferred(default_type_parameters($arrayt, 2)) == 1
17+
@test @constinferred(default_type_parameters($arrayt)) == (Float64, 1)
18+
@test @constinferred(default_type_parameters($arrayt, $((2, 1)))) == (1, Float64)
19+
@test @constinferred(broadcast($default_type_parameters, $arrayt, (ndims, eltype))) ==
1720
(1, Float64)
18-
@test @constinferred(broadcast($default_type_parameters, $Array, $((2, 1)))) ==
21+
@test @constinferred(broadcast($default_type_parameters, $arrayt, $((2, 1)))) ==
1922
(1, Float64)
20-
@test @constinferred(broadcast($default_type_parameters, $Array, (ndims, eltype))) ==
23+
@test @constinferred(broadcast($default_type_parameters, $arrayt, (ndims, eltype))) ==
2124
(1, Float64)
2225
end
2326

2427
@testset "Set defaults" begin
25-
@test @constinferred(set_default_type_parameters($(Array{Float32}), 1)) ==
26-
Array{Float64}
27-
@test @constinferred(set_default_type_parameters($(Array{Float32}), eltype)) ==
28-
Array{Float64}
29-
@test @constinferred(set_default_type_parameters($(Array{Float32}))) == Vector{Float64}
30-
@test @constinferred(set_default_type_parameters($(Array{Float32}), $((1, 2)))) ==
31-
Vector{Float64}
32-
@test @constinferred(set_default_type_parameters($(Array{Float32}), (eltype, ndims))) ==
33-
Vector{Float64}
34-
@test @constinferred(set_default_type_parameters($Array)) == Vector{Float64}
35-
@test @constinferred(set_default_type_parameters($Array, 1)) == Array{Float64}
36-
@test @constinferred(set_default_type_parameters($Array, $((1, 2)))) == Vector{Float64}
28+
@test @constinferred(set_default_type_parameters($(arrayt{Float32}), 1)) ==
29+
arrayt{Float64}
30+
@test @constinferred(set_default_type_parameters($(arrayt{Float32}), eltype)) ==
31+
arrayt{Float64}
32+
@test @constinferred(set_default_type_parameters($(arrayt{Float32}))) ==
33+
vectort{Float64}
34+
@test @constinferred(set_default_type_parameters($(arrayt{Float32}), $((1, 2)))) ==
35+
vectort{Float64}
36+
@test @constinferred(
37+
set_default_type_parameters($(arrayt{Float32}), (eltype, ndims))
38+
) == vectort{Float64}
39+
@test @constinferred(set_default_type_parameters($arrayt)) == vectort{Float64}
40+
@test @constinferred(set_default_type_parameters($arrayt, 1)) == arrayt{Float64}
41+
@test @constinferred(set_default_type_parameters($arrayt, $((1, 2)))) ==
42+
vectort{Float64}
3743
end
3844

3945
@testset "Specify defaults" begin
40-
@test @constinferred(specify_default_type_parameters($Array, 1)) == Array{Float64}
41-
@test @constinferred(specify_default_type_parameters($Array, eltype)) == Array{Float64}
42-
@test @constinferred(specify_default_type_parameters($Array, 2)) == Vector
43-
@test @constinferred(specify_default_type_parameters($Array, ndims)) == Vector
44-
@test @constinferred(specify_default_type_parameters($Array)) == Vector{Float64}
45-
@test @constinferred(specify_default_type_parameters($Array, 1)) == Array{Float64}
46-
@test @constinferred(specify_default_type_parameters($Array, eltype)) == Array{Float64}
47-
@test @constinferred(specify_default_type_parameters($Array, 2)) == Vector
48-
@test @constinferred(specify_default_type_parameters($Array, ndims)) == Vector
49-
@test @constinferred(specify_default_type_parameters($Array, $((1, 2)))) ==
50-
Vector{Float64}
51-
@test @constinferred(specify_default_type_parameters($Array, (eltype, ndims))) ==
52-
Vector{Float64}
46+
@test @constinferred(specify_default_type_parameters($arrayt, 1)) == arrayt{Float64}
47+
@test @constinferred(specify_default_type_parameters($arrayt, eltype)) ==
48+
arrayt{Float64}
49+
@test @constinferred(specify_default_type_parameters($arrayt, 2)) == vectort
50+
@test @constinferred(specify_default_type_parameters($arrayt, ndims)) == vectort
51+
@test @constinferred(specify_default_type_parameters($arrayt)) == vectort{Float64}
52+
@test @constinferred(specify_default_type_parameters($arrayt, 1)) == arrayt{Float64}
53+
@test @constinferred(specify_default_type_parameters($arrayt, eltype)) ==
54+
arrayt{Float64}
55+
@test @constinferred(specify_default_type_parameters($arrayt, 2)) == vectort
56+
@test @constinferred(specify_default_type_parameters($arrayt, ndims)) == vectort
57+
@test @constinferred(specify_default_type_parameters($arrayt, $((1, 2)))) ==
58+
vectort{Float64}
59+
@test @constinferred(specify_default_type_parameters($arrayt, (eltype, ndims))) ==
60+
vectort{Float64}
5361
end
5462

5563
@testset "On objects" begin
@@ -60,4 +68,33 @@ using TestExtras: @constinferred
6068
@test @constinferred(default_type_parameters(a, ndims)) == 1
6169
@test @constinferred(default_type_parameters(a)) == (Float64, 1)
6270
end
71+
72+
@testset "Automatic fallback for defaults" begin
73+
struct MyArray{B,A} <: AbstractArray{A,B} end
74+
@test @constinferred(default_type_parameters(MyArray)) === (1, Float64)
75+
@test @constinferred(default_type_parameters(MyArray{2,Float32})) === (1, Float64)
76+
@test @constinferred(default_type_parameters($MyArray, 1)) === 1
77+
@test @constinferred(default_type_parameters($MyArray, 2)) === Float64
78+
@test @constinferred(default_type_parameters(MyArray, eltype)) === Float64
79+
@test @constinferred(default_type_parameters(MyArray, ndims)) === 1
80+
81+
und = TypeParameterAccessors.UndefinedDefaultTypeParameter()
82+
83+
struct MyVector{X,Y,A<:Real} <: AbstractArray{A,1} end
84+
@test @constinferred(default_type_parameters(MyVector)) === (und, und, Float64)
85+
@test_throws ErrorException default_type_parameters(MyVector, 1)
86+
@test_throws ErrorException default_type_parameters(MyVector, 2)
87+
@test @constinferred(default_type_parameters($MyVector, 3)) === Float64
88+
@test @constinferred(default_type_parameters(MyVector, eltype)) === Float64
89+
@test_throws ErrorException default_type_parameters(MyVector, ndims)
90+
91+
struct MyBoolArray{X,Y,Z,B} <: AbstractArray{Bool,B} end
92+
@test @constinferred(default_type_parameters(MyBoolArray)) === (und, und, und, 1)
93+
@test_throws ErrorException default_type_parameters(MyBoolArray, 1)
94+
@test_throws ErrorException default_type_parameters(MyBoolArray, 2)
95+
@test_throws ErrorException default_type_parameters(MyBoolArray, 3)
96+
@test @constinferred(default_type_parameters($MyBoolArray, 4)) === 1
97+
@test_throws ErrorException default_type_parameters(MyBoolArray, eltype)
98+
@test @constinferred(default_type_parameters(MyBoolArray, ndims)) === 1
99+
end
63100
end

0 commit comments

Comments
 (0)