Skip to content

Commit 29cce1c

Browse files
authored
[NDTensors] JLArrays Extension (#1508)
1 parent 1ef12d6 commit 29cce1c

File tree

15 files changed

+201
-58
lines changed

15 files changed

+201
-58
lines changed

NDTensors/Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3636
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3737
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
3838
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
39+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3940
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
4041
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4142
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
@@ -47,6 +48,7 @@ NDTensorsAMDGPUExt = ["AMDGPU", "GPUArraysCore"]
4748
NDTensorsCUDAExt = ["CUDA", "GPUArraysCore"]
4849
NDTensorsGPUArraysCoreExt = "GPUArraysCore"
4950
NDTensorsHDF5Ext = "HDF5"
51+
NDTensorsJLArraysExt = ["GPUArraysCore", "JLArrays"]
5052
NDTensorsMappedArraysExt = ["MappedArrays"]
5153
NDTensorsMetalExt = ["GPUArraysCore", "Metal"]
5254
NDTensorsOctavianExt = "Octavian"
@@ -70,15 +72,16 @@ GPUArraysCore = "0.1"
7072
HDF5 = "0.14, 0.15, 0.16, 0.17"
7173
HalfIntegers = "1"
7274
InlineStrings = "1"
73-
LinearAlgebra = "1.6"
75+
JLArrays = "0.1"
76+
LinearAlgebra = "<0.0.1, 1.6"
7477
MacroTools = "0.5"
7578
MappedArrays = "0.4"
7679
Metal = "1"
7780
Octavian = "0.3"
7881
PackageExtensionCompat = "1"
79-
Random = "1.6"
82+
Random = "<0.0.1, 1.6"
8083
SimpleTraits = "0.9.4"
81-
SparseArrays = "1.6"
84+
SparseArrays = "<0.0.1, 1.6"
8285
SplitApplyCombine = "1.2.2"
8386
StaticArrays = "0.12, 1.0"
8487
Strided = "2"
@@ -95,6 +98,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
9598
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
9699
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
97100
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
101+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
98102
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
99103
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
100104
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# TypeParameterAccessors definitions
2-
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
2+
using NDTensors.TypeParameterAccessors:
3+
TypeParameterAccessors, Position, default_type_parameters
34
using NDTensors.GPUArraysCoreExtensions: storagemode
45
using AMDGPU: AMDGPU, ROCArray
56

67
function TypeParameterAccessors.default_type_parameters(::Type{<:ROCArray})
7-
return (Float64, 1, AMDGPU.Mem.HIPBuffer)
8+
return (default_type_parameters(AbstractArray)..., AMDGPU.Mem.HIPBuffer)
89
end
9-
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(eltype)) = Position(1)
10-
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(ndims)) = Position(2)
10+
1111
TypeParameterAccessors.position(::Type{<:ROCArray}, ::typeof(storagemode)) = Position(3)
Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
# TypeParameterAccessors definitions
22
using CUDA: CUDA, CuArray
3-
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
3+
using NDTensors.TypeParameterAccessors:
4+
TypeParameterAccessors, Position, default_type_parameters
45
using NDTensors.GPUArraysCoreExtensions: storagemode
56

6-
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(eltype))
7-
return Position(1)
8-
end
9-
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(ndims))
10-
return Position(2)
11-
end
127
function TypeParameterAccessors.position(::Type{<:CuArray}, ::typeof(storagemode))
138
return Position(3)
149
end
1510

1611
function TypeParameterAccessors.default_type_parameters(::Type{<:CuArray})
17-
return (Float64, 1, CUDA.Mem.DeviceBuffer)
12+
return (default_type_parameters(AbstractArray)..., CUDA.Mem.DeviceBuffer)
1813
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module NDTensorsJLArraysExt
2+
include("copyto.jl")
3+
include("indexing.jl")
4+
include("linearalgebra.jl")
5+
include("mul.jl")
6+
include("permutedims.jl")
7+
end
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using JLArrays: JLArray
2+
using NDTensors.Expose: Exposed, expose, unexpose
3+
using LinearAlgebra: Adjoint
4+
5+
# Same definition as `CuArray`.
6+
function Base.copy(src::Exposed{<:JLArray,<:Base.ReshapedArray})
7+
return reshape(copy(parent(src)), size(unexpose(src)))
8+
end
9+
10+
function Base.copy(
11+
src::Exposed{
12+
<:JLArray,<:SubArray{<:Any,<:Any,<:Base.ReshapedArray{<:Any,<:Any,<:Adjoint}}
13+
},
14+
)
15+
return copy(@view copy(expose(parent(src)))[parentindices(unexpose(src))...])
16+
end
17+
18+
# Catches a bug in `copyto!` in CUDA backend.
19+
function Base.copyto!(dest::Exposed{<:JLArray}, src::Exposed{<:JLArray,<:SubArray})
20+
copyto!(dest, expose(copy(src)))
21+
return unexpose(dest)
22+
end
23+
24+
# Catches a bug in `copyto!` in JLArray backend.
25+
function Base.copyto!(
26+
dest::Exposed{<:JLArray}, src::Exposed{<:JLArray,<:Base.ReshapedArray}
27+
)
28+
copyto!(dest, expose(parent(src)))
29+
return unexpose(dest)
30+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using JLArrays: JLArray
2+
using GPUArraysCore: @allowscalar
3+
using NDTensors: NDTensors
4+
using NDTensors.Expose: Exposed, expose, unexpose
5+
6+
function Base.getindex(E::Exposed{<:JLArray})
7+
return @allowscalar unexpose(E)[]
8+
end
9+
10+
function Base.setindex!(E::Exposed{<:JLArray}, x::Number)
11+
@allowscalar unexpose(E)[] = x
12+
return unexpose(E)
13+
end
14+
15+
function Base.getindex(E::Exposed{<:JLArray,<:Adjoint}, i, j)
16+
return (expose(parent(E))[j, i])'
17+
end
18+
19+
Base.any(f, E::Exposed{<:JLArray,<:NDTensors.Tensor}) = any(f, data(unexpose(E)))
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using Adapt: adapt
2+
using JLArrays: JLArray, JLMatrix
3+
using LinearAlgebra: LinearAlgebra, Hermitian, Symmetric, qr, eigen
4+
using NDTensors: NDTensors
5+
using NDTensors.Expose: Expose, expose, qr, qr_positive, ql, ql_positive
6+
using NDTensors.GPUArraysCoreExtensions: cpu
7+
using NDTensors.TypeParameterAccessors: unwrap_array_type
8+
9+
## TODO this function exists because of the same issue below. when
10+
## that issue is resolved we can rely on the abstractarray version of
11+
## this operation.
12+
function Expose.qr(A::Exposed{<:JLArray})
13+
Q, L = qr(unexpose(A))
14+
return adapt(unwrap_array_type(A), Matrix(Q)), adapt(unwrap_array_type(A), L)
15+
end
16+
## TODO this should work using a JLArray but there is an error converting the Q from its packed QR from
17+
## back into a JLArray see https://github.com/JuliaGPU/GPUArrays.jl/issues/545. To fix call cpu for now
18+
function Expose.qr_positive(A::Exposed{<:JLArray})
19+
Q, L = qr_positive(expose(cpu(A)))
20+
return adapt(unwrap_array_type(A), copy(Q)), adapt(unwrap_array_type(A), L)
21+
end
22+
23+
function Expose.ql(A::Exposed{<:JLMatrix})
24+
Q, L = ql(expose(cpu(A)))
25+
return adapt(unwrap_array_type(A), copy(Q)), adapt(unwrap_array_type(A), L)
26+
end
27+
function Expose.ql_positive(A::Exposed{<:JLMatrix})
28+
Q, L = ql_positive(expose(cpu(A)))
29+
return adapt(unwrap_array_type(A), copy(Q)), adapt(unwrap_array_type(A), L)
30+
end
31+
32+
function LinearAlgebra.eigen(A::Exposed{<:JLMatrix,<:Symmetric})
33+
q, l = (eigen(expose(cpu(A))))
34+
return adapt.(unwrap_array_type(A), (q, l))
35+
end
36+
37+
function LinearAlgebra.eigen(A::Exposed{<:JLMatrix,<:Hermitian})
38+
q, l = (eigen(expose(Hermitian(cpu(unexpose(A).data)))))
39+
return adapt.(JLArray, (q, l))
40+
end
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using JLArrays: JLArray
2+
using LinearAlgebra: LinearAlgebra, mul!, transpose
3+
using NDTensors.Expose: Exposed, expose, unexpose
4+
5+
function LinearAlgebra.mul!(
6+
CM::Exposed{<:JLArray,<:LinearAlgebra.Transpose},
7+
AM::Exposed{<:JLArray},
8+
BM::Exposed{<:JLArray},
9+
α,
10+
β,
11+
)
12+
mul!(transpose(CM), transpose(BM), transpose(AM), α, β)
13+
return unexpose(CM)
14+
end
15+
16+
function LinearAlgebra.mul!(
17+
CM::Exposed{<:JLArray,<:LinearAlgebra.Adjoint},
18+
AM::Exposed{<:JLArray},
19+
BM::Exposed{<:JLArray},
20+
α,
21+
β,
22+
)
23+
mul!(CM', BM', AM', α, β)
24+
return unexpose(CM)
25+
end
26+
27+
## Fix issue in JLArrays.jl where it cannot distinguish Transpose{Reshape{Adjoint{JLArray}}}
28+
## as a JLArray and calls generic matmul
29+
function LinearAlgebra.mul!(
30+
CM::Exposed{<:JLArray},
31+
AM::Exposed{<:JLArray},
32+
BM::Exposed{
33+
<:JLArray,
34+
<:LinearAlgebra.Transpose{
35+
<:Any,<:Base.ReshapedArray{<:Any,<:Any,<:LinearAlgebra.Adjoint}
36+
},
37+
},
38+
α,
39+
β,
40+
)
41+
mul!(CM, AM, expose(transpose(copy(expose(parent(BM))))), α, β)
42+
return unexpose(CM)
43+
end
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using JLArrays: JLArray
2+
using LinearAlgebra: Adjoint
3+
using NDTensors.Expose: Exposed, expose, unexpose
4+
5+
function Base.permutedims!(
6+
Edest::Exposed{<:JLArray,<:Base.ReshapedArray}, Esrc::Exposed{<:JLArray}, perm
7+
)
8+
Aperm = permutedims(Esrc, perm)
9+
copyto!(expose(parent(Edest)), expose(Aperm))
10+
return unexpose(Edest)
11+
end
12+
13+
## Found an issue in CUDA where if Edest is a reshaped{<:Adjoint}
14+
## .= can fail. So instead force Esrc into the shape of parent(Edest)
15+
function Base.permutedims!(
16+
Edest::Exposed{<:JLArray,<:Base.ReshapedArray{<:Any,<:Any,<:Adjoint}},
17+
Esrc::Exposed{<:JLArray},
18+
perm,
19+
f,
20+
)
21+
Aperm = reshape(permutedims(Esrc, perm), size(parent(Edest)))
22+
parent(Edest) .= f.(parent(Edest), Aperm)
23+
return unexpose(Edest)
24+
end

NDTensors/ext/NDTensorsMetalExt/set_types.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@ using Metal: Metal, MtlArray
44
using NDTensors.TypeParameterAccessors: TypeParameterAccessors, Position
55
using NDTensors.GPUArraysCoreExtensions: storagemode
66

7-
## TODO remove TypeParameterAccessors when SetParameters is removed
8-
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(eltype))
9-
return Position(1)
10-
end
11-
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(ndims))
12-
return Position(2)
13-
end
147
function TypeParameterAccessors.position(::Type{<:MtlArray}, ::typeof(storagemode))
158
return Position(3)
169
end

0 commit comments

Comments
 (0)