Skip to content

Commit ee294bc

Browse files
authored
Activate GPU device before resizing arrays (#91)
1 parent 1db7d91 commit ee294bc

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
55

66
## Unreleased
77

8+
### Fixed
9+
10+
- Fix multi-GPU issue on AMDGPU.
11+
On AMDGPU (not sure about other backends), when an array on device X is `resize!`d while
12+
device Y is activated, the array is then silently "transferred" to device Y. For this
13+
reason, we now resize arrays in the same device where they were initially created. This
14+
might be a bug in AMDGPU.jl.
15+
816
## [0.32.16] - 2026-01-28
917

1018
### Changed

src/BiotSavart/BiotSavart.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ function do_longrange!(
438438
# Copy point data to the cache (possibly on a GPU).
439439
@assert pointdata_cpu !== pointdata # they are different objects
440440
GC.@preserve pointdata begin # see docs for KA.copyto! (it shouldn't really be needed here)
441+
# Resize output arrays
442+
noutputs = length(pointdata_cpu.nodes)
443+
foreach(v -> resize_no_copy!(v, noutputs), outputs)
444+
441445
@timeit to "Copy point charges (host -> device)" begin
442446
# Only copy fields needed for long-range computations
443447
copy_host_to_device!(pointdata.nodes, pointdata_cpu.nodes)
@@ -491,6 +495,10 @@ function do_shortrange!(cache::ShortRangeCache, outputs::NamedTuple, pointdata_c
491495

492496
@timeit to "Short-range component (async)" begin
493497
GC.@preserve pointdata begin # see docs for KA.copyto! (it shouldn't really be needed here)
498+
# Resize output arrays
499+
noutputs = length(pointdata_cpu.nodes)
500+
foreach(v -> resize_no_copy!(v, noutputs), outputs) # resize output arrays
501+
494502
if LIA === Val(true) || LIA === Val(:only)
495503
@timeit to "Copy point charges (host -> device)" begin
496504
# For now, only copy what we need for local term
@@ -543,15 +551,13 @@ function _compute_on_nodes!(
543551
# TODO: skip unneeded quantities if LIA === Val(:only) or LIA === Val(false)
544552
@timeit to "Add point charges" add_point_charges!(pointdata, fs, params) # done on the CPU
545553

546-
noutputs = sum(length, fs) # total number of interpolation points
547554
channel = Channel{Symbol}(2) # 2 is the length of the channel (for :shortrange + :longrange)
548555
tasks = Task[]
549556

550557
if with_longrange
551558
let cache = cache.longrange
552559
# Select elements of outputs with the same names as in `fields` (in this case :velocity and/or :streamfunction).
553560
local outputs = NamedTuple{keys(fields)}(cache.outputs)
554-
foreach(v -> resize_no_copy!(v, noutputs), outputs) # resize output arrays
555561
# Compute long-range part asynchronously (e.g. on a GPU).
556562
local task = Threads.@spawn :interactive try
557563
do_longrange!(cache, outputs, pointdata; callback_vorticity)
@@ -567,7 +573,6 @@ function _compute_on_nodes!(
567573
let cache = cache.shortrange
568574
# Select elements of outputs with the same names as in `fields` (in this case :velocity and/or :streamfunction).
569575
local outputs = NamedTuple{keys(fields)}(cache.outputs)
570-
foreach(v -> resize_no_copy!(v, noutputs), outputs) # resize output arrays
571576
# Compute short-range part asynchronously (e.g. on a GPU).
572577
local task = Threads.@spawn :interactive try
573578
do_shortrange!(cache, outputs, pointdata; LIA)

src/BiotSavart/shortrange/cache_common.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ end
1313
function ShortRangeCacheCommon(params::ParamsShortRange, pointdata_in::PointData)
1414
(; backend,) = params
1515
ka_backend = KA.get_backend(backend) # CPU, CUDABackend, ROCBackend, ...
16+
# Make sure we've activated the device (e.g. GPU id) where short-range computations will
17+
# be performed. We need arrays to be allocated in that device.
18+
expected_device = KA.device(backend) # 1, 2, ...
19+
@assert KA.device(ka_backend) == expected_device
1620
pointdata = adapt(ka_backend, pointdata_in) # create PointData replica on the device if needed
1721
if pointdata === pointdata_in # basically if ka_backend isa CPU
1822
pointdata = copy(pointdata_in) # make sure pointdata and pointdata_in are not aliased!

0 commit comments

Comments
 (0)