Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c8133e4
Add array namespace option for field buffers
kburns May 27, 2025
9d45669
Add array-api-compat to setup.py
kburns May 27, 2025
6db9593
Allow specifying array namespace by string
kburns May 27, 2025
a39b345
Try fixing cupy allocation from buffer
kburns May 27, 2025
dd1f0f7
Fix cupy check
kburns May 27, 2025
62ee03b
Add cupy-based complex fourier MMT
kburns May 27, 2025
68bbd21
Fix transform lookup
kburns May 27, 2025
dce5d99
Make fill_random array and dtype compatible
kburns May 27, 2025
e189f41
Work on cupy real fourier MMTs
kburns May 27, 2025
cf8644d
Generalize Fourier basis for more dtypes
kburns May 27, 2025
2fb0d32
Add cupy complex FFT
kburns May 28, 2025
4fbe35b
Add cupy real fft
kburns May 28, 2025
8c7985d
Fix dtype conversion
kburns May 28, 2025
d6a4525
Add array compat for basic arithmetic
kburns May 28, 2025
6d05ff0
Beginning adding array_compat to operators
kburns May 28, 2025
79d789c
Quick implementation of apply_sparse for cupy
kburns Jul 22, 2025
fb9b3d6
Make einsum in dot compatible with cupy
kburns Jul 22, 2025
1e29a80
Add custom kernel for cupy csr middle dot product
kburns Jul 22, 2025
d240656
Convert local grids/modes to device arrays
kburns Jul 22, 2025
644f3bf
Explicitly cast data norms to float
kburns Jul 22, 2025
ef9091b
Cast grid spacing to device array in cartesian cfl
kburns Jul 22, 2025
426cad7
Convert field data gathers to numpy on gpu
kburns Jul 22, 2025
c9f5bda
Fix subsystem gather/scatter to copy to/from gpu
kburns Jul 22, 2025
63f4033
Allow for non-contiguous device copy
kburns Jul 22, 2025
68e2cb2
Fix cupy csr kernel for double instead of float
kburns Jul 22, 2025
9421231
Move subsystems, coeff systems, and matrices to GPU
kburns Jul 25, 2025
15a2d6e
Build custom cupy superlu wrapper to reuse spsm descriptors
kburns Jul 25, 2025
2e674b7
Move all operator matrices to device. Add Chebyshev transforms
kburns Jul 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions dedalus/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numexpr as ne
from collections import defaultdict
from math import prod
import array_api_compat

from .domain import Domain
from .field import Operand, Field
Expand Down Expand Up @@ -245,10 +246,11 @@ def choose_layout(self):

def operate(self, out):
"""Perform operation."""
xp = self.array_namespace
arg0, arg1 = self.args
# Set output layout
out.preset_layout(arg0.layout)
np.add(arg0.data, arg1.data, out=out.data)
xp.add(arg0.data, arg1.data, out=out.data)


# used for einsum string manipulation
Expand Down Expand Up @@ -664,14 +666,19 @@ def GammaCoord(self, A_tensorsig, B_tensorsig, C_tensorsig):
return G

def operate(self, out):
xp = self.array_namespace
arg0, arg1 = self.args
out.preset_layout(arg0.layout)
# Broadcast
arg0_data = self.arg0_ghost_broadcaster.cast(arg0)
arg1_data = self.arg1_ghost_broadcaster.cast(arg1)
# Call einsum
if out.data.size:
np.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True)
if array_api_compat.is_cupy_namespace(xp):
# Cupy does not support output keyword
out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=True)
else:
xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True)


@alias("cross")
Expand Down Expand Up @@ -854,6 +861,7 @@ def __init__(self, arg0, arg1, out=None, **kw):

def operate(self, out):
"""Perform operation."""
xp = self.array_namespace
arg0, arg1 = self.args
# Set output layout
out.preset_layout(arg0.layout)
Expand All @@ -863,7 +871,7 @@ def operate(self, out):
# Reshape arg data to broadcast properly for output tensorsig
arg0_exp_data = arg0_data.reshape(self.arg0_exp_tshape + arg0_data.shape[len(arg0.tensorsig):])
arg1_exp_data = arg1_data.reshape(self.arg1_exp_tshape + arg1_data.shape[len(arg1.tensorsig):])
np.multiply(arg0_exp_data, arg1_exp_data, out=out.data)
xp.multiply(arg0_exp_data, arg1_exp_data, out=out.data)


class GhostBroadcaster:
Expand Down Expand Up @@ -939,11 +947,12 @@ def enforce_conditions(self):

def operate(self, out):
"""Perform operation."""
xp = self.array_namespace
arg0, arg1 = self.args
# Set output layout
out.preset_layout(arg1.layout)
# Multiply argument data
np.multiply(arg0, arg1.data, out=out.data)
xp.multiply(arg0, arg1.data, out=out.data)

def matrix_dependence(self, *vars):
return self.args[1].matrix_dependence(*vars)
Expand Down
23 changes: 16 additions & 7 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..tools import clenshaw
from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices
from ..tools.dispatch import MultiClass, SkipDispatchException
from ..tools.general import unify, DeferredTuple
from ..tools.general import unify, DeferredTuple, is_real_dtype, is_complex_dtype
from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct
from .domain import Domain
from .field import Operand, LockedField
Expand Down Expand Up @@ -506,7 +506,13 @@ def _native_grid(self, scale):
@CachedMethod
def transform_plan(self, dist, grid_size):
"""Build transform plan."""
return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0)
xp = dist.array_namespace
xp_name = xp.__name__.split('.')[-1]
# Shortcut trivial transforms
if grid_size == 1 or self.size == 1:
return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype)
else:
return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype)

# def weights(self, scales):
# """Gauss-Jacobi weights."""
Expand Down Expand Up @@ -915,11 +921,13 @@ def _native_grid(self, scale):
@CachedMethod
def transform_plan(self, dist, grid_size):
"""Build transform plan."""
xp = dist.array_namespace
xp_name = xp.__name__.split('.')[-1]
# Shortcut trivial transforms
if grid_size == 1 or self.size == 1:
return self.transforms['matrix'](grid_size, self.size)
return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype)
else:
return self.transforms[self.library](grid_size, self.size)
return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype)

def forward_transform(self, field, axis, gdata, cdata):
# Transform
Expand All @@ -940,9 +948,9 @@ def Fourier(*args, dtype=None, **kw):
"""Factory function dispatching to RealFourier and ComplexFourier based on provided dtype."""
if dtype is None:
raise ValueError("dtype must be specified")
elif dtype == np.float64:
elif is_real_dtype(dtype):
return RealFourier(*args, **kw)
elif dtype == np.complex128:
elif is_complex_dtype(dtype):
return ComplexFourier(*args, **kw)
else:
raise ValueError(f"Unrecognized dtype: {dtype}")
Expand Down Expand Up @@ -6081,6 +6089,7 @@ class CartesianAdvectiveCFL(operators.AdvectiveCFL):

@CachedMethod
def cfl_spacing(self):
xp = self.array_namespace
velocity = self.operand
coordsys = velocity.tensorsig[0]
spacing = []
Expand All @@ -6102,7 +6111,7 @@ def cfl_spacing(self):
axis_spacing[:] = dealias * native_spacing * basis.COV.stretch
elif basis is None:
axis_spacing = np.inf
spacing.append(axis_spacing)
spacing.append(xp.asarray(axis_spacing))
return spacing

def compute_cfl_frequency(self, velocity, out):
Expand Down
17 changes: 13 additions & 4 deletions dedalus/core/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from math import prod
import numbers
from weakref import WeakSet
import array_api_compat

from .coords import CoordinateSystem, DirectProduct
from ..tools.array import reshape_vector
Expand Down Expand Up @@ -74,7 +75,7 @@ class Distributor:
states) and the paths between them (D transforms and R transposes).
"""

def __init__(self, coordsystems, comm=None, mesh=None, dtype=None):
def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespace=np):
# Accept single coordsys in place of tuple/list
if not isinstance(coordsystems, (tuple, list)):
coordsystems = (coordsystems,)
Expand Down Expand Up @@ -115,6 +116,11 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None):
self._build_layouts()
# Keep set of weak field references
self.fields = WeakSet()
# Array module
if isinstance(array_namespace, str):
self.array_namespace = getattr(array_api_compat, array_namespace)
else:
self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0))

@CachedAttribute
def cs_by_axis(self):
Expand Down Expand Up @@ -255,11 +261,12 @@ def IdentityTensor(self, coordsys_in, coordsys_out=None, bases=None, dtype=None)
return I

def local_grid(self, basis, scale=None):
xp = self.array_namespace
# TODO: remove from bases and do it all here?
if scale is None:
scale = 1
if basis.dim == 1:
return basis.local_grid(self, scale=scale)
return xp.asarray(basis.local_grid(self, scale=scale))
else:
raise ValueError("Use `local_grids` for multidimensional bases.")

Expand Down Expand Up @@ -292,16 +299,18 @@ def local_grid(self, basis, scale=None):
# return tuple(grids)

def local_grids(self, *bases, scales=None):
xp = self.array_namespace
scales = self.remedy_scales(scales)
grids = []
for basis in bases:
basis_scales = scales[self.first_axis(basis):self.last_axis(basis)+1]
grids.extend(basis.local_grids(self, scales=basis_scales))
grids.extend(xp.asarray(basis.local_grids(self, scales=basis_scales)))
return grids

def local_modes(self, basis):
# TODO: remove from bases and do it all here?
return basis.local_modes(self)
xp = self.array_namespace
return xp.asarray(basis.local_modes(self))

@CachedAttribute
def default_nonconst_groups(self):
Expand Down
62 changes: 41 additions & 21 deletions dedalus/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import partial, reduce
from collections import defaultdict
import numpy as np
import array_api_compat
from mpi4py import MPI
from scipy import sparse
from scipy.sparse import linalg as splinalg
Expand Down Expand Up @@ -473,16 +474,19 @@ def evaluate(self):
def reinitialize(self, **kw):
return self

@staticmethod
def _create_buffer(buffer_size):
def _create_buffer(self, buffer_size):
"""Create buffer for Field data."""
if buffer_size == 0:
# FFTW doesn't like allocating size-0 arrays
return np.zeros((0,), dtype=np.float64)
xp = self.array_namespace
if xp == np:
if buffer_size == 0:
# FFTW doesn't like allocating size-0 arrays
return np.zeros((0,), dtype=np.float64)
else:
# Use FFTW SIMD aligned allocation
alloc_doubles = buffer_size // 8
return fftw.create_buffer(alloc_doubles)
else:
# Use FFTW SIMD aligned allocation
alloc_doubles = buffer_size // 8
return fftw.create_buffer(alloc_doubles)
return xp.zeros(buffer_size)

@CachedAttribute
def _dealias_buffer_size(self):
Expand Down Expand Up @@ -516,14 +520,17 @@ def preset_scales(self, scales):

def preset_layout(self, layout):
"""Interpret buffer as data in specified layout."""
xp = self.array_namespace
layout = self.dist.get_layout_object(layout)
self.layout = layout
tens_shape = [vs.dim for vs in self.tensorsig]
local_shape = layout.local_shape(self.domain, self.scales)
total_shape = tuple(tens_shape) + tuple(local_shape)
self.data = np.ndarray(shape=total_shape,
dtype=self.dtype,
buffer=self.buffer)
# Create view into buffer
if array_api_compat.is_cupy_namespace(xp):
self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data)
else:
self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer)
#self.global_start = layout.start(self.domain, self.scales)


Expand Down Expand Up @@ -561,6 +568,7 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None):
dtype = dist.dtype
from .domain import Domain
self.dist = dist
self.array_namespace = dist.array_namespace
self.name = name
self.tensorsig = tensorsig
self.dtype = dtype
Expand Down Expand Up @@ -774,9 +782,15 @@ def allgather_data(self, layout=None):
# Change layout
if layout is not None:
self.change_layout(layout)
# Convert to numpy if on GPU
xp = self.dist.array_namespace
if array_api_compat.is_cupy_namespace(xp):
data = xp.asnumpy(self.data)
else:
data = self.data.copy()
# Shortcut for serial execution
if self.dist.comm.size == 1:
return self.data.copy()
return data
# Build global buffers
tensor_shape = tuple(cs.dim for cs in self.tensorsig)
global_shape = tensor_shape + self.layout.global_shape(self.domain, self.scales)
Expand All @@ -785,21 +799,27 @@ def allgather_data(self, layout=None):
recv_buff = np.empty_like(send_buff)
# Combine data via allreduce -- easy but not communication-optimal
# Should be optimized using Allgatherv if this is used past startup
send_buff[local_slices] = self.data
send_buff[local_slices] = data
self.dist.comm.Allreduce(send_buff, recv_buff, op=MPI.SUM)
return recv_buff

def gather_data(self, root=0, layout=None):
# Change layout
if layout is not None:
self.change_layout(layout)
# Convert to numpy if on GPU
xp = self.dist.array_namespace
if array_api_compat.is_cupy_namespace(xp):
data = xp.asnumpy(self.data)
else:
data = self.data.copy()
# Shortcut for serial execution
if self.dist.comm.size == 1:
return self.data.copy()
return data
# TODO: Shortcut this for constant fields
# Gather data
# Should be optimized via Gatherv eventually
pieces = self.dist.comm.gather(self.data, root=root)
pieces = self.dist.comm.gather(data, root=root)
# Assemble on root node
if self.dist.comm.rank == root:
ext_mesh = self.layout.ext_mesh
Expand All @@ -826,7 +846,7 @@ def allreduce_data_norm(self, layout=None, order=2):
if self.dist.comm.size > 1:
norm = self.dist.comm.allreduce(norm, op=MPI.SUM)
norm = norm ** (1 / order)
return norm
return float(norm)

def allreduce_data_max(self, layout=None):
return self.allreduce_data_norm(layout=layout, order=np.inf)
Expand Down Expand Up @@ -907,6 +927,7 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis
**kw : dict
Other keywords passed to the distribution method.
"""
xp = self.dist.array_namespace
init_layout = self.layout
# Set scales if requested
if scales is not None:
Expand All @@ -926,11 +947,10 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis
spatial_slices = self.layout.slices(self.domain, self.scales)
local_slices = component_slices + spatial_slices
local_data = global_data[local_slices]
if self.is_real:
self.data[:] = local_data
else:
self.data.real[:] = local_data[..., 0]
self.data.imag[:] = local_data[..., 1]
if self.is_complex:
local_data = local_data[..., 0] + 1j * local_data[..., 1]
# Copy to field data
self.data[:] = xp.asarray(local_data, dtype=self.dtype)

def low_pass_filter(self, shape=None, scales=None):
"""
Expand Down
1 change: 1 addition & 0 deletions dedalus/core/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, *args, out=None):
self.original_args = tuple(args)
self.out = out
self.dist = unify_attributes(args, 'dist', require=False)
self.array_namespace = self.dist.array_namespace
#self.domain = Domain(self.dist, self.bases)
self._grid_layout = self.dist.grid_layout
self._coeff_layout = self.dist.coeff_layout
Expand Down
Loading