Skip to content

Added support for CUDA#228

Open
albertomercurio wants to merge 2 commits intoJuliaNLSolvers:masterfrom
albertomercurio:master
Open

Added support for CUDA#228
albertomercurio wants to merge 2 commits intoJuliaNLSolvers:masterfrom
albertomercurio:master

Conversation

@albertomercurio
Copy link

Hello,

I added the support for CUDA GPUs, as requested from issue #218

The following code should work now

using CUDA
CUDA.allowscalar(false)
using LsqFit

function model(x,p)
    a, b = Array(p)
    @. a*exp(-x*b)
end

xdata = collect(range(0, stop=10, length=20))
ydata = model(xdata, [1.0 2.0]) + 0.01*randn(length(xdata))
p0 = [0.5, 0.5]

xdata = CuArray(xdata)
ydata = CuArray(ydata)
p0 = CuArray(p0)

fit = curve_fit(model, xdata, ydata, p0)
fit.param


2-element CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}:
 1.0155557857704927
 1.9932445695606529

The only bottleneck is that I need to call Array(p) inside the model function, since most of the times I need to separate the parameters, and it is not allowed for CUDA Arrays. So the GPU have to transfer the data to the CPU.

Another bottleneck is that the jacobian calls ArrayInterfaceCore.allowed_getindex that again transfer the element to the CPU.

However it works! It could be fine if in the future someone will test this feature. I guess it will work fine for very large datasets, like the fit on 2d datas and so on.

@albertomercurio
Copy link
Author

Is there anyone available to review this PR? I made only few changes.

@pkofod
Copy link
Member

pkofod commented Jun 12, 2023

Is ArrayInterfacesCUDA necessary for diagind ? You seemed to commit a Manifest, and I would also ask you to add some tests of the functionality. Thank you for the contribution even if it took me way too long to notice.

@singularitti
Copy link
Contributor

May I help to make this work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants