Skip to content

Commit d1af860

Browse files
committed
Move back cuarray files
1 parent 481dded commit d1af860

File tree

3 files changed

+37
-34
lines changed

3 files changed

+37
-34
lines changed

src/ArrayInterface.jl

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,7 @@ function __init__()
101101

102102
@require CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
103103
@require Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
104-
fast_scalar_indexing(::Type{<:CuArrays.CuArray}) = false
105-
@inline allowed_getindex(x::CuArrays.CuArray, i...) = CuArrays.@allowscalar(x[i...])
106-
@inline function allowed_setindex!(x::CuArrays.CuArray, v, i...)
107-
(CuArrays.@allowscalar(x[i...] = v))
108-
end
109-
110-
function Base.setindex(x::CuArrays.CuArray, v, i::Int)
111-
_x = copy(x)
112-
allowed_setindex!(_x, v, i)
113-
return _x
114-
end
115-
116-
function restructure(x::CuArrays.CuArray, y)
117-
return reshape(Adapt.adapt(parameterless_type(x), y), Base.size(x)...)
118-
end
119-
120-
device(::Type{<:CuArrays.CuArray}) = GPU()
104+
include("cuarrays.jl")
121105
end
122106
@require DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" begin
123107
# actually do QR
@@ -129,30 +113,15 @@ function __init__()
129113

130114
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin
131115
@require Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
132-
fast_scalar_indexing(::Type{<:CUDA.CuArray}) = false
133-
@inline allowed_getindex(x::CUDA.CuArray, i...) = CUDA.@allowscalar(x[i...])
134-
@inline allowed_setindex!(x::CUDA.CuArray, v, i...) = (CUDA.@allowscalar(x[i...] = v))
135-
136-
function Base.setindex(x::CUDA.CuArray, v, i::Int)
137-
_x = copy(x)
138-
allowed_setindex!(_x, v, i)
139-
return _x
140-
end
141-
142-
function restructure(x::CUDA.CuArray, y)
143-
return reshape(Adapt.adapt(parameterless_type(x), y), Base.size(x)...)
144-
end
145-
146-
device(::Type{<:CUDA.CuArray}) = GPU()
116+
include("cuarrays2.jl")
147117
end
148118
@require DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" begin
149119
# actually do QR
150-
function uu_instance(A::CUDA.CuMatrix{T}) where {T}
120+
function lu_instance(A::CUDA.CuMatrix{T}) where {T}
151121
return CUDA.CUSOLVER.CuQR(similar(A, 0, 0), similar(A, 0))
152122
end
153123
end
154124
end
155-
156125
@require BandedMatrices = "aae01518-5342-5314-be14-df237901396f" begin
157126
struct BandedMatrixIndex <: MatrixIndex
158127
count::Int

src/cuarrays.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
fast_scalar_indexing(::Type{<:CuArrays.CuArray}) = false
3+
@inline allowed_getindex(x::CuArrays.CuArray, i...) = CuArrays.@allowscalar(x[i...])
4+
@inline function allowed_setindex!(x::CuArrays.CuArray, v, i...)
5+
return (CuArrays.@allowscalar(x[i...] = v))
6+
end
7+
8+
function Base.setindex(x::CuArrays.CuArray, v, i::Int)
9+
_x = copy(x)
10+
allowed_setindex!(_x, v, i)
11+
return _x
12+
end
13+
14+
function restructure(x::CuArrays.CuArray, y)
15+
return reshape(Adapt.adapt(parameterless_type(x), y), Base.size(x)...)
16+
end
17+
18+
device(::Type{<:CuArrays.CuArray}) = GPU()

src/cuarrays2.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
fast_scalar_indexing(::Type{<:CUDA.CuArray}) = false
3+
@inline allowed_getindex(x::CUDA.CuArray, i...) = CUDA.@allowscalar(x[i...])
4+
@inline allowed_setindex!(x::CUDA.CuArray, v, i...) = (CUDA.@allowscalar(x[i...] = v))
5+
6+
function Base.setindex(x::CUDA.CuArray, v, i::Int)
7+
_x = copy(x)
8+
allowed_setindex!(_x, v, i)
9+
return _x
10+
end
11+
12+
function restructure(x::CUDA.CuArray, y)
13+
return reshape(Adapt.adapt(parameterless_type(x), y), Base.size(x)...)
14+
end
15+
16+
device(::Type{<:CUDA.CuArray}) = GPU()

0 commit comments

Comments
 (0)