Skip to content

Commit 24b3add

Browse files
authored
GPU arrays (#29)
1 parent ca6b2c0 commit 24b3add

8 files changed

+136
-59
lines changed

Project.toml

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,38 @@
11
name = "TypeParameterAccessors"
22
uuid = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.0"
4+
version = "0.3.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
99

1010
[weakdeps]
11-
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
11+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
12+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1213
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
14+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
15+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
16+
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
17+
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
18+
19+
[extensions]
20+
TypeParameterAccessorsAMDGPUExt = "AMDGPU"
21+
TypeParameterAccessorsCUDAExt = "CUDA"
22+
TypeParameterAccessorsFillArraysExt = "FillArrays"
23+
TypeParameterAccessorsJLArraysExt = "JLArrays"
24+
TypeParameterAccessorsMetalExt = "Metal"
25+
TypeParameterAccessorsStridedViewsExt = "StridedViews"
26+
TypeParameterAccessorsoneAPIExt = "oneAPI"
1327

1428
[compat]
29+
AMDGPU = "0, 1"
30+
CUDA = "3, 4, 5"
1531
FillArrays = "1.13"
32+
JLArrays = "0.1, 0.2"
1633
LinearAlgebra = "1.10"
34+
Metal = "0, 1"
1735
SimpleTraits = "0.9.4"
1836
StridedViews = "0.3.2"
1937
julia = "1.10"
20-
21-
[extensions]
22-
TypeParameterAccessorsStridedViewsExt = "StridedViews"
23-
TypeParameterAccessorsFillArraysExt = "FillArrays"
38+
oneAPI = "0, 1, 2"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module TypeParameterAccessorsAMDGPUExt
2+
3+
using AMDGPU: ROCArray
4+
using TypeParameterAccessors: TypeParameterAccessors, Position
5+
6+
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(eltype)) = Position(1)
7+
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(ndims)) = Position(2)
8+
9+
end
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module TypeParameterAccessorsCUDAExt
2+
3+
using CUDA: CuArray
4+
using TypeParameterAccessors: TypeParameterAccessors, Position
5+
6+
TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype)) = Position(1)
7+
TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(ndims)) = Position(2)
8+
9+
end
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module TypeParameterAccessorsJLArraysExt
2+
3+
using JLArrays: JLArray
4+
using TypeParameterAccessors: TypeParameterAccessors, Position
5+
6+
TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(eltype)) = Position(1)
7+
TypeParameterAccessors.position(::Type{<:JLArray}, ::typeof(ndims)) = Position(2)
8+
9+
end
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module TypeParameterAccessorsMetalExt
2+
3+
using Metal: MtlArray
4+
using TypeParameterAccessors: TypeParameterAccessors, Position
5+
6+
TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype)) = Position(1)
7+
TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(ndims)) = Position(2)
8+
9+
end
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module TypeParameterAccessorsoneAPIExt
2+
3+
using oneAPI: oneArray
4+
using TypeParameterAccessors: TypeParameterAccessors, Position
5+
6+
TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(eltype)) = Position(1)
7+
TypeParameterAccessors.position(::Type{<:oneArray}, ::typeof(ndims)) = Position(2)
8+
9+
end

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
67
StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
@@ -12,8 +13,8 @@ TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1213
[compat]
1314
Aqua = "0.8.9"
1415
FillArrays = "1.13"
15-
StridedViews = "0.3"
1616
SafeTestsets = "0.1"
17+
StridedViews = "0.3"
1718
Suppressor = "0.2"
1819
Test = "1.10"
1920
TestExtras = "0.3"

test/test_basics.jl

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,80 @@
1+
using JLArrays: JLArray, JLMatrix, JLVector
12
using Test: @test, @test_throws, @test_broken, @testset
23
using TestExtras: @constinferred
34
using TypeParameterAccessors:
45
set_type_parameters, specify_type_parameters, type_parameters, unspecify_type_parameters
56

6-
@testset "Get parameters" begin
7-
@test @constinferred(type_parameters($(AbstractArray{Float64}), 1)) == Float64
8-
@test @constinferred(type_parameters($(AbstractArray{Float64}), eltype)) == Float64
9-
@test @constinferred(type_parameters($(AbstractMatrix{Float64}), ndims)) == 2
7+
const anyarrayts = (
8+
(arrayt=Array, matrixt=Matrix, vectort=Vector),
9+
(arrayt=JLArray, matrixt=JLMatrix, vectort=JLVector),
10+
)
11+
@testset "basics (arrayts=$anyarrayt)" for anyarrayt in anyarrayts
12+
(; arrayt, matrixt, vectort) = anyarrayt
1013

11-
@test @constinferred(type_parameters($(Array{Float64}), 1)) == Float64
12-
@test @constinferred(type_parameters($(Val{3}))) == (3,)
14+
@testset "Get parameters" begin
15+
@test @constinferred(type_parameters($(AbstractArray{Float64}), 1)) == Float64
16+
@test @constinferred(type_parameters($(AbstractArray{Float64}), eltype)) == Float64
17+
@test @constinferred(type_parameters($(AbstractMatrix{Float64}), ndims)) == 2
1318

14-
# @test_throws ErrorException type_parameter(Array, 1)
15-
@test @constinferred(type_parameters($(Array{Float64}), eltype)) == Float64
16-
@test @constinferred(type_parameters($(Matrix{Float64}), ndims)) == 2
17-
@test @constinferred(type_parameters($(Matrix{Float64}), (ndims, eltype))) == (2, Float64)
18-
# TODO: Not inferrable without interpolating positions:
19-
# https://github.com/ITensor/TypeParameterAccessors.jl/issues/21.
20-
@test @constinferred(type_parameters($(Matrix{Float64}), $((2, eltype)))) == (2, Float64)
21-
@test @constinferred(type_parameters($(Matrix{Float64}), (ndims, eltype))) == (2, Float64)
22-
# @test_throws ErrorException type_parameters(Array{Float64}, ndims) == 2
23-
@test @constinferred(broadcast($type_parameters, $(Matrix{Float64}), $((2, eltype)))) ==
24-
(2, Float64)
25-
end
19+
@test @constinferred(type_parameters($(arrayt{Float64}), 1)) == Float64
20+
@test @constinferred(type_parameters($(Val{3}))) == (3,)
2621

27-
@testset "Set parameters" begin
28-
@test @constinferred(set_type_parameters($Array, 1, $Float64)) == Array{Float64}
29-
@test @constinferred(set_type_parameters($Array, 2, 2)) == Matrix
30-
@test @constinferred(set_type_parameters($Array, $eltype, $Float32)) == Array{Float32}
31-
@test @constinferred(set_type_parameters($Array, $((eltype, 2)), $((Float32, 3)))) ==
32-
Array{Float32,3}
33-
end
22+
# @test_throws ErrorException type_parameter(arrayt, 1)
23+
@test @constinferred(type_parameters($(arrayt{Float64}), eltype)) == Float64
24+
@test @constinferred(type_parameters($(matrixt{Float64}), ndims)) == 2
25+
@test @constinferred(type_parameters($(matrixt{Float64}), (ndims, eltype))) ==
26+
(2, Float64)
27+
# TODO: Not inferrable without interpolating positions:
28+
# https://github.com/ITensor/TypeParameterAccessors.jl/issues/21.
29+
@test @constinferred(type_parameters($(matrixt{Float64}), $((2, eltype)))) ==
30+
(2, Float64)
31+
@test @constinferred(type_parameters($(matrixt{Float64}), (ndims, eltype))) ==
32+
(2, Float64)
33+
# @test_throws ErrorException type_parameters(arrayt{Float64}, ndims) == 2
34+
@test @constinferred(
35+
broadcast($type_parameters, $(matrixt{Float64}), $((2, eltype)))
36+
) == (2, Float64)
37+
end
3438

35-
@testset "Specify parameters" begin
36-
@test @constinferred(specify_type_parameters($Array, 1, $Float64)) == Array{Float64}
37-
@test @constinferred(specify_type_parameters($Matrix, $((2, 1)), $((4, Float32)))) ==
38-
Matrix{Float32}
39-
@test @constinferred(specify_type_parameters($Array, $((Float64, 2)))) == Matrix{Float64}
40-
@test @constinferred(specify_type_parameters($Array, $eltype, $Float32)) == Array{Float32}
41-
@test @constinferred(specify_type_parameters($Array, $((eltype, 2)), $((Float32, 3)))) ==
42-
Array{Float32,3}
43-
end
39+
@testset "Set parameters" begin
40+
@test @constinferred(set_type_parameters($arrayt, 1, $Float64)) == arrayt{Float64}
41+
@test @constinferred(set_type_parameters($arrayt, 2, 2)) == matrixt
42+
@test @constinferred(set_type_parameters($arrayt, $eltype, $Float32)) == arrayt{Float32}
43+
@test @constinferred(set_type_parameters($arrayt, $((eltype, 2)), $((Float32, 3)))) ==
44+
arrayt{Float32,3}
45+
end
4446

45-
@testset "Unspecify parameters" begin
46-
@test @constinferred(unspecify_type_parameters($Vector, 2)) == Array
47-
@test @constinferred(unspecify_type_parameters($(Vector{Float64}), eltype)) == Vector
48-
@test @constinferred(unspecify_type_parameters($(Vector{Float64}))) == Array
49-
@test @constinferred(unspecify_type_parameters($(Vector{Float64}), $((eltype, 2)))) ==
50-
Array
51-
end
47+
@testset "Specify parameters" begin
48+
@test @constinferred(specify_type_parameters($arrayt, 1, $Float64)) == arrayt{Float64}
49+
@test @constinferred(specify_type_parameters($matrixt, $((2, 1)), $((4, Float32)))) ==
50+
matrixt{Float32}
51+
@test @constinferred(specify_type_parameters($arrayt, $((Float64, 2)))) ==
52+
matrixt{Float64}
53+
@test @constinferred(specify_type_parameters($arrayt, $eltype, $Float32)) ==
54+
arrayt{Float32}
55+
@test @constinferred(
56+
specify_type_parameters($arrayt, $((eltype, 2)), $((Float32, 3)))
57+
) == arrayt{Float32,3}
58+
end
59+
60+
@testset "Unspecify parameters" begin
61+
@test @constinferred(unspecify_type_parameters($vectort, 2)) == arrayt
62+
@test @constinferred(unspecify_type_parameters($(vectort{Float64}), eltype)) == vectort
63+
@test @constinferred(unspecify_type_parameters($(vectort{Float64}))) == arrayt
64+
@test @constinferred(unspecify_type_parameters($(vectort{Float64}), $((eltype, 2)))) ==
65+
arrayt
66+
end
5267

53-
@testset "On objects" begin
54-
@test @constinferred(type_parameters($(Val{3}()))) == (3,)
55-
@test @constinferred(type_parameters($(Val{Float32}()))) == (Float32,)
56-
a = randn(Float32, (2, 2, 2))
57-
@test @constinferred(type_parameters(a, 1)) == Float32
58-
@test @constinferred(type_parameters(a, eltype)) == Float32
59-
@test @constinferred(type_parameters(a, 2)) == 3
60-
@test @constinferred(type_parameters(a, ndims)) == 3
61-
@test @constinferred(type_parameters(a)) == (Float32, 3)
62-
@test @constinferred(broadcast($type_parameters, $(Ref(a)), $((2, eltype)))) ==
63-
(3, Float32)
68+
@testset "On objects" begin
69+
@test @constinferred(type_parameters($(Val{3}()))) == (3,)
70+
@test @constinferred(type_parameters($(Val{Float32}()))) == (Float32,)
71+
a = arrayt(randn(Float32, (2, 2, 2)))
72+
@test @constinferred(type_parameters(a, 1)) == Float32
73+
@test @constinferred(type_parameters(a, eltype)) == Float32
74+
@test @constinferred(type_parameters(a, 2)) == 3
75+
@test @constinferred(type_parameters(a, ndims)) == 3
76+
@test @constinferred(type_parameters(a)) == (Float32, 3)
77+
@test @constinferred(broadcast($type_parameters, $(Ref(a)), $((2, eltype)))) ==
78+
(3, Float32)
79+
end
6480
end

0 commit comments

Comments
 (0)