Skip to content

Commit a147370

Browse files
Merge pull request #54 from SciML/cuda
allow CUDA
2 parents 79ff5bf + e98e220 commit a147370

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

src/ArrayInterface.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,16 @@ function __init__()
551551
end
552552
end
553553

554+
@require CUDA="3a865a2d-5b23-5a0f-bc46-62713ec82fae" begin
555+
@require Adapt="79e6a3ab-5dfb-504d-930d-738a2a938a0e" begin
556+
include("cuarrays2.jl")
557+
end
558+
@require DiffEqBase="2b5f629d-d688-5b77-993f-72d75c75574e" begin
559+
# actually do QR
560+
lu_instance(A::CUDA.CuMatrix{T}) where T = CUDA.CUSOLVER.CuQR(similar(A, 0, 0), similar(A, 0))
561+
end
562+
end
563+
554564
@require BandedMatrices="aae01518-5342-5314-be14-df237901396f" begin
555565
function findstructralnz(x::BandedMatrices.BandedMatrix)
556566
l,u=BandedMatrices.bandwidths(x)

src/cuarrays2.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
fast_scalar_indexing(::Type{<:CUDA.CuArray}) = false
2+
@inline allowed_getindex(x::CUDA.CuArray,i...) = CUDA.@allowscalar(x[i...])
3+
@inline allowed_setindex!(x::CUDA.CuArray,v,i...) = (CUDA.@allowscalar(x[i...] = v))
4+
5+
function Base.setindex(x::CUDA.CuArray,v,i::Int)
6+
_x = copy(x)
7+
allowed_setindex!(_x,v,i)
8+
_x
9+
end
10+
11+
function restructure(x::CUDA.CuArray,y)
12+
reshape(Adapt.adapt(parameterless_type(x),y),size(x)...)
13+
end

0 commit comments

Comments
 (0)