diff --git a/dedalus/core/arithmetic.py b/dedalus/core/arithmetic.py index 55c70e53..647c47a6 100644 --- a/dedalus/core/arithmetic.py +++ b/dedalus/core/arithmetic.py @@ -13,6 +13,7 @@ import numexpr as ne from collections import defaultdict from math import prod +import array_api_compat from .domain import Domain from .field import Operand, Field @@ -245,10 +246,11 @@ def choose_layout(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) - np.add(arg0.data, arg1.data, out=out.data) + xp.add(arg0.data, arg1.data, out=out.data) # used for einsum string manipulation @@ -664,6 +666,7 @@ def GammaCoord(self, A_tensorsig, B_tensorsig, C_tensorsig): return G def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args out.preset_layout(arg0.layout) # Broadcast @@ -671,7 +674,11 @@ def operate(self, out): arg1_data = self.arg1_ghost_broadcaster.cast(arg1) # Call einsum if out.data.size: - np.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) + if array_api_compat.is_cupy_namespace(xp): + # Cupy does not support output keyword + out.data[:] = xp.einsum(self.einsum_str, arg0_data, arg1_data, optimize=True) + else: + xp.einsum(self.einsum_str, arg0_data, arg1_data, out=out.data, optimize=True) @alias("cross") @@ -854,6 +861,7 @@ def __init__(self, arg0, arg1, out=None, **kw): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg0.layout) @@ -863,7 +871,7 @@ def operate(self, out): # Reshape arg data to broadcast properly for output tensorsig arg0_exp_data = arg0_data.reshape(self.arg0_exp_tshape + arg0_data.shape[len(arg0.tensorsig):]) arg1_exp_data = arg1_data.reshape(self.arg1_exp_tshape + arg1_data.shape[len(arg1.tensorsig):]) - np.multiply(arg0_exp_data, arg1_exp_data, out=out.data) + xp.multiply(arg0_exp_data, arg1_exp_data, out=out.data) class GhostBroadcaster: @@ -939,11 +947,12 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0, arg1 = self.args # Set output layout out.preset_layout(arg1.layout) # Multiply argument data - np.multiply(arg0, arg1.data, out=out.data) + xp.multiply(arg0, arg1.data, out=out.data) def matrix_dependence(self, *vars): return self.args[1].matrix_dependence(*vars) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index 3f6167c9..fe44931c 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -14,7 +14,7 @@ from ..tools import clenshaw from ..tools.array import reshape_vector, axindex, axslice, interleave_matrices from ..tools.dispatch import MultiClass, SkipDispatchException -from ..tools.general import unify, DeferredTuple +from ..tools.general import unify, DeferredTuple, is_real_dtype, is_complex_dtype from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate, DirectProduct from .domain import Domain from .field import Operand, LockedField @@ -506,7 +506,13 @@ def _native_grid(self, scale): @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" - return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0) + xp = dist.array_namespace + xp_name = xp.__name__.split('.')[-1] + # Shortcut trivial transforms + if grid_size == 1 or self.size == 1: + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) + else: + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, self.a, self.b, self.a0, self.b0, dist.array_namespace, dist.dtype) # def weights(self, scales): # """Gauss-Jacobi weights.""" @@ -915,11 +921,13 @@ def _native_grid(self, scale): @CachedMethod def transform_plan(self, dist, grid_size): """Build transform plan.""" + xp = dist.array_namespace + xp_name = xp.__name__.split('.')[-1] # Shortcut trivial transforms if grid_size == 1 or self.size == 1: - return self.transforms['matrix'](grid_size, self.size) + return self.transforms[f"matrix-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) else: - return self.transforms[self.library](grid_size, self.size) + return self.transforms[f"{self.library}-{xp_name}"](grid_size, self.size, dist.array_namespace, dist.dtype) def forward_transform(self, field, axis, gdata, cdata): # Transform @@ -940,9 +948,9 @@ def Fourier(*args, dtype=None, **kw): """Factory function dispatching to RealFourier and ComplexFourier based on provided dtype.""" if dtype is None: raise ValueError("dtype must be specified") - elif dtype == np.float64: + elif is_real_dtype(dtype): return RealFourier(*args, **kw) - elif dtype == np.complex128: + elif is_complex_dtype(dtype): return ComplexFourier(*args, **kw) else: raise ValueError(f"Unrecognized dtype: {dtype}") @@ -6081,6 +6089,7 @@ class CartesianAdvectiveCFL(operators.AdvectiveCFL): @CachedMethod def cfl_spacing(self): + xp = self.array_namespace velocity = self.operand coordsys = velocity.tensorsig[0] spacing = [] @@ -6102,7 +6111,7 @@ def cfl_spacing(self): axis_spacing[:] = dealias * native_spacing * basis.COV.stretch elif basis is None: axis_spacing = np.inf - spacing.append(axis_spacing) + spacing.append(xp.asarray(axis_spacing)) return spacing def compute_cfl_frequency(self, velocity, out): diff --git a/dedalus/core/distributor.py b/dedalus/core/distributor.py index c4cc766f..99ada2e9 100644 --- a/dedalus/core/distributor.py +++ b/dedalus/core/distributor.py @@ -10,6 +10,7 @@ from math import prod import numbers from weakref import WeakSet +import array_api_compat from .coords import CoordinateSystem, DirectProduct from ..tools.array import reshape_vector @@ -74,7 +75,7 @@ class Distributor: states) and the paths between them (D transforms and R transposes). """ - def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): + def __init__(self, coordsystems, comm=None, mesh=None, dtype=None, array_namespace=np): # Accept single coordsys in place of tuple/list if not isinstance(coordsystems, (tuple, list)): coordsystems = (coordsystems,) @@ -115,6 +116,11 @@ def __init__(self, coordsystems, comm=None, mesh=None, dtype=None): self._build_layouts() # Keep set of weak field references self.fields = WeakSet() + # Array module + if isinstance(array_namespace, str): + self.array_namespace = getattr(array_api_compat, array_namespace) + else: + self.array_namespace = array_api_compat.array_namespace(array_namespace.zeros(0)) @CachedAttribute def cs_by_axis(self): @@ -255,11 +261,12 @@ def IdentityTensor(self, coordsys_in, coordsys_out=None, bases=None, dtype=None) return I def local_grid(self, basis, scale=None): + xp = self.array_namespace # TODO: remove from bases and do it all here? if scale is None: scale = 1 if basis.dim == 1: - return basis.local_grid(self, scale=scale) + return xp.asarray(basis.local_grid(self, scale=scale)) else: raise ValueError("Use `local_grids` for multidimensional bases.") @@ -292,16 +299,18 @@ def local_grid(self, basis, scale=None): # return tuple(grids) def local_grids(self, *bases, scales=None): + xp = self.array_namespace scales = self.remedy_scales(scales) grids = [] for basis in bases: basis_scales = scales[self.first_axis(basis):self.last_axis(basis)+1] - grids.extend(basis.local_grids(self, scales=basis_scales)) + grids.extend(xp.asarray(basis.local_grids(self, scales=basis_scales))) return grids def local_modes(self, basis): # TODO: remove from bases and do it all here? - return basis.local_modes(self) + xp = self.array_namespace + return xp.asarray(basis.local_modes(self)) @CachedAttribute def default_nonconst_groups(self): diff --git a/dedalus/core/field.py b/dedalus/core/field.py index 415edcf6..93330740 100644 --- a/dedalus/core/field.py +++ b/dedalus/core/field.py @@ -7,6 +7,7 @@ from functools import partial, reduce from collections import defaultdict import numpy as np +import array_api_compat from mpi4py import MPI from scipy import sparse from scipy.sparse import linalg as splinalg @@ -473,16 +474,19 @@ def evaluate(self): def reinitialize(self, **kw): return self - @staticmethod - def _create_buffer(buffer_size): + def _create_buffer(self, buffer_size): """Create buffer for Field data.""" - if buffer_size == 0: - # FFTW doesn't like allocating size-0 arrays - return np.zeros((0,), dtype=np.float64) + xp = self.array_namespace + if xp == np: + if buffer_size == 0: + # FFTW doesn't like allocating size-0 arrays + return np.zeros((0,), dtype=np.float64) + else: + # Use FFTW SIMD aligned allocation + alloc_doubles = buffer_size // 8 + return fftw.create_buffer(alloc_doubles) else: - # Use FFTW SIMD aligned allocation - alloc_doubles = buffer_size // 8 - return fftw.create_buffer(alloc_doubles) + return xp.zeros(buffer_size) @CachedAttribute def _dealias_buffer_size(self): @@ -516,14 +520,17 @@ def preset_scales(self, scales): def preset_layout(self, layout): """Interpret buffer as data in specified layout.""" + xp = self.array_namespace layout = self.dist.get_layout_object(layout) self.layout = layout tens_shape = [vs.dim for vs in self.tensorsig] local_shape = layout.local_shape(self.domain, self.scales) total_shape = tuple(tens_shape) + tuple(local_shape) - self.data = np.ndarray(shape=total_shape, - dtype=self.dtype, - buffer=self.buffer) + # Create view into buffer + if array_api_compat.is_cupy_namespace(xp): + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, memptr=self.buffer.data) + else: + self.data = xp.ndarray(shape=total_shape, dtype=self.dtype, buffer=self.buffer) #self.global_start = layout.start(self.domain, self.scales) @@ -561,6 +568,7 @@ def __init__(self, dist, bases=None, name=None, tensorsig=None, dtype=None): dtype = dist.dtype from .domain import Domain self.dist = dist + self.array_namespace = dist.array_namespace self.name = name self.tensorsig = tensorsig self.dtype = dtype @@ -774,9 +782,15 @@ def allgather_data(self, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # Build global buffers tensor_shape = tuple(cs.dim for cs in self.tensorsig) global_shape = tensor_shape + self.layout.global_shape(self.domain, self.scales) @@ -785,7 +799,7 @@ def allgather_data(self, layout=None): recv_buff = np.empty_like(send_buff) # Combine data via allreduce -- easy but not communication-optimal # Should be optimized using Allgatherv if this is used past startup - send_buff[local_slices] = self.data + send_buff[local_slices] = data self.dist.comm.Allreduce(send_buff, recv_buff, op=MPI.SUM) return recv_buff @@ -793,13 +807,19 @@ def gather_data(self, root=0, layout=None): # Change layout if layout is not None: self.change_layout(layout) + # Convert to numpy if on GPU + xp = self.dist.array_namespace + if array_api_compat.is_cupy_namespace(xp): + data = xp.asnumpy(self.data) + else: + data = self.data.copy() # Shortcut for serial execution if self.dist.comm.size == 1: - return self.data.copy() + return data # TODO: Shortcut this for constant fields # Gather data # Should be optimized via Gatherv eventually - pieces = self.dist.comm.gather(self.data, root=root) + pieces = self.dist.comm.gather(data, root=root) # Assemble on root node if self.dist.comm.rank == root: ext_mesh = self.layout.ext_mesh @@ -826,7 +846,7 @@ def allreduce_data_norm(self, layout=None, order=2): if self.dist.comm.size > 1: norm = self.dist.comm.allreduce(norm, op=MPI.SUM) norm = norm ** (1 / order) - return norm + return float(norm) def allreduce_data_max(self, layout=None): return self.allreduce_data_norm(layout=layout, order=np.inf) @@ -907,6 +927,7 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis **kw : dict Other keywords passed to the distribution method. """ + xp = self.dist.array_namespace init_layout = self.layout # Set scales if requested if scales is not None: @@ -926,11 +947,10 @@ def fill_random(self, layout=None, scales=None, seed=None, chunk_size=2**20, dis spatial_slices = self.layout.slices(self.domain, self.scales) local_slices = component_slices + spatial_slices local_data = global_data[local_slices] - if self.is_real: - self.data[:] = local_data - else: - self.data.real[:] = local_data[..., 0] - self.data.imag[:] = local_data[..., 1] + if self.is_complex: + local_data = local_data[..., 0] + 1j * local_data[..., 1] + # Copy to field data + self.data[:] = xp.asarray(local_data, dtype=self.dtype) def low_pass_filter(self, shape=None, scales=None): """ diff --git a/dedalus/core/future.py b/dedalus/core/future.py index 58f9cd9d..ab07e8ff 100644 --- a/dedalus/core/future.py +++ b/dedalus/core/future.py @@ -51,6 +51,7 @@ def __init__(self, *args, out=None): self.original_args = tuple(args) self.out = out self.dist = unify_attributes(args, 'dist', require=False) + self.array_namespace = self.dist.array_namespace #self.domain = Domain(self.dist, self.bases) self._grid_layout = self.dist.grid_layout self._coeff_layout = self.dist.coeff_layout diff --git a/dedalus/core/operators.py b/dedalus/core/operators.py index 9a9a993d..3655fe72 100644 --- a/dedalus/core/operators.py +++ b/dedalus/core/operators.py @@ -15,6 +15,7 @@ from math import prod from ..libraries import dedalus_sphere import logging +import array_api_compat logger = logging.getLogger(__name__.split('.')[-1]) from .domain import Domain @@ -378,11 +379,12 @@ def enforce_conditions(self): arg0.require_grid_space() def operate(self, out): + xp = self.array_namespace arg0, arg1 = self.args # Multiply in grid layout out.preset_layout(arg0.layout) if out.data.size: - np.power(arg0.data, arg1, out.data) + xp.power(arg0.data, arg1, out.data) def new_operands(self, arg0, arg1, **kw): return Power(arg0, arg1) @@ -498,8 +500,9 @@ def enforce_conditions(self): self.args[i].change_layout(self.layout) def operate(self, out): + xp = self.array_namespace out.preset_layout(self.layout) - np.copyto(out.data, self.func(*self.args, **self.kw)) + xp.copyto(out.data, self.func(*self.args, **self.kw)) class UnaryGridFunction(NonlinearOperator, FutureField): @@ -812,10 +815,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg0 = self.args[0] out.preset_layout(arg0.layout) out.lock_to_layouts(self.layouts) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) def new_operand(self, operand, **kw): return Lock(operand, *self.layouts, **kw) @@ -947,6 +951,20 @@ def subspace_matrix(self, layout): # Caching layer to allow insertion of other arguments return self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + @CachedMethod + def subspace_matrix_device(self, layout): + """Build matrix operating on local subspace data on device.""" + # Caching layer to allow insertion of other arguments + matrix = self._subspace_matrix(layout, self.input_basis, self.output_basis, self.first_axis) + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupy as cp + import cupyx.scipy.sparse as csp + if sparse.issparse(matrix): + matrix = csp.csr_matrix(matrix) + else: + matrix = cp.array(matrix) + return matrix + def group_matrix(self, group): return self._group_matrix(group, self.input_basis, self.output_basis) @@ -987,7 +1005,7 @@ def operate(self, out): # Apply matrix if arg.data.size and out.data.size: data_axis = self.last_axis + len(arg.tensorsig) - apply_matrix(self.subspace_matrix(layout), arg.data, data_axis, out=out.data) + apply_matrix(self.subspace_matrix_device(layout), arg.data, data_axis, out=out.data) else: out.data.fill(0) @@ -1522,9 +1540,10 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) class Convert(SpectralOperator, metaclass=MultiClass): @@ -1624,12 +1643,13 @@ def subspace_matrix(self, layout): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] layout = arg.layout # Copy for grid space if layout.grid_space[self.last_axis]: out.preset_layout(layout) - np.copyto(out.data, arg.data) + xp.copyto(out.data, arg.data) # Revert to matrix application for coeff space else: super().operate(out) @@ -1772,9 +1792,10 @@ def base(self): def operate(self, out): """Perform operation.""" + xp = self.array_namespace arg = self.args[0] out.preset_layout(arg.layout) - np.einsum('ii...', arg.data, out=out.data) + xp.einsum('ii...', arg.data, out=out.data) class SphericalTrace(Trace): @@ -1971,6 +1992,7 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" + xp = self.array_namespace operand = self.args[0] # Set output layout out.preset_layout(operand.layout) @@ -3485,10 +3507,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductDivergence(Divergence): @@ -3534,10 +3557,11 @@ def subproblem_matrix(self, subproblem): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalDivergence(Divergence, SphericalEllOperator): @@ -3739,10 +3763,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductCurl(Curl): @@ -3826,10 +3851,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalCurl(Curl, SphericalEllOperator): @@ -4052,10 +4078,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class DirectProductLaplacian(Laplacian): @@ -4097,10 +4124,11 @@ def enforce_conditions(self): def operate(self, out): """Perform operation.""" # OPTIMIZE: this has an extra copy + xp = self.array_namespace arg0 = self.args[0] # Set output layout out.preset_layout(arg0.layout) - np.copyto(out.data, arg0.data) + xp.copyto(out.data, arg0.data) class SphericalLaplacian(Laplacian, SphericalEllOperator): diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index 191d63e0..bd0a6c0b 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -11,13 +11,20 @@ from mpi4py import MPI import uuid from math import prod +import array_api_compat from .domain import Domain -from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv +from ..tools.array import zeros_with_pattern, expand_pattern, sparse_block_diag, copyto, perm_matrix, drop_empty_rows, apply_sparse, assert_sparse_pinv, copy_to_device, copy_from_device from ..tools.cache import CachedAttribute, CachedMethod from ..tools.general import replace, OrderedSet from ..tools.progress import log_progress +try: + import cupy as cp + import cupyx.scipy.sparse as csp +except ImportError: + pass + import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -118,6 +125,7 @@ def __init__(self, solver, group): self.solver = solver self.problem = problem = solver.problem self.dist = solver.dist + self.array_namespace = solver.dist.array_namespace self.dtype = problem.dtype self.group = group # Determine matrix group using solver matrix dependence @@ -191,11 +199,12 @@ def field_size(self, field): @CachedMethod def _gather_scatter_setup(self, fields): + xp = self.array_namespace # Allocate vector fsizes = tuple(self.field_size(f) for f in fields) fslices = tuple(self.field_slices(f) for f in fields) fshapes = tuple(self.field_shape(f) for f in fields) - data = np.empty(sum(fsizes), dtype=self.dtype) + data = xp.empty(sum(fsizes), dtype=self.dtype) # Make views into data fviews = [] i0 = 0 @@ -248,6 +257,7 @@ def __init__(self, solver, subsystems, group): self.subsystems = subsystems self.group = group self.dist = problem.dist + self.array_namespace = self.dist.array_namespace self.domain = problem.variables[0].domain # HACK self.dtype = problem.dtype # Cross reference from subsystems @@ -279,7 +289,8 @@ def size(self): @CachedAttribute def _compressed_buffer(self): - return np.zeros(self.shape, dtype=self.dtype) + xp = self.array_namespace + return xp.zeros(self.shape, dtype=self.dtype) def coeff_slices(self, domain): return self.subsystems[0].coeff_slices(domain) @@ -300,9 +311,10 @@ def field_size(self, field): return self.subsystems[0].field_size(field) def _build_buffer_views(self, fields): + xp = self.array_namespace # Allocate buffer fsizes = tuple(self.field_size(f) for f in fields) - buffer = np.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) + buffer = xp.zeros((sum(fsizes), len(self.subsystems)), dtype=self.dtype) # Make views into buffer views = [] i0 = 0 @@ -342,7 +354,7 @@ def gather_inputs(self, fields, out=None): # Gather from fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply right preconditioner inverse to compress inputs if out is None: out = self._compressed_buffer @@ -354,7 +366,7 @@ def gather_outputs(self, fields, out=None): # Gather from fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(buffer_view, field_view) + copyto(buffer_view, field_view) # Apply left preconditioner to compress outputs if out is None: out = self._compressed_buffer @@ -368,7 +380,7 @@ def scatter_inputs(self, data, fields): # Scatter to fields views = self._input_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copyto(field_view, buffer_view) def scatter_outputs(self, data, fields): """Precondition and scatter subproblem data out to output-like field list.""" @@ -377,7 +389,7 @@ def scatter_outputs(self, data, fields): # Scatter to fields views = self._output_field_views(tuple(fields)) for buffer_view, field_view in views: - np.copyto(field_view, buffer_view) + copyto(field_view, buffer_view) def inclusion_matrices(self, bases): """List of inclusion matrices.""" @@ -555,24 +567,45 @@ def build_matrices(self, names): left_perm = left_permutation(self, eqns, bc_top=solver.bc_top, interleave_components=solver.interleave_components).tocsr() right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr() - # Preconditioners + # Preconditioners on CPU # TODO: remove astype casting, requires dealing with used types in apply_sparse - self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) - self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype) - self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) - self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype) + pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) + pre_left_pinv = pre_left.T.tocsr().astype(dtype) + pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) + pre_right = pre_right_pinv.T.tocsr().astype(dtype) # Check preconditioner pseudoinverses - assert_sparse_pinv(self.pre_left, self.pre_left_pinv) - assert_sparse_pinv(self.pre_right, self.pre_right_pinv) + assert_sparse_pinv(pre_left, pre_left_pinv) + assert_sparse_pinv(pre_right, pre_right_pinv) # Precondition matrices for name in matrices: - matrices[name] = self.pre_left @ matrices[name] @ self.pre_right + matrices[name] = pre_left @ matrices[name] @ pre_right - # Store minimal CSR matrices for fast dot products + # Store minimal CSR matrices on CPU for name, matrix in matrices.items(): - setattr(self, '{:}_min'.format(name), matrix.tocsr()) + setattr(self, f'{name}_min', matrix.tocsr()) + + # Store device copies for fast dot products + xp = solver.dist.array_namespace + if array_api_compat.is_numpy_namespace(xp): + self.pre_left = pre_left + self.pre_left_pinv = pre_left_pinv + self.pre_right_pinv = pre_right_pinv + self.pre_right = pre_right + # Reference current CPU matrices + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', getattr(self, f'{name}_min')) + elif array_api_compat.is_cupy_namespace(xp): + # Copy to device + self.pre_left = csp.csr_matrix(pre_left) + self.pre_left_pinv = csp.csr_matrix(pre_left_pinv) + self.pre_right_pinv = csp.csr_matrix(pre_right_pinv) + self.pre_right = csp.csr_matrix(pre_right) + for name, matrix in matrices.items(): + setattr(self, f'{name}_min_device', csp.csr_matrix(matrix)) + else: + raise ValueError("Unsupported array namespace: {}".format(xp)) # Store expanded CSR matrices for fast recombination if len(matrices) > 1: diff --git a/dedalus/core/system.py b/dedalus/core/system.py index 23cbb86b..f28cb206 100644 --- a/dedalus/core/system.py +++ b/dedalus/core/system.py @@ -12,45 +12,44 @@ class CoeffSystem: """ - Representation of a collection of fields that don't need to be transformed, - and are therefore stored as a contigous set of coefficient data for - efficient pencil and group manipulation. + Contiguous buffer for data from all subproblems. Parameters ---------- - nfields : int - Number of fields to represent - domain : domain object - Problem domain + subproblems : list of Subproblem objects + Subproblems to represent + dtype : dtype + Data type + array_namespace : array namespace + Array namespace Attributes ---------- data : ndarray - Contiguous buffer for field coefficients - - """ - - """ - var buffer - + Contiguous buffer for data from all subproblems + views : dict + Nested dictionary of views for each subproblem and subsystem """ - def __init__(self, subproblems, dtype): + def __init__(self, subproblems, dtype, array_namespace): + xp = array_namespace # Build buffer total_size = sum(sp.LHS.shape[1]*len(sp.subsystems) for sp in subproblems) - self.data = np.zeros(total_size, dtype=dtype) + self.data = xp.zeros(total_size, dtype=dtype) # Build views i0 = i1 = 0 self.views = views = {} for sp in subproblems: views[sp] = views_sp = {} + # View for each individual subsystem i00 = i0 for ss in sp.subsystems: i1 += sp.LHS.shape[1] views_sp[ss] = self.data[i0:i1] i0 = i1 i11 = i1 + # View combining all subsystems as rows in a matrix if i11 - i00 > 0: views_sp[None] = self.data[i00:i11].reshape((sp.LHS.shape[1], -1)) else: diff --git a/dedalus/core/timesteppers.py b/dedalus/core/timesteppers.py index 81da4c10..162a2d32 100644 --- a/dedalus/core/timesteppers.py +++ b/dedalus/core/timesteppers.py @@ -2,10 +2,9 @@ from collections import deque, OrderedDict import numpy as np -from scipy.linalg import blas from .system import CoeffSystem -from ..tools.array import apply_sparse +from ..tools.array import apply_sparse, get_axpy # Public interface @@ -71,7 +70,8 @@ class MultistepIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) # Create deque for storing recent timesteps self.dt = deque([0.] * self.steps) @@ -81,16 +81,16 @@ def __init__(self, solver): self.LX = LX = deque() self.F = F = deque() for j in range(self.amax): - MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + MX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) for j in range(self.bmax): - LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + LX.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) for j in range(self.cmax): - F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype)) + F.append(CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp)) # Attributes self._iteration = 0 self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -143,8 +143,8 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Evaluate F(X0) evaluator.evaluate_scheduled(iteration=iteration, wall_time=wall_time, sim_time=sim_time, timestep=dt) @@ -539,15 +539,16 @@ class RungeKuttaIMEX: def __init__(self, solver): self.solver = solver - self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype) + xp = solver.dist.array_namespace + self.RHS = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) # Create coefficient systems for multistep history - self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype) - self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] - self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype) for i in range(self.stages)] + self.MX0 = CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) + self.LX = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] + self.F = [CoeffSystem(solver.subproblems, dtype=solver.dtype, array_namespace=xp) for i in range(self.stages)] self._LHS_params = None - self.axpy = blas.get_blas_funcs('axpy', dtype=solver.dtype) + self.axpy = get_axpy(xp, solver.dtype) def step(self, dt, wall_time): """Advance solver by one timestep.""" @@ -584,11 +585,12 @@ def step(self, dt, wall_time): # Compute M.X(n,0) and L.X(n,0) # Ensure coeff space before subsystem gathers + # TODO: add option to evaluate this matrix-free (e.g for high-bandwidth NCCs when using fast transforms) evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.M_min, spX, axis=0, out=MX0.get_subdata(sp)) - apply_sparse(sp.L_min, spX, axis=0, out=LX0.get_subdata(sp)) + apply_sparse(sp.M_min_device, spX, axis=0, out=MX0.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LX0.get_subdata(sp)) # Compute stages # (M + k Hii L).X(n,i) = M.X(n,0) + k Aij F(n,j) - k Hij L.X(n,j) @@ -601,7 +603,7 @@ def step(self, dt, wall_time): evaluator.require_coeff_space(state_fields) for sp in subproblems: spX = sp.gather_inputs(state_fields) - apply_sparse(sp.L_min, spX, axis=0, out=LXi.get_subdata(sp)) + apply_sparse(sp.L_min_device, spX, axis=0, out=LXi.get_subdata(sp)) # Compute F(n,i-1), only doing output on first evaluation if i == 1: diff --git a/dedalus/core/transforms.py b/dedalus/core/transforms.py index 00758fb2..0433d802 100644 --- a/dedalus/core/transforms.py +++ b/dedalus/core/transforms.py @@ -8,13 +8,16 @@ import scipy.fftpack from ..libraries import dedalus_sphere from math import prod +import array_api_compat from . import basis from ..libraries.fftw import fftw_wrappers as fftw from ..tools import jacobi -from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse +from ..tools.array import apply_matrix, apply_dense, axslice, solve_upper_sparse, apply_sparse, copyto from ..tools.cache import CachedAttribute from ..tools.cache import CachedMethod +from ..tools.general import float_to_complex +from ..tools.linalg_gpu import cupy_solve_upper_csr, CustomCupyUpperTriangularSolver import logging logger = logging.getLogger(__name__.split('.')[-1]) @@ -93,31 +96,39 @@ class JacobiTransform(SeparableTransform): Jacobi "a" parameter for the quadrature grid. b0 : int Jacobi "b" parameter for the quadrature grid. + array_namespace : array namespace + Array namespace for the transform. + dtype : dtype + Data type for the transform. Notes ----- TODO: We need to define the normalization we use here. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, dealias_before_converting=None): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, dealias_before_converting=None): self.N = grid_size self.M = coeff_size self.a = a self.b = b self.a0 = a0 self.b0 = b0 + self.array_namespace = array_namespace + self.dtype = dtype if dealias_before_converting is None: dealias_before_converting = GET_DEALIAS_BEFORE_CONVERTING() self.dealias_before_converting = dealias_before_converting -@register_transform(basis.Jacobi, 'matrix') +@register_transform(basis.Jacobi, 'matrix-numpy') +@register_transform(basis.Jacobi, 'matrix-cupy') class JacobiMMT(JacobiTransform, SeparableMatrixTransform): """Jacobi polynomial MMTs.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -141,11 +152,12 @@ def forward_matrix(self): # Truncate to specified coeff_size forward_matrix = forward_matrix[:M, :] # Ensure C ordering for fast dot products - return np.asarray(forward_matrix, order='C') + return xp.asarray(forward_matrix, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N, M = self.N, self.M a, a0 = self.a, self.a0 b, b0 = self.b, self.b0 @@ -155,7 +167,7 @@ def backward_matrix(self): # Zero higher polynomials for transforms with grid_size < coeff_size polynomials[N:, :] = 0 # Transpose and ensure C ordering for fast dot products - return np.asarray(polynomials.T, order='C') + return xp.asarray(polynomials.T, order='C', dtype=self.dtype) class ComplexFourierTransform(SeparableTransform): @@ -191,50 +203,56 @@ class ComplexFourierTransform(SeparableTransform): If M is even, the ordering is [0, 1, 2, ..., KM, -KM, -KM+1, ..., -1]. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): self.N = grid_size self.M = coeff_size self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace M = self.M KM = self.KM - k = np.arange(M) + k = xp.arange(M) # Wrap around Nyquist mode return (k + KM) % M - KM -@register_transform(basis.ComplexFourier, 'matrix') +@register_transform(basis.ComplexFourier, 'matrix-numpy') +@register_transform(basis.ComplexFourier, 'matrix-cupy') class ComplexFourierMMT(ComplexFourierTransform, SeparableMatrixTransform): """Complex-to-complex Fourier MMT.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[:, None] - X = np.arange(self.N)[None, :] + X = xp.arange(self.N)[None, :] dX = self.N / 2 / np.pi - quadrature = np.exp(-1j*K*X/dX) / self.N + quadrature = xp.exp(-1j*K*X/dX) / self.N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - quadrature *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + quadrature *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace K = self.wavenumbers[None, :] - X = np.arange(self.N)[:, None] + X = xp.arange(self.N)[:, None] dX = self.N / 2 / np.pi - functions = np.exp(1j*K*X/dX) + functions = xp.exp(1j*K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size - functions *= np.abs(K) <= self.Kmax - # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + functions *= xp.abs(K) <= self.Kmax + # Ensure C ordering for fast dot products, cast to specified dtype + return xp.asarray(functions, order='C', dtype=self.dtype) class ComplexFFT(ComplexFourierTransform): @@ -242,32 +260,33 @@ class ComplexFFT(ComplexFourierTransform): def resize_coeffs(self, data_in, data_out, axis, rescale): """Resize and rescale coefficients in standard FFT format by intermediate padding/truncation.""" + xp = self.array_namespace M = self.M Kmax = self.Kmax if Kmax == 0: posfreq = axslice(axis, 0, 1) badfreq = axslice(axis, 1, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 else: posfreq = axslice(axis, 0, Kmax+1) badfreq = axslice(axis, Kmax+1, -Kmax) negfreq = axslice(axis, -Kmax, None) if rescale is None: - np.copyto(data_out[posfreq], data_in[posfreq]) + xp.copyto(data_out[posfreq], data_in[posfreq]) data_out[badfreq] = 0 - np.copyto(data_out[negfreq], data_in[negfreq]) + xp.copyto(data_out[negfreq], data_in[negfreq]) else: - np.multiply(data_in[posfreq], rescale, data_out[posfreq]) + xp.multiply(data_in[posfreq], rescale, data_out[posfreq]) data_out[badfreq] = 0 - np.multiply(data_in[negfreq], rescale, data_out[negfreq]) + xp.multiply(data_in[negfreq], rescale, data_out[negfreq]) -@register_transform(basis.ComplexFourier, 'scipy') +@register_transform(basis.ComplexFourier, 'scipy-numpy') class ScipyComplexFFT(ComplexFFT): """Complex-to-complex FFT using scipy.fft.""" @@ -289,6 +308,34 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +@register_transform(basis.ComplexFourier, 'scipy-cupy') +class CupyComplexFFT(ComplexFFT): + """Complex-to-complex FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call FFT + temp = self.cufft.fft(gdata, axis=axis) # Creates temporary + # Resize and rescale for unit-amplitude normalization + self.resize_coeffs(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_coeffs(cdata, temp, axis, rescale=self.N) + # Call FFT + temp = self.cufft.ifft(temp, axis=axis, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + class FFTWBase: """Abstract base class for FFTW transforms.""" @@ -299,7 +346,7 @@ def __init__(self, *args, rigor=None, **kw): super().__init__(*args, **kw) -@register_transform(basis.ComplexFourier, 'fftw') +@register_transform(basis.ComplexFourier, 'fftw-numpy') class FFTWComplexFFT(FFTWBase, ComplexFFT): """Complex-to-complex FFT using FFTW.""" @@ -368,7 +415,7 @@ class RealFourierTransform(SeparableTransform): where the k = 0 minus-sine mode is zeroed in both directions. """ - def __init__(self, grid_size, coeff_size): + def __init__(self, grid_size, coeff_size, array_namespace, dtype): if coeff_size % 2 != 0: pass#raise ValueError("coeff_size must be even.") self.N = grid_size @@ -376,55 +423,61 @@ def __init__(self, grid_size, coeff_size): self.KN = (self.N - 1) // 2 self.KM = (self.M - 1) // 2 self.Kmax = min(self.KN, self.KM) + self.array_namespace = array_namespace + self.dtype = dtype @property def wavenumbers(self): """One-dimensional global wavenumber array.""" + xp = self.array_namespace # Repeat k's for cos and msin parts - return np.repeat(np.arange(self.KM+1), 2) + return xp.repeat(xp.arange(self.KM+1), 2) -@register_transform(basis.RealFourier, 'matrix') +@register_transform(basis.RealFourier, 'matrix-numpy') +@register_transform(basis.RealFourier, 'matrix-cupy') class RealFourierMMT(RealFourierTransform, SeparableMatrixTransform): """Real-to-real Fourier MMT.""" @CachedAttribute def forward_matrix(self): """Build forward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[::2, None] - X = np.arange(N)[None, :] + X = xp.arange(N)[None, :] dX = N / 2 / np.pi - quadrature = np.zeros((M, N)) - quadrature[0::2] = (2 / N) * np.cos(K*X/dX) - quadrature[1::2] = -(2 / N) * np.sin(K*X/dX) + quadrature = xp.zeros((M, N)) + quadrature[0::2] = (2 / N) * xp.cos(K*X/dX) + quadrature[1::2] = -(2 / N) * xp.sin(K*X/dX) quadrature[0] = 1 / N # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size quadrature *= self.wavenumbers[:,None] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(quadrature, order='C') + return xp.asarray(quadrature, order='C', dtype=self.dtype) @CachedAttribute def backward_matrix(self): """Build backward transform matrix.""" + xp = self.array_namespace N = self.N M = max(2, self.M) # Account for sin and cos parts of m=0 Kmax = self.Kmax K = self.wavenumbers[None, ::2] - X = np.arange(N)[:, None] + X = xp.arange(N)[:, None] dX = N / 2 / np.pi - functions = np.zeros((N, M)) - functions[:, 0::2] = np.cos(K*X/dX) - functions[:, 1::2] = -np.sin(K*X/dX) + functions = xp.zeros((N, M)) + functions[:, 0::2] = xp.cos(K*X/dX) + functions[:, 1::2] = -xp.sin(K*X/dX) # Zero Nyquist and higher modes for transforms with grid_size <= coeff_size functions *= self.wavenumbers[None, :] <= self.Kmax # Ensure C ordering for fast dot products - return np.asarray(functions, order='C') + return xp.asarray(functions, order='C', dtype=self.dtype) -@register_transform(basis.RealFourier, 'fftpack') +@register_transform(basis.RealFourier, 'fftpack-numpy') class FFTPACKRealFFT(RealFourierTransform): """Real-to-real FFT using scipy.fftpack.""" @@ -471,48 +524,54 @@ class RealFFT(RealFourierTransform): def unpack_rescale(self, temp, cdata, axis, rescale): """Unpack complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 cos data meancos = axslice(axis, 0, 1) - np.multiply(temp[meancos].real, rescale, cdata[meancos]) + xp.multiply(temp[meancos].real, rescale, cdata[meancos]) # Zero k = 0 msin data cdata[axslice(axis, 1, 2)] = 0 # Unpack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] - np.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) - np.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) + xp.multiply(temp_posfreq.real, 2*rescale, cdata_posfreq_cos) + xp.multiply(temp_posfreq.imag, 2*rescale, cdata_posfreq_msin) # Zero k > Kmax data cdata[axslice(axis, 2*(Kmax+1), None)] = 0 def repack_rescale(self, cdata, temp, axis, rescale): """Repack into complex coefficients and rescale for unit-amplitude normalization.""" + xp = self.array_namespace Kmax = self.Kmax # Scale k = 0 data meancos = axslice(axis, 0, 1) if rescale is None: - np.copyto(temp[meancos], cdata[meancos]) + xp.copyto(temp[meancos], cdata[meancos]) else: - np.multiply(cdata[meancos], rescale, temp[meancos]) + xp.multiply(cdata[meancos], rescale, temp[meancos]) # Repack and scale 1 < k <= Kmax data temp_posfreq = temp[axslice(axis, 1, Kmax+1)] cdata_posfreq_cos = cdata[axslice(axis, 2, 2*(Kmax+1), 2)] cdata_posfreq_msin = cdata[axslice(axis, 3, 2*(Kmax+1), 2)] if rescale is None: - np.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (1 / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (1 / 2), temp_posfreq.imag) else: - np.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) - np.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) + xp.multiply(cdata_posfreq_cos, (rescale / 2), temp_posfreq.real) + xp.multiply(cdata_posfreq_msin, (rescale / 2), temp_posfreq.imag) # Zero k > Kmax data temp[axslice(axis, Kmax+1, None)] = 0 -@register_transform(basis.RealFourier, 'scipy') +@register_transform(basis.RealFourier, 'scipy-numpy') class ScipyRealFFT(RealFFT): """Real-to-real FFT using scipy.fft.""" + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + def forward(self, gdata, cdata, axis): """Apply forward transform along specified axis.""" # Call RFFT @@ -526,7 +585,7 @@ def backward(self, cdata, gdata, axis): # Rescale all modes and combine into complex form shape = list(gdata.shape) shape[axis] = N // 2 + 1 - temp = np.empty(shape=shape, dtype=np.complex128) # Creates temporary + temp = np.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary # Repack into complex form and rescale self.repack_rescale(cdata, temp, axis, rescale=N) # Call IRFFT @@ -534,7 +593,39 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) -@register_transform(basis.RealFourier, 'fftw') +@register_transform(basis.RealFourier, 'scipy-cupy') +class CupyRealFFT(RealFFT): + """Real-to-real FFT using scipy.fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + self.complex_dtype = float_to_complex(self.dtype) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call RFFT + temp = self.cufft.rfft(gdata, axis=axis) # Creates temporary + # Unpack from complex form and rescale + self.unpack_rescale(temp, cdata, axis, rescale=1/self.N) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + N = self.N + # Rescale all modes and combine into complex form + shape = list(gdata.shape) + shape[axis] = N // 2 + 1 + temp = xp.empty(shape=shape, dtype=self.complex_dtype) # Creates temporary + # Repack into complex form and rescale + self.repack_rescale(cdata, temp, axis, rescale=N) + # Call IRFFT + temp = self.cufft.irfft(temp, axis=axis, n=N, overwrite_x=True) # Creates temporary + xp.copyto(gdata, temp) + + +@register_transform(basis.RealFourier, 'fftw-numpy') class FFTWRealFFT(FFTWBase, RealFFT): """Real-to-real FFT using FFTW.""" @@ -565,7 +656,7 @@ def backward(self, cdata, gdata, axis): plan.backward(temp, gdata) -@register_transform(basis.RealFourier, 'fftw_hc') +@register_transform(basis.RealFourier, 'fftw_hc-numpy') class FFTWHalfComplexFFT(FFTWBase, RealFourierTransform): """Real-to-real FFT using FFTW half-complex DFT.""" @@ -768,6 +859,33 @@ def backward(self, cdata, gdata, axis): np.copyto(gdata, temp) +class CupyDCT(FastCosineTransform): + """Fast cosine transform using cupy fft.""" + + def __init__(self, *args, **kw): + import cupyx.scipy.fft as cufft + self.cufft = cufft + super().__init__(*args, **kw) + + def forward(self, gdata, cdata, axis): + """Apply forward transform along specified axis.""" + # Call DCT + temp = self.cufft.dct(gdata, type=2, axis=axis) # Creates temporary + # Resize and rescale for unit-ampltidue normalization + self.resize_rescale_forward(temp, cdata, axis, self.Kmax) + + def backward(self, cdata, gdata, axis): + """Apply backward transform along specified axis.""" + xp = self.array_namespace + # Resize and rescale for unit-amplitude normalization + # Need temporary to avoid overwriting problems + temp = xp.empty_like(gdata) # Creates temporary + self.resize_rescale_backward(cdata, temp, axis, self.Kmax) + # Call IDCT + temp = self.cufft.dct(temp, type=3, axis=axis, overwrite_x=True) # Creates temporary + copyto(gdata, temp) + + #@register_transform(basis.Cosine, 'fftw') class FFTWDCT(FFTWBase, FastCosineTransform): """Fast cosine transform using FFTW.""" @@ -804,11 +922,11 @@ class FastChebyshevTransform(JacobiTransform): Subclasses should inherit from this class, then a FastCosineTransform subclass. """ - def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): + def __init__(self, grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw): if not a0 == b0 == -1/2: raise ValueError("Fast Chebshev transform requires a0 == b0 == -1/2.") # Jacobi initialization - super().__init__(grid_size, coeff_size, a, b, a0, b0, **kw) + super().__init__(grid_size, coeff_size, a, b, a0, b0, array_namespace, dtype, **kw) # DCT initialization to set scaling factors if a != a0 or b != b0: # Modify coeff_size to avoid truncation before conversion @@ -840,6 +958,13 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0, **kw): self.backward_conversion.sum_duplicates() # for faster solve_upper self.resize_rescale_forward = self._resize_rescale_forward_convert self.resize_rescale_backward = self._resize_rescale_backward_convert + if array_api_compat.is_cupy_namespace(self.array_namespace): + import cupyx.scipy.sparse as csp + self.forward_conversion = csp.csr_matrix(self.forward_conversion) + self.backward_conversion = csp.csr_matrix(self.backward_conversion) + self.forward_conversion.sum_duplicates() + self.backward_conversion.sum_duplicates() + self.backward_conversion_LU = CustomCupyUpperTriangularSolver(self.backward_conversion) def _resize_rescale_forward(self, data_in, data_out, axis, Kmax): """Resize by padding/trunction and rescale to unit amplitude.""" @@ -881,7 +1006,10 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): # Truncate input before conversion data_in[badfreq] = 0 # Ultraspherical conversion - solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) + if array_api_compat.is_cupy_namespace(self.array_namespace): + cupy_solve_upper_csr(self.backward_conversion_LU, data_in, axis, out=data_in) + else: + solve_upper_sparse(self.backward_conversion, data_in, axis, out=data_in) # Change sign of odd modes if Kmax_orig > 0: posfreq_odd = axslice(axis, 1, Kmax_orig+1, 2) @@ -890,18 +1018,24 @@ def _resize_rescale_backward_convert(self, data_in, data_out, axis, Kmax_DCT): super().resize_rescale_backward(data_in, data_out, axis, Kmax_orig) -@register_transform(basis.Jacobi, 'scipy_dct') +@register_transform(basis.Jacobi, 'scipy_dct-numpy') class ScipyFastChebyshevTransform(FastChebyshevTransform, ScipyDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance -@register_transform(basis.Jacobi, 'fftw_dct') +@register_transform(basis.Jacobi, 'fftw_dct-numpy') class FFTWFastChebyshevTransform(FastChebyshevTransform, FFTWDCT): """Fast ultraspherical transform using scipy.fft and spectral conversion.""" pass # Implementation is complete via inheritance +@register_transform(basis.Jacobi, 'scipy_dct-cupy') +class CupyFastChebyshevTransform(FastChebyshevTransform, CupyDCT): + """Fast ultraspherical transform using cupy fft and spectral conversion.""" + pass # Implementation is complete via inheritance + + # class ScipyDST(PolynomialTransform): # def forward_reduced(self): diff --git a/dedalus/libraries/matsolvers.py b/dedalus/libraries/matsolvers.py index f301d4f2..ede93a1b 100644 --- a/dedalus/libraries/matsolvers.py +++ b/dedalus/libraries/matsolvers.py @@ -5,7 +5,12 @@ import scipy.sparse as sp import scipy.sparse.linalg as spla from functools import partial - +import array_api_compat +try: + import cupyx.scipy.sparse.linalg as cupy_spla + cupy_available = True +except ImportError: + cupy_available = False matsolvers = {} def add_solver(solver): @@ -144,6 +149,21 @@ def __init__(self, matrix, solver=None): relax=self.relax, panel_size=self.panel_size, options=self.options) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + # Avoid cupy splu which requires GPU matrices but transfers them to factorize on CPU + # Run same typecheck as cupy splu + if matrix.dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(self.LU.dtype)) + # Build cupy factorization from scipy factorization of CPU matrices + self.LU = cupy_spla.SuperLU(self.LU) + self.LU.spsm_L_descr = None + self.LU.spsm_U_descr = None + self.solve = self.cupy_solve + + def cupy_solve(self, vector): + from dedalus.tools.linalg_gpu import custom_SuperLU_solve + return custom_SuperLU_solve(self.LU, vector, trans=self.trans) def solve(self, vector): return self.LU.solve(vector, trans=self.trans) @@ -225,6 +245,9 @@ class SparseInverse(SparseSolver): def __init__(self, matrix, solver=None): self.matrix_inverse = spla.inv(matrix.tocsc()) + # Cupy conversion + if array_api_compat.is_cupy_namespace(solver.dist.array_namespace): + self.matrix_inverse = cupy_spla.inv(matrix.tocsc()) def solve(self, vector): return self.matrix_inverse @ vector diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index ab9caf88..e137f75d 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -5,7 +5,10 @@ import scipy.sparse as sp from scipy.sparse import _sparsetools from scipy.sparse import linalg as spla +from scipy.linalg import blas from math import prod +from ..tools import linalg_gpu +import array_api_compat from .config import config from . import linalg as cython_linalg @@ -76,10 +79,20 @@ def expand_pattern(input, pattern): def apply_matrix(matrix, array, axis, **kw): """Apply matrix along any axis of an array.""" - if sparse.isspmatrix(matrix): - return apply_sparse(matrix, array, axis, **kw) + xp = array_api_compat.array_namespace(array) + if array_api_compat.is_numpy_namespace(xp): + if sparse.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) + elif array_api_compat.is_cupy_namespace(xp): + import cupyx.scipy.sparse as csp + if csp.issparse(matrix): + return apply_sparse(matrix, array, axis, **kw) + else: + return apply_dense(matrix, array, axis, **kw) else: - return apply_dense(matrix, array, axis, **kw) + raise ValueError("Unsupported array type") def apply_dense_einsum(matrix, array, axis, optimize=True, **kw): @@ -173,14 +186,14 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= Apply sparse matrix along any axis of an array. Must be out of place if ouptut is specified. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") + xp = array_api_compat.array_namespace(array) + matrix.sum_duplicates() + matrix.has_canonical_format = True # Check output if out is None: out_shape = list(array.shape) out_shape[axis] = matrix.shape[0] - out = np.empty(out_shape, dtype=array.dtype) + out = xp.empty(out_shape, dtype=array.dtype) elif out is array: raise ValueError("Cannot apply in place") # Check shapes @@ -189,17 +202,27 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= raise ValueError("Axis out of bounds.") if matrix.shape[1] != array.shape[axis] or matrix.shape[0] != out.shape[axis]: raise ValueError("Matrix shape mismatch.") - # Old way if requested - if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: - out.fill(0) - return csr_matvecs(matrix, array, out) - # Promote datatypes - # TODO: find way to optimize this with fused types - matrix_data = matrix.data - if matrix_data.dtype != out.dtype: - matrix_data = matrix_data.astype(out.dtype) - # Call cython routine - cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + # Old way if requested + if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: + out.fill(0) + return csr_matvecs(matrix, array, out) + # Promote datatypes + # TODO: find way to optimize this with fused types + matrix_data = matrix.data + if matrix_data.dtype != out.dtype: + matrix_data = matrix_data.astype(out.dtype) + # Call cython routine + cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + # TODO: check matrix format here without import cupy + linalg_gpu.cupy_apply_csr(matrix, array, axis, out) + else: + raise ValueError("Unsupported array type") return out @@ -208,28 +231,40 @@ def solve_upper_sparse(matrix, rhs, axis, out=None, check_shapes=False, num_thre Solve upper triangular sparse matrix along any axis of an array. Matrix assumed to be nonzero on the diagonals. """ - # Check matrix - if not isinstance(matrix, sparse.csr_matrix): - raise ValueError("Matrix must be in CSR format.") - if not matrix._has_canonical_format: # avoid property hook (without underscore) - matrix.sum_duplicates() - # Setup output = rhs + xp = array_api_compat.array_namespace(rhs) + matrix.sum_duplicates() + matrix.has_canonical_format = True + # Check output if out is None: - out = np.copy(rhs) - elif out is not rhs: - np.copyto(out, rhs) - # Promote datatypes - matrix_data = matrix.data - if matrix_data.dtype != rhs.dtype: - matrix_data = matrix_data.astype(rhs.dtype) - # Check shapes - if check_shapes: - if not (0 <= axis < rhs.ndim): - raise ValueError("Axis out of bounds.") - if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): - raise ValueError("Matrix shape mismatch.") - # Call cython routine - cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + out = xp.empty_like(rhs) + # Dispatch on array type + if array_api_compat.is_numpy_namespace(xp): + # Check matrix + if not isinstance(matrix, sparse.csr_matrix): + raise ValueError("Matrix must be in CSR format.") + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + # Setup output = rhs + copyto(out, rhs) + # Promote datatypes + matrix_data = matrix.data + if matrix_data.dtype != rhs.dtype: + matrix_data = matrix_data.astype(rhs.dtype) + # Check shapes + if check_shapes: + if not (0 <= axis < rhs.ndim): + raise ValueError("Axis out of bounds.") + if not (matrix.shape[0] == matrix.shape[1] == rhs.shape[axis]): + raise ValueError("Matrix shape mismatch.") + # Call cython routine + cython_linalg.solve_upper_csr(matrix.indptr, matrix.indices, matrix_data, out, axis, num_threads) + elif array_api_compat.is_cupy_namespace(xp): + if not matrix._has_canonical_format: # avoid property hook (without underscore) + matrix.sum_duplicates() + linalg_gpu.cupy_solve_upper_csr(matrix, rhs, axis, out) + else: + raise ValueError("Unsupported array type") + return out def csr_matvec(A_csr, x_vec, out_vec): @@ -353,6 +388,22 @@ def copyto(dest, src): dest[:] = src +def copy_to_device(dest, src): + xp = array_api_compat.array_namespace(dest) + if array_api_compat.is_cupy_namespace(xp): + src = xp.asarray(src) + dest[:] = src + else: + dest[:] = src + + +def copy_from_device(dest, src): + if array_api_compat.is_cupy_array(src): + src.get(out=dest) + else: + dest[:] = src + + def perm_matrix(perm, M=None, source_index=False, sparse=True): """ Build sparse permutation matrix from permutation vector. @@ -474,3 +525,12 @@ def assert_sparse_pinv(A, B): if not sparse_allclose((B @ A).conj().T, B @ A): raise AssertionError("Not a pseudoinverse") + +def get_axpy(array_namespace, dtype): + if array_api_compat.is_numpy_namespace(array_namespace): + return blas.get_blas_funcs('axpy', dtype=dtype) + elif array_api_compat.is_cupy_namespace(array_namespace): + from cupy.cublas import axpy as cublas_axpy + return cublas_axpy + else: + raise ValueError("Unsupported array namespace") diff --git a/dedalus/tools/general.py b/dedalus/tools/general.py index 18eb5ee4..9b8b5746 100644 --- a/dedalus/tools/general.py +++ b/dedalus/tools/general.py @@ -124,3 +124,15 @@ def is_complex_dtype(dtype): dtype = dtype.type return np.iscomplexobj(dtype()) + +def float_to_complex(dtype): + itemsize = np.dtype(dtype).itemsize + complex_dtype = np.dtype(f'complex{16*itemsize}') + return complex_dtype.type + + +def complex_to_float(dtype): + itemsize = np.dtype(dtype).itemsize + float_dtype = np.dtype(f'float{4*itemsize}') + return float_dtype.type + diff --git a/dedalus/tools/linalg_gpu.py b/dedalus/tools/linalg_gpu.py new file mode 100644 index 00000000..a64f51f1 --- /dev/null +++ b/dedalus/tools/linalg_gpu.py @@ -0,0 +1,545 @@ +"""Linear algebra routines using cupy.""" + +import numpy as np +import math +try: + import cupy as cp + import cupyx.scipy.sparse as csp + import cupyx.scipy.sparse.linalg as cupy_spla + cupy_available = True +except ImportError: + cupy_available = False + + +def cupy_apply_csr(matrix, array, axis, out): + """Apply CSR matrix to arbitrary axis of array.""" + if not cupy_available: + raise ImportError("cupy must be installed to use GPU linear algebra") + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + print('WARNING: converting matrix to CSR format') + matrix = csp.csr_matrix(matrix) + #raise ValueError("Matrix must be in CSR format.") + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + out[:] = matrix.dot(array) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + out[:,0] = matrix.dot(array[:,0]) + else: + out[:] = matrix.dot(array) + elif axis == 1: + if array.shape[0] == 1: + out[0,:] = matrix.dot(array[0,:]) + else: + out[:] = matrix.dot(array.T).T + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = array.shape[0] + N2 = array.shape[1] + N3 = array.shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + temp = matrix.dot(x1) + out[:] = temp.reshape(out.shape) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + temp = matrix.dot(x2) + out[:] = temp.reshape(out.shape) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + temp = matrix.dot(x2.T).T + out[:] = temp.reshape(out.shape) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape(((N1, matrix.shape[0], N3))) + cupy_apply_csr_mid(matrix, x3, y3) + + +# Kernel for applying CSR matrix with parallelization over n1 and n3 +apply_csr_mid_kernel = cp.RawKernel( + r''' + extern "C" __global__ void apply_csr_mid_kernel( + const double* data, // CSR data of shape (nnz,) + const int* indices, // CSR column indices (nnz,) + const int* indptr, // CSR row pointers (N2o + 1,) + const double* input, // shape (N1, N2i, N3) + double* output, // shape (N1, N2o, N3) + int N1, int N2i, int N2o, int N3) + { + int n1 = blockIdx.x * blockDim.x + threadIdx.x ; // batch index + int n3 = blockIdx.y * blockDim.y + threadIdx.y; // output column index + + if (n1 >= N1 || n3 >= N3) return; + + // Loop over output rows = CSR matrix rows + for (int i = 0; i < N2o; ++i) { + double acc = 0; + int start = indptr[i]; + int end = indptr[i + 1]; + + for (int k = start; k < end; ++k) { + int j = indices[k]; // input column + double val = data[k]; + acc += val * input[n1 * N2i * N3 + j * N3 + n3]; + } + + output[n1 * N2o * N3 + i * N3 + n3] = acc; + } + } + ''', + 'apply_csr_mid_kernel') + + +def cupy_apply_csr_mid(matrix, array, out): + N1, N2i, N3 = array.shape + N2o = matrix.shape[0] + # Choose thread/block config + threads_y = min(1024, N3) # maximize concurrency along n3 + threads_x = 1024 // threads_y # make block have 1024 threads + blockdim = (threads_x, threads_y) + blocks_x = (N1 + threads_x - 1) // threads_x + blocks_y = (N3 + threads_y - 1) // threads_y + griddim = (blocks_x, blocks_y) + # Launch kernel + apply_csr_mid_kernel(griddim, blockdim, (matrix.data, matrix.indices, matrix.indptr, array, out, N1, N2i, N2o, N3)) + + +def custom_spsm(a, b, alpha=1.0, lower=True, unit_diag=False, transa=False, spsm_descr=None): + """Custom spsm wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves a sparse triangular linear system op(a) * x = alpha * op(b). + + Args: + a (cupyx.scipy.sparse.csr_matrix or cupyx.scipy.sparse.coo_matrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): Dense matrix with dimension ``(M, K)``. + alpha (float or complex): Coefficient. + lower (bool): + True: ``a`` is lower triangle matrix. + False: ``a`` is upper triangle matrix. + unit_diag (bool): + True: diagonal part of ``a`` has unit elements. + False: diagonal part of ``a`` has non-unit elements. + transa (bool or str): True, False, 'N', 'T' or 'H'. + 'N' or False: op(a) == ``a``. + 'T' or True: op(a) == ``a.T``. + 'H': op(a) == ``a.conj().T``. + """ + import cupyx + from cupyx import cusparse + import cupy as _cupy + import numpy as _numpy + from cupy._core import _dtype + from cupy_backends.cuda.libs import cusparse as _cusparse + from cupy.cuda import device as _device + from cupyx.cusparse import SpMatDescriptor, DnMatDescriptor + if not cusparse.check_availability('spsm'): + raise RuntimeError('spsm is not available.') + + # Canonicalise transa + if transa is False: + transa = 'N' + elif transa is True: + transa = 'T' + elif transa not in 'NTH': + raise ValueError(f'Unknown transa (actual: {transa})') + + # Check A's type and sparse format + if cupyx.scipy.sparse.isspmatrix_csr(a): + pass + elif cupyx.scipy.sparse.isspmatrix_csc(a): + if transa == 'N': + a = a.T + transa = 'T' + elif transa == 'T': + a = a.T + transa = 'N' + elif transa == 'H': + a = a.conj().T + transa = 'N' + lower = not lower + elif cupyx.scipy.sparse.isspmatrix_coo(a): + pass + else: + raise ValueError('a must be CSR, CSC or COO sparse matrix') + assert a.has_canonical_format + + # Check B's ndim + if b.ndim == 1: + is_b_vector = True + b = b.reshape(-1, 1) + elif b.ndim == 2: + is_b_vector = False + else: + raise ValueError('b.ndim must be 1 or 2') + + # Check shapes + if not (a.shape[0] == a.shape[1] == b.shape[0]): + raise ValueError('mismatched shape') + + # Check dtypes + dtype = a.dtype + if dtype.char not in 'fdFD': + raise TypeError('Invalid dtype (actual: {})'.format(dtype)) + if dtype != b.dtype: + raise TypeError('dtype mismatch') + + # Prepare fill mode + if lower is True: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_LOWER + elif lower is False: + fill_mode = _cusparse.CUSPARSE_FILL_MODE_UPPER + else: + raise ValueError('Unknown lower (actual: {})'.format(lower)) + + # Prepare diag type + if unit_diag is False: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_NON_UNIT + elif unit_diag is True: + diag_type = _cusparse.CUSPARSE_DIAG_TYPE_UNIT + else: + raise ValueError('Unknown unit_diag (actual: {})'.format(unit_diag)) + + # Prepare op_a + if transa == 'N': + op_a = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif transa == 'T': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: # transa == 'H' + if dtype.char in 'fd': + op_a = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + op_a = _cusparse.CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE + + # Prepare op_b + if b._f_contiguous: + op_b = _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE + elif b._c_contiguous: + if _cusparse.get_build_version() < 11701: # earlier than CUDA 11.6 + raise ValueError('b must be F-contiguous.') + b = b.T + op_b = _cusparse.CUSPARSE_OPERATION_TRANSPOSE + else: + raise ValueError('b must be F-contiguous or C-contiguous.') + + # Allocate space for matrix C. Note that it is known cusparseSpSM requires + # the output matrix zero initialized. + m, _ = a.shape + if op_b == _cusparse.CUSPARSE_OPERATION_NON_TRANSPOSE: + _, n = b.shape + else: + n, _ = b.shape + c_shape = m, n + c = _cupy.zeros(c_shape, dtype=a.dtype, order='f') + + # Prepare descriptors and other parameters + handle = _device.get_cusparse_handle() + mat_a = SpMatDescriptor.create(a) + mat_b = DnMatDescriptor.create(b) + mat_c = DnMatDescriptor.create(c) + if spsm_descr is None: + spsm_descr = _cusparse.spSM_createDescr() + new_spsm_descr = True + else: + spsm_descr, buff = spsm_descr + new_spsm_descr = False + alpha = _numpy.array(alpha, dtype=c.dtype).ctypes + cuda_dtype = _dtype.to_cuda_dtype(c.dtype) + algo = _cusparse.CUSPARSE_SPSM_ALG_DEFAULT + + try: + # Specify Lower|Upper fill mode + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_FILL_MODE, fill_mode) + + # Specify Unit|Non-Unit diagonal type + mat_a.set_attribute(_cusparse.CUSPARSE_SPMAT_DIAG_TYPE, diag_type) + + # Allocate the workspace needed by the succeeding phases + if new_spsm_descr: + buff_size = _cusparse.spSM_bufferSize( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr) + buff = _cupy.empty(buff_size, dtype=_cupy.int8) + + # Perform the analysis phase + if new_spsm_descr: + _cusparse.spSM_analysis( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Executes the solve phase + _cusparse.spSM_solve( + handle, op_a, op_b, alpha.data, mat_a.desc, mat_b.desc, + mat_c.desc, cuda_dtype, algo, spsm_descr, buff.data.ptr) + + # Reshape back if B was a vector + if is_b_vector: + c = c.reshape(-1) + + return c, (spsm_descr, buff) + + finally: + # Destroy matrix/vector descriptors + #_cusparse.spSM_destroyDescr(spsm_descr) + pass + + +def custom_SuperLU_solve(self, rhs, trans='N', spsm_descr=None): + """Custom SuperLU solve wrapper to save spsm_descr, since spsm_analysis takes lots of time.""" + """Solves linear system of equations with one or several right-hand sides. + + Args: + rhs (cupy.ndarray): Right-hand side(s) of equation with dimension + ``(M)`` or ``(M, K)``. + trans (str): 'N', 'T' or 'H'. + 'N': Solves ``A * x = rhs``. + 'T': Solves ``A.T * x = rhs``. + 'H': Solves ``A.conj().T * x = rhs``. + + Returns: + cupy.ndarray: + Solution vector(s) + """ # NOQA + from cupyx import cusparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + if not isinstance(rhs, cupy.ndarray): + raise TypeError('ojb must be cupy.ndarray') + if rhs.ndim not in (1, 2): + raise ValueError('rhs.ndim must be 1 or 2 (actual: {})'. + format(rhs.ndim)) + if rhs.shape[0] != self.shape[0]: + raise ValueError('shape mismatch (self.shape: {}, rhs.shape: {})' + .format(self.shape, rhs.shape)) + if trans not in ('N', 'T', 'H'): + raise ValueError('trans must be \'N\', \'T\', or \'H\'') + + if cusparse.check_availability('spsm') and _should_use_spsm(rhs): + def spsm(A, B, lower, transa, spsm_descr): + return custom_spsm(A, B, lower=lower, transa=transa, spsm_descr=spsm_descr) + sm = spsm + else: + raise NotImplementedError + + x = rhs.astype(self.L.dtype) + if trans == 'N': + if self.perm_r is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_r_rev].T # want to keep f-order + else: + x = x[self._perm_r_rev] + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + if self.perm_c is not None: + x = x[self.perm_c] + else: + if self.perm_c is not None: + if x.ndim == 2 and x._f_contiguous: + x = x.T[:, self._perm_c_rev].T # want to keep f-order + else: + x = x[self._perm_c_rev] + x, self.spsm_U_descr = sm(self.U, x, lower=False, transa=trans, spsm_descr=self.spsm_U_descr) + x, self.spsm_L_descr = sm(self.L, x, lower=True, transa=trans, spsm_descr=self.spsm_L_descr) + if self.perm_r is not None: + x = x[self.perm_r] + + if not x._f_contiguous: + # For compatibility with SciPy + x = x.copy(order='F') + return x + + +class CustomCupyUpperTriangularSolver: + """Hacky class to save spsm_descr for reuse in spsm for triangular solves.""" + + def __init__(self, matrix): + # Check matrix format + if not isinstance(matrix, csp.csr_matrix): + # TODO: avoid this explicit conversion + matrix = csp.csr_matrix(matrix) + print('WARNING: converting matrix to CSR format') + #raise ValueError("Matrix must be in CSR format.") + self.matrix = matrix + self.spsm_descr = None + + def solve(self, b, lower=True, overwrite_A=False, overwrite_b=False, + unit_diagonal=False): + """Solves a sparse triangular system ``A x = b``. + + Args: + A (cupyx.scipy.sparse.spmatrix): + Sparse matrix with dimension ``(M, M)``. + b (cupy.ndarray): + Dense vector or matrix with dimension ``(M)`` or ``(M, K)``. + lower (bool): + Whether ``A`` is a lower or upper triangular matrix. + If True, it is lower triangular, otherwise, upper triangular. + overwrite_A (bool): + (not supported) + overwrite_b (bool): + Allows overwriting data in ``b``. + unit_diagonal (bool): + If True, diagonal elements of ``A`` are assumed to be 1 and will + not be referenced. + + Returns: + cupy.ndarray: + Solution to the system ``A x = b``. The shape is the same as ``b``. + """ + from cupyx import cusparse + from cupyx.scipy import sparse + import cupy + from cupyx.scipy.sparse.linalg._solve import _should_use_spsm + + A = self.matrix + + if not (cusparse.check_availability('spsm') or + cusparse.check_availability('csrsm2')): + raise NotImplementedError + + if not sparse.isspmatrix(A): + raise TypeError('A must be cupyx.scipy.sparse.spmatrix') + if not isinstance(b, cupy.ndarray): + raise TypeError('b must be cupy.ndarray') + if A.shape[0] != A.shape[1]: + raise ValueError(f'A must be a square matrix (A.shape: {A.shape})') + if b.ndim not in [1, 2]: + raise ValueError(f'b must be 1D or 2D array (b.shape: {b.shape})') + if A.shape[0] != b.shape[0]: + raise ValueError('The size of dimensions of A must be equal to the ' + 'size of the first dimension of b ' + f'(A.shape: {A.shape}, b.shape: {b.shape})') + if A.dtype.char not in 'fdFD': + raise TypeError(f'unsupported dtype (actual: {A.dtype})') + + if cusparse.check_availability('spsm') and _should_use_spsm(b): + if not (sparse.isspmatrix_csr(A) or + sparse.isspmatrix_csc(A) or + sparse.isspmatrix_coo(A)): + warnings.warn('CSR, CSC or COO format is required. Converting to ' + 'CSR format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + x, self.spsm_descr = custom_spsm(A, b, lower=lower, unit_diag=unit_diagonal, spsm_descr=self.spsm_descr) + elif cusparse.check_availability('csrsm2'): + if not (sparse.isspmatrix_csr(A) or sparse.isspmatrix_csc(A)): + warnings.warn('CSR or CSC format is required. Converting to CSR ' + 'format.', sparse.SparseEfficiencyWarning) + A = A.tocsr() + A.sum_duplicates() + + if (overwrite_b and A.dtype == b.dtype and + (b._c_contiguous or b._f_contiguous)): + x = b + else: + x = b.astype(A.dtype, copy=True) + + cusparse.csrsm2(A, x, lower=lower, unit_diag=unit_diagonal) + else: + assert False + + if x.dtype.char in 'fF': + # Note: This is for compatibility with SciPy. + dtype = numpy.promote_types(x.dtype, 'float64') + x = x.astype(dtype) + return x + + +def cupy_solve_upper_csr(matrix, array, axis, out): + """Solve upper triangular CSR matrix along specified axis of an array.""" + # Switch by dimension + ndim = array.ndim + if ndim == 1: + if axis == 0: + cupy_solve_upper_csr_vec(matrix, array, out) + else: + raise ValueError("axis must be 0 for 1D arrays") + elif ndim == 2: + if axis == 0: + if array.shape[1] == 1: + cupy_solve_upper_csr_vec(matrix, array[:,0], out[:,0]) + else: + cupy_solve_upper_csr_first(matrix, array, out) + elif axis == 1: + if array.shape[0] == 1: + cupy_solve_upper_csr_vec(matrix, array[0,:], out[0,:]) + else: + cupy_solve_upper_csr_last(matrix, array, out) + else: + raise ValueError("axis must be 0 or 1 for 2D arrays") + else: + # Treat as 3D array with specified axis in the middle + # Compute equivalent shape (N1, N2, N3) + if ndim == 3 and axis == 1: + N1 = shape[0] + N2 = shape[1] + N3 = shape[2] + else: + N1 = int(np.prod(array.shape[:axis])) + N2 = array.shape[axis] + N3 = int(np.prod(array.shape[axis+1:])) + # Dispatch to cupy routines + if N1 == 1: + if N3 == 1: + # (1, N2, 1) -> (N2,) + x1 = array.reshape((N2,)) + y1 = out.reshape((N2,)) + cupy_solve_upper_csr_vec(matrix, x1, y1) + else: + # (1, N2, N3) -> (N2, N3) + x2 = array.reshape((N2, N3)) + y2 = out.reshape((N2, N3)) + cupy_solve_upper_csr_first(matrix, x2, y2) + else: + if N3 == 1: + # (N1, N2, 1) -> (N1, N2) + x2 = array.reshape((N1, N2)) + y2 = out.reshape((N1, N2)) + cupy_solve_upper_csr_last(matrix, x2, y2) + else: + # (N1, N2, N3) + x3 = array.reshape((N1, N2, N3)) + y3 = out.reshape((N1, N2, N3)) + cupy_solve_upper_csr_mid(matrix, x3, y3) + + +def cupy_solve_upper_csr_vec(matrix, vec, out): + """Solve upper triangular CSR matrix along a vector.""" + out[:] = matrix.solve(vec, lower=False) + + +def cupy_solve_upper_csr_first(matrix, array, out): + """Solve upper triangular CSR matrix along first axis of 2D array.""" + out[:] = matrix.solve(array, lower=False) + + +def cupy_solve_upper_csr_last(matrix, array, out): + """Solve upper triangular CSR matrix along last axis of 2D array.""" + out.T[:] = matrix.solve(array.T, lower=False) + + +def cupy_solve_upper_csr_mid(matrix, array, out): + """Solve upper triangular CSR matrix along middle axis of 3D array.""" + raise NotImplementedError diff --git a/setup.py b/setup.py index 1cf1d9dd..583009a6 100644 --- a/setup.py +++ b/setup.py @@ -181,6 +181,7 @@ def read(rel_path): # Runtime requirements install_requires = [ + "array-api-compat", "docopt", "h5py >= 3.0.0", "matplotlib",