diff --git a/conftest.py b/conftest.py index c0339f722b..0ed5b9ad78 100644 --- a/conftest.py +++ b/conftest.py @@ -14,7 +14,7 @@ from devito.ir.iet import (FindNodes, FindSymbols, Iteration, ParallelBlock, retrieve_iteration_tree) from devito.tools import as_tuple -from devito.petsc.utils import PetscOSError, get_petsc_dir +from devito.petsc.config import PetscOSError, get_petsc_dir try: from mpi4py import MPI # noqa diff --git a/devito/core/cpu.py b/devito/core/cpu.py index a93d3a3fef..eaa9266dc8 100644 --- a/devito/core/cpu.py +++ b/devito/core/cpu.py @@ -14,6 +14,8 @@ check_stability, PetscTarget) from devito.tools import timed_pass +from devito.petsc.iet.passes import lower_petsc_symbols + __all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator', 'Cpu64AdvOmpOperator', 'Cpu64FsgCOperator', 'Cpu64FsgOmpOperator', 'Cpu64CustomOperator', 'Cpu64CustomCXXOperator', 'Cpu64AdvCXXOperator', @@ -143,6 +145,9 @@ def _specialize_iet(cls, graph, **kwargs): # Symbol definitions cls._Target.DataManager(**kwargs).process(graph) + # Lower PETSc symbols + lower_petsc_symbols(graph, **kwargs) + return graph @@ -222,6 +227,9 @@ def _specialize_iet(cls, graph, **kwargs): # Symbol definitions cls._Target.DataManager(**kwargs).process(graph) + # Lower PETSc symbols + lower_petsc_symbols(graph, **kwargs) + # Linearize n-dimensional Indexeds linearize(graph, **kwargs) diff --git a/devito/ir/iet/algorithms.py b/devito/ir/iet/algorithms.py index 52f48e28b1..01e3b4976e 100644 --- a/devito/ir/iet/algorithms.py +++ b/devito/ir/iet/algorithms.py @@ -4,7 +4,7 @@ Section, HaloSpot, ExpressionBundle) from devito.tools import timed_pass from devito.petsc.types import MetaData -from devito.petsc.iet.utils import petsc_iet_mapper +from devito.petsc.iet.nodes import petsc_iet_mapper __all__ = ['iet_build'] diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 1179cd8936..a72772b854 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1068,7 +1068,7 @@ class FindSymbols(LazyVisitor[Any, list[Any], None]): Drive the search. Accepted: - `symbolics`: Collect all AbstractFunction objects, default - `basics`: Collect all Basic objects - - `abstractsymbols`: Collect all AbstractSymbol objects + - `symbols`: Collect all AbstractSymbol objects - `dimensions`: Collect all Dimensions - `indexeds`: Collect all Indexed objects - `indexedbases`: Collect all IndexedBase objects diff --git a/devito/passes/iet/languages/C.py b/devito/passes/iet/languages/C.py index 7822bf1680..6f2850c3a9 100644 --- a/devito/passes/iet/languages/C.py +++ b/devito/passes/iet/languages/C.py @@ -7,7 +7,8 @@ from devito.passes.iet.langbase import LangBB from devito.symbolics import c_complex, c_double_complex from devito.tools import dtype_to_cstr -from devito.petsc.utils import petsc_type_mappings + +from devito.petsc.config import petsc_type_mappings __all__ = ['CBB', 'CDataManager', 'COrchestrator'] @@ -82,3 +83,6 @@ class PetscCPrinter(CPrinter): _restrict_keyword = '' type_mappings = {**CPrinter.type_mappings, **petsc_type_mappings} + + def _print_Pi(self, expr): + return 'PETSC_PI' diff --git a/devito/petsc/config.py b/devito/petsc/config.py new file mode 100644 index 0000000000..e2fe3ed443 --- /dev/null +++ b/devito/petsc/config.py @@ -0,0 +1,86 @@ +import os +import ctypes +from pathlib import Path + +from petsctools import get_petscvariables, MissingPetscException + +from devito.tools import memoized_func + + +class PetscOSError(OSError): + pass + + +@memoized_func +def get_petsc_dir(): + petsc_dir = os.environ.get('PETSC_DIR') + if petsc_dir is None: + raise PetscOSError("PETSC_DIR environment variable not set") + else: + petsc_dir = (Path(petsc_dir),) + + petsc_arch = os.environ.get('PETSC_ARCH') + if petsc_arch is not None: + petsc_dir += (petsc_dir[0] / petsc_arch,) + + petsc_installed = petsc_dir[-1] / 'include' / 'petscconf.h' + if not petsc_installed.is_file(): + raise PetscOSError("PETSc is not installed") + + return petsc_dir + + +@memoized_func +def core_metadata(): + petsc_dir = get_petsc_dir() + + petsc_include = tuple([arch / 'include' for arch in petsc_dir]) + petsc_lib = tuple([arch / 'lib' for arch in petsc_dir]) + + return { + 'includes': ('petscsnes.h', 'petscdmda.h'), + 'include_dirs': petsc_include, + 'libs': ('petsc'), + 'lib_dirs': petsc_lib, + 'ldflags': tuple([f"-Wl,-rpath,{lib}" for lib in petsc_lib]) + } + + +try: + petsc_variables = get_petscvariables() +except MissingPetscException: + petsc_variables = {} + + +def get_petsc_type_mappings(): + try: + petsc_precision = petsc_variables['PETSC_PRECISION'] + except KeyError: + printer_mapper = {} + petsc_type_to_ctype = {} + else: + petsc_scalar = 'PetscScalar' + # TODO: Check to see whether Petsc is compiled with + # 32-bit or 64-bit integers + printer_mapper = {ctypes.c_int: 'PetscInt'} + + if petsc_precision == 'single': + printer_mapper[ctypes.c_float] = petsc_scalar + elif petsc_precision == 'double': + printer_mapper[ctypes.c_double] = petsc_scalar + + # Used to construct ctypes.Structures that wrap PETSc objects + petsc_type_to_ctype = {v: k for k, v in printer_mapper.items()} + # Add other PETSc types + petsc_type_to_ctype.update({ + 'KSPType': ctypes.c_char_p, + 'KSPConvergedReason': petsc_type_to_ctype['PetscInt'], + 'KSPNormType': petsc_type_to_ctype['PetscInt'], + }) + return printer_mapper, petsc_type_to_ctype + + +petsc_type_mappings, petsc_type_to_ctype = get_petsc_type_mappings() + + +petsc_languages = ['petsc'] diff --git a/devito/petsc/iet/builder.py b/devito/petsc/iet/builder.py new file mode 100644 index 0000000000..e1c178c059 --- /dev/null +++ b/devito/petsc/iet/builder.py @@ -0,0 +1,341 @@ +import math + +from devito.ir.iet import DummyExpr, BlankLine +from devito.symbolics import (Byref, FieldFromPointer, VOID, + FieldFromComposite, Null) + +from devito.petsc.iet.nodes import ( + FormFunctionCallback, MatShellSetOp, PETScCall, petsc_call +) + + +def make_core_petsc_calls(objs, comm): + call_mpi = petsc_call_mpi('MPI_Comm_size', [comm, Byref(objs['size'])]) + return call_mpi, BlankLine + + +class BuilderBase: + def __init__(self, **kwargs): + self.inject_solve = kwargs.get('inject_solve') + self.objs = kwargs.get('objs') + self.solver_objs = kwargs.get('solver_objs') + self.callback_builder = kwargs.get('callback_builder') + self.field_data = self.inject_solve.expr.rhs.field_data + self.formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix + self.calls = self._setup() + + @property + def snes_ctx(self): + """ + The [optional] context for private data for the function evaluation routine. + https://petsc.org/main/manualpages/SNES/SNESSetFunction/ + """ + return VOID(self.solver_objs['dmda'], stars='*') + + def _setup(self): + sobjs = self.solver_objs + dmda = sobjs['dmda'] + + snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])]) + + snes_options_prefix = petsc_call( + 'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snes_prefix']] + ) if self.formatted_prefix else None + + set_options = petsc_call( + self.callback_builder._set_options_efunc.name, [] + ) + + snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda]) + + create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])]) + + snes_set_jac = petsc_call( + 'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'], + sobjs['Jac'], 'MatMFFDComputeJacobian', Null] + ) + + global_x = petsc_call('DMCreateGlobalVector', + [dmda, Byref(sobjs['xglobal'])]) + + target = self.field_data.target + field_from_ptr = FieldFromPointer( + target.function._C_field_data, target.function._C_symbol + ) + + local_size = math.prod( + v for v, dim in zip(target.shape_allocated, target.dimensions) if dim.is_Space + ) + # TODO: Check - VecCreateSeqWithArray + local_x = petsc_call('VecCreateMPIWithArray', + [sobjs['comm'], 1, local_size, 'PETSC_DECIDE', + field_from_ptr, Byref(sobjs['xlocal'])]) + + # TODO: potentially also need to set the DM and local/global map to xlocal + + get_local_size = petsc_call('VecGetSize', + [sobjs['xlocal'], Byref(sobjs['localsize'])]) + + global_b = petsc_call('DMCreateGlobalVector', + [dmda, Byref(sobjs['bglobal'])]) + + snes_get_ksp = petsc_call('SNESGetKSP', + [sobjs['snes'], Byref(sobjs['ksp'])]) + + matvec = self.callback_builder.main_matvec_callback + matvec_operation = petsc_call( + 'MatShellSetOperation', + [sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)] + ) + formfunc = self.callback_builder._F_efunc + formfunc_operation = petsc_call( + 'SNESSetFunction', + [sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void), + self.snes_ctx] + ) + + snes_set_options = petsc_call( + 'SNESSetFromOptions', [sobjs['snes']] + ) + + dmda_calls = self._create_dmda_calls(dmda) + + mainctx = sobjs['userctx'] + + call_struct_callback = petsc_call( + self.callback_builder.user_struct_callback.name, [Byref(mainctx)] + ) + + # TODO: maybe don't need to explictly set this + mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda]) + + calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)]) + + base_setup = dmda_calls + ( + snes_create, + snes_options_prefix, + set_options, + snes_set_dm, + create_matrix, + snes_set_jac, + global_x, + local_x, + get_local_size, + global_b, + snes_get_ksp, + matvec_operation, + formfunc_operation, + snes_set_options, + call_struct_callback, + mat_set_dm, + calls_set_app_ctx, + BlankLine + ) + extended_setup = self._extend_setup() + return base_setup + extended_setup + + def _extend_setup(self): + """ + Hook for subclasses to add additional setup calls. + """ + return () + + def _create_dmda_calls(self, dmda): + dmda_create = self._create_dmda(dmda) + dm_setup = petsc_call('DMSetUp', [dmda]) + dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL']) + return dmda_create, dm_setup, dm_mat_type + + def _create_dmda(self, dmda): + sobjs = self.solver_objs + grid = self.field_data.grid + nspace_dims = len(grid.dimensions) + + # MPI communicator + args = [sobjs['comm']] + + # Type of ghost nodes + args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(nspace_dims)]) + + # Stencil type + if nspace_dims > 1: + args.append('DMDA_STENCIL_BOX') + + # Global dimensions + args.extend(list(grid.shape)[::-1]) + # No.of processors in each dimension + if nspace_dims > 1: + args.extend(list(grid.distributor.topology)[::-1]) + + # Number of degrees of freedom per node + args.append(dmda.dofs) + # "Stencil width" -> size of overlap + # TODO: Instead, this probably should be + # extracted from field_data.target._size_outhalo? + stencil_width = self.field_data.space_order + + args.append(stencil_width) + args.extend([Null]*nspace_dims) + + # The distributed array object + args.append(Byref(dmda)) + + # The PETSc call used to create the DMDA + dmda = petsc_call(f'DMDACreate{nspace_dims}d', args) + + return dmda + + +class CoupledBuilder(BuilderBase): + def _setup(self): + # TODO: minimise code duplication with superclass + objs = self.objs + sobjs = self.solver_objs + dmda = sobjs['dmda'] + + snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])]) + + snes_options_prefix = petsc_call( + 'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snes_prefix']] + ) if self.formatted_prefix else None + + set_options = petsc_call( + self.callback_builder._set_options_efunc.name, [] + ) + + snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda]) + + create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])]) + + snes_set_jac = petsc_call( + 'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'], + sobjs['Jac'], 'MatMFFDComputeJacobian', Null] + ) + + global_x = petsc_call('DMCreateGlobalVector', + [dmda, Byref(sobjs['xglobal'])]) + + local_x = petsc_call('DMCreateLocalVector', [dmda, Byref(sobjs['xlocal'])]) + + get_local_size = petsc_call('VecGetSize', + [sobjs['xlocal'], Byref(sobjs['localsize'])]) + + snes_get_ksp = petsc_call('SNESGetKSP', + [sobjs['snes'], Byref(sobjs['ksp'])]) + + matvec = self.callback_builder.main_matvec_callback + matvec_operation = petsc_call( + 'MatShellSetOperation', + [sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)] + ) + formfunc = self.callback_builder._F_efunc + formfunc_operation = petsc_call( + 'SNESSetFunction', + [sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void), + self.snes_ctx] + ) + + snes_set_options = petsc_call( + 'SNESSetFromOptions', [sobjs['snes']] + ) + + dmda_calls = self._create_dmda_calls(dmda) + + mainctx = sobjs['userctx'] + + call_struct_callback = petsc_call( + self.callback_builder.user_struct_callback.name, [Byref(mainctx)] + ) + + # TODO: maybe don't need to explictly set this + mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda]) + + calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)]) + + create_field_decomp = petsc_call( + 'DMCreateFieldDecomposition', + [dmda, Byref(sobjs['nfields']), Null, Byref(sobjs['fields']), + Byref(sobjs['subdms'])] + ) + submat_cb = self.callback_builder.submatrices_callback + matop_create_submats_op = petsc_call( + 'MatShellSetOperation', + [sobjs['Jac'], 'MATOP_CREATE_SUBMATRICES', + MatShellSetOp(submat_cb.name, void, void)] + ) + + call_coupled_struct_callback = petsc_call( + 'PopulateMatContext', + [Byref(sobjs['jacctx']), sobjs['subdms'], sobjs['fields']] + ) + + shell_set_ctx = petsc_call( + 'MatShellSetContext', [sobjs['Jac'], Byref(sobjs['jacctx']._C_symbol)] + ) + + create_submats = petsc_call( + 'MatCreateSubMatrices', + [sobjs['Jac'], sobjs['nfields'], sobjs['fields'], + sobjs['fields'], 'MAT_INITIAL_MATRIX', + Byref(FieldFromComposite(objs['Submats'].base, sobjs['jacctx']))] + ) + + targets = self.field_data.targets + + deref_dms = [ + DummyExpr(sobjs[f'da{t.name}'], sobjs['subdms'].indexed[i]) + for i, t in enumerate(targets) + ] + + xglobals = [petsc_call( + 'DMCreateGlobalVector', + [sobjs[f'da{t.name}'], Byref(sobjs[f'xglobal{t.name}'])] + ) for t in targets] + + xlocals = [] + for t in targets: + target_xloc = sobjs[f'xlocal{t.name}'] + local_size = math.prod( + v for v, dim in zip(t.shape_allocated, t.dimensions) if dim.is_Space + ) + field_from_ptr = FieldFromPointer( + t.function._C_field_data, t.function._C_symbol + ) + # TODO: Check - VecCreateSeqWithArray? + xlocals.append(petsc_call( + 'VecCreateMPIWithArray', + [sobjs['comm'], 1, local_size, 'PETSC_DECIDE', + field_from_ptr, Byref(target_xloc)] + )) + + coupled_setup = dmda_calls + ( + snes_create, + snes_options_prefix, + set_options, + snes_set_dm, + create_matrix, + snes_set_jac, + global_x, + local_x, + get_local_size, + snes_get_ksp, + matvec_operation, + formfunc_operation, + snes_set_options, + call_struct_callback, + mat_set_dm, + calls_set_app_ctx, + create_field_decomp, + matop_create_submats_op, + call_coupled_struct_callback, + shell_set_ctx, + create_submats) + \ + tuple(deref_dms) + tuple(xglobals) + tuple(xlocals) + (BlankLine,) + return coupled_setup + + +def petsc_call_mpi(specific_call, call_args): + return PETScCall('PetscCallMPI', [PETScCall(specific_call, arguments=call_args)]) + + +void = VOID._dtype diff --git a/devito/petsc/iet/callbacks.py b/devito/petsc/iet/callbacks.py new file mode 100644 index 0000000000..4d06063242 --- /dev/null +++ b/devito/petsc/iet/callbacks.py @@ -0,0 +1,1132 @@ +from collections import OrderedDict + +from devito.ir.iet import ( + Call, FindSymbols, List, Uxreplace, CallableBody, Dereference, DummyExpr, + BlankLine, Callable, Iteration, PointerCast, Definition +) +from devito.symbolics import ( + Byref, FieldFromPointer, IntDiv, Deref, Mod, String, Null, VOID +) +from devito.symbolics.unevaluation import Mul +from devito.types.basic import AbstractFunction +from devito.types import Dimension, Temp, TempArray +from devito.tools import filter_ordered + +from devito.petsc.iet.nodes import PETScCallable, MatShellSetOp, petsc_call +from devito.petsc.types import DMCast, MainUserStruct, CallbackUserStruct +from devito.petsc.iet.type_builder import objs +from devito.petsc.types.macros import petsc_func_begin_user +from devito.petsc.types.modes import InsertMode + + +class BaseCallbackBuilder: + """ + Build IET routines to generate PETSc callback functions. + """ + def __init__(self, **kwargs): + + self.rcompile = kwargs.get('rcompile', None) + self.sregistry = kwargs.get('sregistry', None) + self.concretize_mapper = kwargs.get('concretize_mapper', {}) + self.time_dependence = kwargs.get('time_dependence') + self.objs = kwargs.get('objs') + self.solver_objs = kwargs.get('solver_objs') + self.inject_solve = kwargs.get('inject_solve') + self.solve_expr = self.inject_solve.expr.rhs + + self._efuncs = OrderedDict() + self._struct_params = [] + + self._set_options_efunc = None + self._clear_options_efunc = None + self._main_matvec_callback = None + self._user_struct_callback = None + self._F_efunc = None + self._b_efunc = None + + self._J_efuncs = [] + self._initial_guesses = [] + + self._make_core() + self._efuncs = self._uxreplace_efuncs() + + @property + def efuncs(self): + return self._efuncs + + @property + def struct_params(self): + return self._struct_params + + @property + def filtered_struct_params(self): + return filter_ordered(self.struct_params) + + @property + def main_matvec_callback(self): + """ + The matrix-vector callback for the full Jacobian. + This is the function set in the main Kernel via: + PetscCall(MatShellSetOperation(J, MATOP_MULT, (void (*)(void))...)); + The callback has the signature `(Mat, Vec, Vec)`. + """ + return self._J_efuncs[0] + + @property + def J_efuncs(self): + """ + List of matrix-vector callbacks. + Each callback has the signature `(Mat, Vec, Vec)`. Typically, this list + contains a single element, but in mixed systems it can include multiple + callbacks, one for each subblock. + """ + return self._J_efuncs + + @property + def initial_guesses(self): + return self._initial_guesses + + @property + def user_struct_callback(self): + return self._user_struct_callback + + @property + def solver_parameters(self): + return self.solve_expr.solver_parameters + + @property + def field_data(self): + return self.solve_expr.field_data + + @property + def formatted_prefix(self): + return self.solve_expr.formatted_prefix + + @property + def arrays(self): + return self.field_data.arrays + + @property + def target(self): + return self.field_data.target + + def _make_core(self): + self._make_options_callback() + self._make_matvec(self.field_data.jacobian) + self._make_formfunc() + self._make_formrhs() + if self.field_data.initial_guess.exprs: + self._make_initial_guess() + self._make_user_struct_callback() + + def _make_petsc_callable(self, prefix, body, parameters=()): + return PETScCallable( + self.sregistry.make_name(prefix=prefix), + body, + retval=self.objs['err'], + parameters=parameters + ) + + def _make_callable_body(self, body, standalones=(), stacks=(), casts=()): + return CallableBody( + List(body=body), + init=(petsc_func_begin_user,), + standalones=standalones, + stacks=stacks, + casts=casts, + retstmt=(Call('PetscFunctionReturn', arguments=[0]),) + ) + + def _make_options_callback(self): + """ + Create two callbacks: one to set PETSc options and one + to clear them. + Options are only set/cleared if they were not specifed via + command line arguments. + """ + params = self.solver_parameters + prefix = self.inject_solve.expr.rhs.formatted_prefix + + set_body, clear_body = [], [] + + for k, v in params.items(): + option = f'-{prefix}{k}' + + # TODO: Revisit use of a global variable here. + # Consider replacing this with a call to `PetscGetArgs`, though + # initial attempts failed, possibly because the argv pointer is + # created in Python?.. + import devito.petsc.initialize + if option in devito.petsc.initialize._petsc_clargs: + # Ensures that the command line args take priority + continue + + option_name = String(option) + # For options without a value e.g `ksp_view`, pass Null + option_value = Null if v is None else String(str(v)) + set_body.append( + petsc_call('PetscOptionsSetValue', [Null, option_name, option_value]) + ) + clear_body.append( + petsc_call('PetscOptionsClearValue', [Null, option_name]) + ) + + set_body = self._make_callable_body(set_body) + clear_body = self._make_callable_body(clear_body) + + set_callback = self._make_petsc_callable('SetPetscOptions', set_body) + clear_callback = self._make_petsc_callable('ClearPetscOptions', clear_body) + + self._set_options_efunc = set_callback + self._efuncs[set_callback.name] = set_callback + self._clear_options_efunc = clear_callback + self._efuncs[clear_callback.name] = clear_callback + + def _make_matvec(self, jacobian, prefix='MatMult'): + # Compile `matvecs` into an IET via recursive compilation + matvecs = jacobian.matvecs + irs, _ = self.rcompile( + matvecs, options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper + ) + body = self._create_matvec_body( + List(body=irs.uiet.body), jacobian + ) + objs = self.objs + cb = self._make_petsc_callable( + prefix, body, parameters=(objs['J'], objs['X'], objs['Y']) + ) + self._J_efuncs.append(cb) + self._efuncs[cb.name] = cb + + def _create_matvec_body(self, body, jacobian): + linsolve_expr = self.inject_solve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + xlocal = objs['xloc'] + ylocal = objs['yloc'] + y_matvec = self.arrays[jacobian.row_target]['y'] + x_matvec = self.arrays[jacobian.col_target]['x'] + + body = self.time_dependence.uxreplace_time(body) + + fields = get_user_struct_fields(body) + + mat_get_dm = petsc_call('MatGetDM', [objs['J'], Byref(dmda)]) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + zero_y_memory = zero_vector(objs['Y']) if jacobian.zero_memory else None + + dm_get_local_xvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(xlocal)] + ) + + global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, objs['X'], insert_values, xlocal] + ) + + global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ + dmda, objs['X'], insert_values, xlocal + ]) + + dm_get_local_yvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(ylocal)] + ) + + zero_ylocal_memory = zero_vector(ylocal) + + vec_get_array_y = petsc_call( + 'VecGetArray', [ylocal, Byref(y_matvec._C_symbol)] + ) + + vec_get_array_x = petsc_call( + 'VecGetArray', [xlocal, Byref(x_matvec._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + vec_restore_array_y = petsc_call( + 'VecRestoreArray', [ylocal, Byref(y_matvec._C_symbol)] + ) + + vec_restore_array_x = petsc_call( + 'VecRestoreArray', [xlocal, Byref(x_matvec._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, ylocal, add_values, objs['Y'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, ylocal, add_values, objs['Y'] + ]) + + dm_restore_local_xvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(xlocal)] + ) + + dm_restore_local_yvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(ylocal)] + ) + + # TODO: Some of the calls are placed in the `stacks` argument of the + # `CallableBody` to ensure that they precede the `cast` statements. The + # 'casts' depend on the calls, so this order is necessary. By doing this, + # you avoid having to manually construct the `casts` and can allow + # Devito to handle their construction. This is a temporary solution and + # should be revisited + + body = body._rebuild( + body=body.body + + (vec_restore_array_y, + vec_restore_array_x, + dm_local_to_global_begin, + dm_local_to_global_end, + dm_restore_local_xvec, + dm_restore_local_yvec) + ) + + stacks = ( + zero_y_memory, + dm_get_local_xvec, + global_to_local_begin, + global_to_local_end, + dm_get_local_yvec, + zero_ylocal_memory, + vec_get_array_y, + vec_get_array_x, + dm_get_local_info + ) + + # Dereference function data in struct + derefs = dereference_funcs(ctx, fields) + + # Force the struct definition to appear at the very start, since + # stacks, allocs etc may rely on its information + struct_definition = [ + Definition(ctx), Definition(dmda), mat_get_dm, dm_get_app_context + ] + + body = self._make_callable_body( + body, standalones=struct_definition, stacks=stacks+derefs + ) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} + body = Uxreplace(subs).visit(body) + + self._struct_params.extend(fields) + return body + + def _make_formfunc(self): + objs = self.objs + F_exprs = self.field_data.residual.F_exprs + # Compile `F_exprs` into an IET via recursive compilation + irs, _ = self.rcompile( + F_exprs, options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper + ) + body_formfunc = self._create_formfunc_body( + List(body=irs.uiet.body) + ) + parameters = (objs['snes'], objs['X'], objs['F'], objs['dummyptr']) + cb = self._make_petsc_callable('FormFunction', body_formfunc, parameters) + + self._F_efunc = cb + self._efuncs[cb.name] = cb + + def _create_formfunc_body(self, body): + linsolve_expr = self.inject_solve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + arrays = self.arrays + target = self.target + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + + body = self.time_dependence.uxreplace_time(body) + + fields = get_user_struct_fields(body) + self._struct_params.extend(fields) + + f_formfunc = arrays[target]['f'] + x_formfunc = arrays[target]['x'] + + dm_cast = DummyExpr(dmda, DMCast(objs['dummyptr']), init=True) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + zero_f_memory = zero_vector(objs['F']) + + dm_get_local_xvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(objs['xloc'])] + ) + + global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, objs['X'], insert_values, objs['xloc']] + ) + + global_to_local_end = petsc_call( + 'DMGlobalToLocalEnd', [dmda, objs['X'], insert_values, objs['xloc']] + ) + + dm_get_local_yvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(objs['floc'])] + ) + + vec_get_array_y = petsc_call( + 'VecGetArray', [objs['floc'], Byref(f_formfunc._C_symbol)] + ) + + vec_get_array_x = petsc_call( + 'VecGetArray', [objs['xloc'], Byref(x_formfunc._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + vec_restore_array_y = petsc_call( + 'VecRestoreArray', [objs['floc'], Byref(f_formfunc._C_symbol)] + ) + + vec_restore_array_x = petsc_call( + 'VecRestoreArray', [objs['xloc'], Byref(x_formfunc._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, objs['floc'], add_values, objs['F'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, objs['floc'], add_values, objs['F'] + ]) + + dm_restore_local_xvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(objs['xloc'])] + ) + + dm_restore_local_yvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(objs['floc'])] + ) + + body = body._rebuild( + body=body.body + + (vec_restore_array_y, + vec_restore_array_x, + dm_local_to_global_begin, + dm_local_to_global_end, + dm_restore_local_xvec, + dm_restore_local_yvec) + ) + + stacks = ( + zero_f_memory, + dm_get_local_xvec, + global_to_local_begin, + global_to_local_end, + dm_get_local_yvec, + vec_get_array_y, + vec_get_array_x, + dm_get_local_info + ) + + # Dereference function data in struct + derefs = dereference_funcs(ctx, fields) + + # Force the struct definition to appear at the very start, since + # stacks, allocs etc may rely on its information + struct_definition = [Definition(ctx), dm_cast, dm_get_app_context] + + body = self._make_callable_body( + body, standalones=struct_definition, stacks=stacks+derefs + ) + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} + + return Uxreplace(subs).visit(body) + + def _make_formrhs(self): + b_exprs = self.field_data.residual.b_exprs + sobjs = self.solver_objs + + # Compile `b_exprs` into an IET via recursive compilation + irs, _ = self.rcompile( + b_exprs, options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper + ) + body = self._create_form_rhs_body( + List(body=irs.uiet.body) + ) + objs = self.objs + cb = self._make_petsc_callable( + 'FormRHS', body, parameters=(sobjs['callbackdm'], objs['B']) + ) + self._b_efunc = cb + self._efuncs[cb.name] = cb + + def _create_form_rhs_body(self, body): + linsolve_expr = self.inject_solve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + target = self.target + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + + dm_get_local = petsc_call( + 'DMGetLocalVector', [dmda, Byref(sobjs['blocal'])] + ) + + dm_global_to_local_begin = petsc_call( + 'DMGlobalToLocalBegin', [dmda, objs['B'], insert_values, sobjs['blocal']] + ) + + dm_global_to_local_end = petsc_call( + 'DMGlobalToLocalEnd', [dmda, objs['B'], insert_values, sobjs['blocal']] + ) + + b_arr = self.field_data.arrays[target]['b'] + + vec_get_array = petsc_call( + 'VecGetArray', [sobjs['blocal'], Byref(b_arr._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + body = self.time_dependence.uxreplace_time(body) + + fields = get_user_struct_fields(body) + self._struct_params.extend(fields) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, sobjs['blocal'], insert_values, objs['B'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, sobjs['blocal'], insert_values, objs['B'] + ]) + + vec_restore_array = petsc_call( + 'VecRestoreArray', [sobjs['blocal'], Byref(b_arr._C_symbol)] + ) + + dm_restore_local_bvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(sobjs['blocal'])] + ) + + body = body._rebuild(body=body.body + ( + dm_local_to_global_begin, dm_local_to_global_end, vec_restore_array, + dm_restore_local_bvec + )) + + stacks = ( + dm_get_local, + dm_global_to_local_begin, + dm_global_to_local_end, + vec_get_array, + dm_get_local_info + ) + + # Dereference function data in struct + derefs = dereference_funcs(ctx, fields) + + # Force the struct definition to appear at the very start, since + # stacks, allocs etc may rely on its information + struct_definition = [Definition(ctx), dm_get_app_context] + + body = self._make_callable_body( + [body], standalones=struct_definition, stacks=stacks+derefs + ) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for + i in fields if not isinstance(i.function, AbstractFunction)} + + return Uxreplace(subs).visit(body) + + def _make_initial_guess(self): + exprs = self.field_data.initial_guess.exprs + sobjs = self.solver_objs + objs = self.objs + + # Compile initital guess `eqns` into an IET via recursive compilation + irs, _ = self.rcompile( + exprs, options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper + ) + body = self._create_initial_guess_body( + List(body=irs.uiet.body) + ) + cb = self._make_petsc_callable( + 'FormInitialGuess', body, parameters=(sobjs['callbackdm'], objs['xloc']) + ) + self._initial_guesses.append(cb) + self._efuncs[cb.name] = cb + + def _create_initial_guess_body(self, body): + linsolve_expr = self.inject_solve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + target = self.target + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + + x_arr = self.field_data.arrays[target]['x'] + + vec_get_array = petsc_call( + 'VecGetArray', [objs['xloc'], Byref(x_arr._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + body = self.time_dependence.uxreplace_time(body) + + fields = get_user_struct_fields(body) + self._struct_params.extend(fields) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + vec_restore_array = petsc_call( + 'VecRestoreArray', [objs['xloc'], Byref(x_arr._C_symbol)] + ) + + body = body._rebuild(body=body.body + (vec_restore_array,)) + + stacks = ( + vec_get_array, + dm_get_local_info + ) + + # Dereference function data in struct + derefs = dereference_funcs(ctx, fields) + + # Force the struct definition to appear at the very start, since + # stacks, allocs etc may rely on its information + struct_definition = [Definition(ctx), dm_get_app_context] + + body = self._make_callable_body( + body, standalones=struct_definition, stacks=stacks+derefs + ) + + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for + i in fields if not isinstance(i.function, AbstractFunction)} + + return Uxreplace(subs).visit(body) + + def _make_user_struct_callback(self): + """ + This is the struct initialised inside the main kernel and + attached to the DM via DMSetApplicationContext. + """ + mainctx = self.solver_objs['userctx'] = MainUserStruct( + name=self.sregistry.make_name(prefix='ctx'), + pname=self.sregistry.make_name(prefix='UserCtx'), + fields=self.filtered_struct_params, + liveness='lazy', + modifier=None + ) + body = [ + DummyExpr(FieldFromPointer(i._C_symbol, mainctx), i._C_symbol) + for i in mainctx.callback_fields + ] + struct_callback_body = self._make_callable_body(body) + cb = Callable( + self.sregistry.make_name(prefix='PopulateUserContext'), + struct_callback_body, self.objs['err'], + parameters=[mainctx] + ) + self._efuncs[cb.name] = cb + self._user_struct_callback = cb + + def _uxreplace_efuncs(self): + sobjs = self.solver_objs + callback_user_struct = CallbackUserStruct( + name=sobjs['userctx'].name, + pname=sobjs['userctx'].pname, + fields=self.filtered_struct_params, + liveness='lazy', + modifier=' *', + parent=sobjs['userctx'] + ) + mapper = {} + visitor = Uxreplace({self.objs['dummyctx']: callback_user_struct}) + for k, v in self._efuncs.items(): + mapper.update({k: visitor.visit(v)}) + return mapper + + +class CoupledCallbackBuilder(BaseCallbackBuilder): + def __init__(self, **kwargs): + self._submatrices_callback = None + super().__init__(**kwargs) + + @property + def submatrices_callback(self): + return self._submatrices_callback + + @property + def jacobian(self): + return self.inject_solve.expr.rhs.field_data.jacobian + + @property + def main_matvec_callback(self): + """ + This is the matrix-vector callback associated with the whole Jacobian i.e + is set in the main kernel via + `PetscCall(MatShellSetOperation(J,MATOP_MULT,(void (*)(void))MyMatShellMult));` + """ + return self._main_matvec_callback + + def _make_core(self): + for sm in self.field_data.jacobian.nonzero_submatrices: + self._make_matvec(sm, prefix=f'{sm.name}_MatMult') + + self._make_options_callback() + self._make_whole_matvec() + self._make_whole_formfunc() + self._make_user_struct_callback() + self._create_submatrices() + self._efuncs['PopulateMatContext'] = self.objs['dummyefunc'] + + def _make_whole_matvec(self): + objs = self.objs + body = self._whole_matvec_body() + + parameters = (objs['J'], objs['X'], objs['Y']) + cb = self._make_petsc_callable( + 'WholeMatMult', List(body=body), parameters=parameters + ) + self._main_matvec_callback = cb + self._efuncs[cb.name] = cb + + def _whole_matvec_body(self): + objs = self.objs + sobjs = self.solver_objs + + jctx = objs['ljacctx'] + ctx_main = petsc_call('MatShellGetContext', [objs['J'], Byref(jctx)]) + + nonzero_submats = self.jacobian.nonzero_submatrices + + zero_y_memory = zero_vector(objs['Y']) + + calls = () + for sm in nonzero_submats: + name = sm.name + ctx = sobjs[f'{name}ctx'] + X = sobjs[f'{name}X'] + Y = sobjs[f'{name}Y'] + rows = objs['rows'].base + cols = objs['cols'].base + sm_indexed = objs['Submats'].indexed[sm.linear_idx] + + calls += ( + DummyExpr(sobjs[name], FieldFromPointer(sm_indexed, jctx)), + petsc_call('MatShellGetContext', [sobjs[name], Byref(ctx)]), + petsc_call( + 'VecGetSubVector', + [objs['X'], Deref(FieldFromPointer(cols, ctx)), Byref(X)] + ), + petsc_call( + 'VecGetSubVector', + [objs['Y'], Deref(FieldFromPointer(rows, ctx)), Byref(Y)] + ), + petsc_call('MatMult', [sobjs[name], X, Y]), + petsc_call( + 'VecRestoreSubVector', + [objs['X'], Deref(FieldFromPointer(cols, ctx)), Byref(X)] + ), + petsc_call( + 'VecRestoreSubVector', + [objs['Y'], Deref(FieldFromPointer(rows, ctx)), Byref(Y)] + ), + ) + body = (ctx_main, zero_y_memory, BlankLine) + calls + return self._make_callable_body(body) + + def _make_whole_formfunc(self): + objs = self.objs + F_exprs = self.field_data.residual.F_exprs + # Compile `F_exprs` into an IET via recursive compilation + irs, _ = self.rcompile( + F_exprs, options={'mpi': False}, sregistry=self.sregistry, + concretize_mapper=self.concretize_mapper + ) + body = self._whole_formfunc_body(List(body=irs.uiet.body)) + + parameters = (objs['snes'], objs['X'], objs['F'], objs['dummyptr']) + cb = self._make_petsc_callable( + 'WholeFormFunc', body, parameters=parameters + ) + + self._F_efunc = cb + self._efuncs[cb.name] = cb + + def _whole_formfunc_body(self, body): + linsolve_expr = self.inject_solve.expr.rhs + objs = self.objs + sobjs = self.solver_objs + + dmda = sobjs['callbackdm'] + ctx = objs['dummyctx'] + + body = self.time_dependence.uxreplace_time(body) + + fields = get_user_struct_fields(body) + self._struct_params.extend(fields) + + # Process body with bundles for residual callback + bundles = sobjs['bundles'] + fbundle = bundles['f'] + xbundle = bundles['x'] + + body = self.residual_bundle(body, bundles) + + dm_cast = DummyExpr(dmda, DMCast(objs['dummyptr']), init=True) + + dm_get_app_context = petsc_call( + 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] + ) + + zero_f_memory = zero_vector(objs['F']) + + dm_get_local_xvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(objs['xloc'])] + ) + + global_to_local_begin = petsc_call('DMGlobalToLocalBegin', [ + dmda, objs['X'], insert_values, objs['xloc'] + ]) + + global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ + dmda, objs['X'], insert_values, objs['xloc'] + ]) + + dm_get_local_yvec = petsc_call( + 'DMGetLocalVector', [dmda, Byref(objs['floc'])] + ) + + vec_get_array_f = petsc_call( + 'VecGetArray', [objs['floc'], Byref(fbundle.vector._C_symbol)] + ) + + vec_get_array_x = petsc_call( + 'VecGetArray', [objs['xloc'], Byref(xbundle.vector._C_symbol)] + ) + + dm_get_local_info = petsc_call( + 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] + ) + + vec_restore_array_f = petsc_call( + 'VecRestoreArray', [objs['floc'], Byref(fbundle.vector._C_symbol)] + ) + + vec_restore_array_x = petsc_call( + 'VecRestoreArray', [objs['xloc'], Byref(xbundle.vector._C_symbol)] + ) + + dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ + dmda, objs['floc'], add_values, objs['F'] + ]) + + dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ + dmda, objs['floc'], add_values, objs['F'] + ]) + + dm_restore_local_xvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(objs['xloc'])] + ) + + dm_restore_local_yvec = petsc_call( + 'DMRestoreLocalVector', [dmda, Byref(objs['floc'])] + ) + + body = body._rebuild( + body=body.body + + (vec_restore_array_f, + vec_restore_array_x, + dm_local_to_global_begin, + dm_local_to_global_end, + dm_restore_local_xvec, + dm_restore_local_yvec) + ) + + stacks = ( + zero_f_memory, + dm_get_local_xvec, + global_to_local_begin, + global_to_local_end, + dm_get_local_yvec, + vec_get_array_f, + vec_get_array_x, + dm_get_local_info + ) + + # Dereference function data in struct + derefs = dereference_funcs(ctx, fields) + + # Force the struct definition to appear at the very start, since + # stacks, allocs etc may rely on its information + struct_definition = [Definition(ctx), dm_cast, dm_get_app_context] + + f_soa = PointerCast(fbundle) + x_soa = PointerCast(xbundle) + + formfunc_body = self._make_callable_body( + body, + standalones=struct_definition, + stacks=stacks+derefs, + casts=(f_soa, x_soa), + ) + # Replace non-function data with pointer to data in struct + subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} + + return Uxreplace(subs).visit(formfunc_body) + + def _create_submatrices(self): + body = self._submat_callback_body() + objs = self.objs + params = ( + objs['J'], + objs['nfields'], + objs['irow'], + objs['icol'], + objs['matreuse'], + objs['Submats'], + ) + cb = self._make_petsc_callable( + 'MatCreateSubMatrices', body, parameters=params) + + self._submatrices_callback = cb + self._efuncs[cb.name] = cb + + def _submat_callback_body(self): + objs = self.objs + sobjs = self.solver_objs + + n_submats = DummyExpr( + objs['nsubmats'], Mul(objs['nfields'], objs['nfields']) + ) + + malloc_submats = petsc_call( + 'PetscCalloc1', [objs['nsubmats'], objs['Submats']._C_symbol] + ) + + mat_get_dm = petsc_call('MatGetDM', [objs['J'], Byref(sobjs['callbackdm'])]) + + dm_get_app = petsc_call( + 'DMGetApplicationContext', [sobjs['callbackdm'], Byref(objs['dummyctx'])] + ) + + get_ctx = petsc_call('MatShellGetContext', [objs['J'], Byref(objs['ljacctx'])]) + + dm_get_info = petsc_call( + 'DMDAGetInfo', [ + sobjs['callbackdm'], Null, Byref(sobjs['M']), Byref(sobjs['N']), + Null, Null, Null, Null, Byref(objs['dof']), Null, Null, Null, Null, Null + ] + ) + subblock_rows = DummyExpr(objs['subblockrows'], Mul(sobjs['M'], sobjs['N'])) + subblock_cols = DummyExpr(objs['subblockcols'], Mul(sobjs['M'], sobjs['N'])) + + ptr = DummyExpr( + objs['submat_arr']._C_symbol, Deref(objs['Submats']._C_symbol), init=True + ) + + mat_create = petsc_call('MatCreate', [sobjs['comm'], Byref(objs['block'])]) + + mat_set_sizes = petsc_call( + 'MatSetSizes', [ + objs['block'], 'PETSC_DECIDE', 'PETSC_DECIDE', + objs['subblockrows'], objs['subblockcols'] + ] + ) + + mat_set_type = petsc_call('MatSetType', [objs['block'], 'MATSHELL']) + + malloc = petsc_call('PetscMalloc1', [1, Byref(objs['subctx'])]) + i = Dimension(name='i') + + row_idx = DummyExpr(objs['rowidx'], IntDiv(i, objs['dof'])) + col_idx = DummyExpr(objs['colidx'], Mod(i, objs['dof'])) + + deref_subdm = Dereference(objs['Subdms'], objs['ljacctx']) + + set_rows = DummyExpr( + FieldFromPointer(objs['rows'].base, objs['subctx']), + Byref(objs['irow'].indexed[objs['rowidx']]) + ) + set_cols = DummyExpr( + FieldFromPointer(objs['cols'].base, objs['subctx']), + Byref(objs['icol'].indexed[objs['colidx']]) + ) + dm_set_ctx = petsc_call( + 'DMSetApplicationContext', [ + objs['Subdms'].indexed[objs['rowidx']], objs['dummyctx'] + ] + ) + matset_dm = petsc_call('MatSetDM', [ + objs['block'], objs['Subdms'].indexed[objs['rowidx']] + ]) + + set_ctx = petsc_call('MatShellSetContext', [objs['block'], objs['subctx']]) + + mat_setup = petsc_call('MatSetUp', [objs['block']]) + + assign_block = DummyExpr(objs['submat_arr'].indexed[i], objs['block']) + + iter_body = ( + mat_create, + mat_set_sizes, + mat_set_type, + malloc, + row_idx, + col_idx, + set_rows, + set_cols, + dm_set_ctx, + matset_dm, + set_ctx, + mat_setup, + assign_block + ) + + upper_bound = objs['nsubmats'] - 1 + iteration = Iteration(List(body=iter_body), i, upper_bound) + + nonzero_submats = self.jacobian.nonzero_submatrices + matvec_lookup = {mv.name.split('_')[0]: mv for mv in self.J_efuncs} + + matmult_op = [ + petsc_call( + 'MatShellSetOperation', + [ + objs['submat_arr'].indexed[sb.linear_idx], + 'MATOP_MULT', + MatShellSetOp(matvec_lookup[sb.name].name, VOID._dtype, VOID._dtype), + ], + ) + for sb in nonzero_submats if sb.name in matvec_lookup + ] + + body = [ + n_submats, + malloc_submats, + mat_get_dm, + dm_get_app, + dm_get_info, + subblock_rows, + subblock_cols, + ptr, + BlankLine, + iteration, + ] + matmult_op + return self._make_callable_body(tuple(body), stacks=(get_ctx, deref_subdm)) + + def residual_bundle(self, body, bundles): + """ + Replaces PetscArrays in `body` with PetscBundle struct field accesses + (e.g., f_v[ix][iy] -> f_bundle[ix][iy].v). + Example: + f_v[ix][iy] = x_v[ix][iy]; + f_u[ix][iy] = x_u[ix][iy]; + becomes: + f_bundle[ix][iy].v = x_bundle[ix][iy].v; + f_bundle[ix][iy].u = x_bundle[ix][iy].u; + NOTE: This is used because the data is interleaved for + multi-component DMDAs in PETSc. + """ + mapper = bundles['bundle_mapper'] + indexeds = FindSymbols('indexeds').visit(body) + subs = {} + + for i in indexeds: + if i.base in mapper: + bundle = mapper[i.base] + index = bundles['target_indices'][i.function.target] + index = (index,) + i.indices + subs[i] = bundle.__getitem__(index) + + body = Uxreplace(subs).visit(body) + return body + + +def populate_matrix_context(efuncs): + if not objs['dummyefunc'] in efuncs.values(): + return + + subdms_expr = DummyExpr( + FieldFromPointer(objs['Subdms']._C_symbol, objs['ljacctx']), + objs['Subdms']._C_symbol + ) + fields_expr = DummyExpr( + FieldFromPointer(objs['Fields']._C_symbol, objs['ljacctx']), + objs['Fields']._C_symbol + ) + body = CallableBody( + List(body=[subdms_expr, fields_expr]), + init=(petsc_func_begin_user,), + retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])]) + ) + name = 'PopulateMatContext' + efuncs[name] = Callable( + name, body, objs['err'], + parameters=[objs['ljacctx'], objs['Subdms'], objs['Fields']] + ) + + +def dereference_funcs(struct, fields): + """ + Dereference AbstractFunctions from a struct. + """ + return tuple( + [Dereference(i, struct) for i in + fields if isinstance(i.function, AbstractFunction)] + ) + + +def zero_vector(vec): + """ + Set all entries of a PETSc vector to zero. + """ + return petsc_call('VecSet', [vec, 0.0]) + + +def get_user_struct_fields(iet): + fields = [f.function for f in FindSymbols('basics').visit(iet)] + from devito.types.basic import LocalType + avoid = (Temp, TempArray, LocalType) + fields = [f for f in fields if not isinstance(f.function, avoid)] + fields = [ + f for f in fields if not (f.is_Dimension and not (f.is_Time or f.is_Modulo)) + ] + return fields + + +insert_values = InsertMode.insert_values +add_values = InsertMode.add_values diff --git a/devito/petsc/iet/logging.py b/devito/petsc/iet/logging.py index 68f31cfa91..65e2ec2be8 100644 --- a/devito/petsc/iet/logging.py +++ b/devito/petsc/iet/logging.py @@ -5,7 +5,7 @@ from devito.logger import PERF from devito.tools import frozendict -from devito.petsc.iet.utils import petsc_call +from devito.petsc.iet.nodes import petsc_call from devito.petsc.logging import petsc_return_variable_dict, PetscInfo diff --git a/devito/petsc/iet/nodes.py b/devito/petsc/iet/nodes.py index abb5da3acd..70508970c3 100644 --- a/devito/petsc/iet/nodes.py +++ b/devito/petsc/iet/nodes.py @@ -29,3 +29,12 @@ def callback_form(self): class PETScCall(Call): pass + + +def petsc_call(specific_call, call_args): + return PETScCall('PetscCall', [PETScCall(specific_call, arguments=call_args)]) + + +# Mapping special Eq operations to their corresponding IET Expression subclass types. +# These operations correspond to subclasses of `Eq`` utilised within `petscsolve``. +petsc_iet_mapper = {OpPetsc: PetscMetaData} diff --git a/devito/petsc/iet/passes.py b/devito/petsc/iet/passes.py index eb6cf48a1f..5154afe43d 100644 --- a/devito/petsc/iet/passes.py +++ b/devito/petsc/iet/passes.py @@ -3,34 +3,35 @@ from functools import cached_property from devito.passes.iet.engine import iet_pass -from devito.ir.iet import (Transformer, MapNodes, Iteration, BlankLine, - DummyExpr, CallableBody, List, Call, Callable, - FindNodes, Section) -from devito.symbolics import Byref, FieldFromPointer, Macro, Null -from devito.types import Symbol, Scalar +from devito.ir.iet import ( + Transformer, MapNodes, Iteration, CallableBody, List, Call, FindNodes, Section, + FindSymbols, DummyExpr, Uxreplace, Dereference +) +from devito.symbolics import Byref, Macro, Null, FieldFromPointer from devito.types.basic import DataSymbol -from devito.tools import frozendict import devito.logger -from devito.petsc.types import (PetscMPIInt, PetscErrorCode, MultipleFieldData, - PointerIS, Mat, CallbackVec, Vec, CallbackMat, SNES, - DummyArg, PetscInt, PointerDM, PointerMat, MatReuse, - CallbackPointerIS, CallbackPointerDM, JacobianStruct, - SubMatrixStruct, Initialize, Finalize, ArgvSymbol) +from devito.petsc.types import ( + MultipleFieldData, Initialize, Finalize, ArgvSymbol, MainUserStruct, + CallbackUserStruct +) from devito.petsc.types.macros import petsc_func_begin_user -from devito.petsc.iet.nodes import PetscMetaData -from devito.petsc.utils import core_metadata, petsc_languages -from devito.petsc.iet.routines import (CBBuilder, CCBBuilder, BaseObjectBuilder, - CoupledObjectBuilder, BaseSetup, CoupledSetup, - Solver, CoupledSolver, TimeDependent, - NonTimeDependent) +from devito.petsc.iet.nodes import PetscMetaData, petsc_call +from devito.petsc.config import core_metadata, petsc_languages +from devito.petsc.iet.callbacks import ( + BaseCallbackBuilder, CoupledCallbackBuilder, populate_matrix_context, + get_user_struct_fields +) +from devito.petsc.iet.type_builder import BaseTypeBuilder, CoupledTypeBuilder, objs +from devito.petsc.iet.builder import BuilderBase, CoupledBuilder, make_core_petsc_calls +from devito.petsc.iet.solve import Solve, CoupledSolve +from devito.petsc.iet.time_dependence import TimeDependent, TimeIndependent from devito.petsc.iet.logging import PetscLogger -from devito.petsc.iet.utils import petsc_call, petsc_call_mpi @iet_pass def lower_petsc(iet, **kwargs): - # Check if PETScSolve was used + # Check if `petscsolve` was used inject_solve_mapper = MapNodes(Iteration, PetscMetaData, 'groupby').visit(iet) @@ -52,21 +53,24 @@ def lower_petsc(iet, **kwargs): return finalize(iet), core_metadata() unique_grids = {i.expr.rhs.grid for (i,) in inject_solve_mapper.values()} - # Assumption is that all solves are on the same grid + # Assumption is that all solves are on the same `Grid` if len(unique_grids) > 1: - raise ValueError("All PETScSolves must use the same Grid, but multiple found.") + raise ValueError( + "All `petscsolve` calls must use the same `Grid`, " + "but multiple `Grid`s were found." + ) grid = unique_grids.pop() devito_mpi = kwargs['options'].get('mpi', False) comm = grid.distributor._obj_comm if devito_mpi else 'PETSC_COMM_WORLD' - # Create core PETSc calls (not specific to each PETScSolve) + # Create core PETSc calls (not specific to each `petscsolve`) core = make_core_petsc_calls(objs, comm) setup = [] subs = {} efuncs = {} - # Map PETScSolve to its Section (for logging) + # Map each `PetscMetaData`` to its Section (for logging) section_mapper = MapNodes(Section, PetscMetaData, 'groupby').visit(iet) # Prefixes within the same `Operator` should not be duplicated @@ -76,8 +80,8 @@ def lower_petsc(iet, **kwargs): if duplicates: dup_list = ", ".join(repr(p) for p in sorted(duplicates)) raise ValueError( - f"The following `options_prefix` values are duplicated " - f"among your PETScSolves. Ensure each one is unique: {dup_list}" + "The following `options_prefix` values are duplicated " + f"among your `petscsolve` calls. Ensure each one is unique: {dup_list}" ) # List of `Call`s to clear options from the global PETSc options database, @@ -86,17 +90,17 @@ def lower_petsc(iet, **kwargs): for iters, (inject_solve,) in inject_solve_mapper.items(): - builder = Builder(inject_solve, iters, comm, section_mapper, **kwargs) + solver = BuildSolver(inject_solve, iters, comm, section_mapper, **kwargs) - setup.extend(builder.solver_setup.calls) + setup.extend(solver.builder.calls) # Transform the spatial iteration loop with the calls to execute the solver - subs.update({builder.solve.spatial_body: builder.calls}) + subs.update({solver.solve.spatial_body: solver.calls}) - efuncs.update(builder.cbbuilder.efuncs) + efuncs.update(solver.callback_builder.efuncs) clear_options.extend((petsc_call( - builder.cbbuilder._clear_options_efunc.name, [] + solver.callback_builder._clear_options_efunc.name, [] ),)) populate_matrix_context(efuncs) @@ -108,6 +112,98 @@ def lower_petsc(iet, **kwargs): return iet, metadata +def lower_petsc_symbols(iet, **kwargs): + """ + The `place_definitions` and `place_casts` passes may introduce new + symbols, which must be incorporated into + the relevant PETSc structs. To update the structs, this method then + applies two additional passes: `rebuild_child_user_struct` and + `rebuild_parent_user_struct`. + """ + callback_struct_mapper = {} + # Rebuild `CallbackUserStruct` and update iet accordingly + rebuild_child_user_struct(iet, mapper=callback_struct_mapper) + # Rebuild `MainUserStruct` and update iet accordingly + rebuild_parent_user_struct(iet, mapper=callback_struct_mapper) + + +@iet_pass +def rebuild_child_user_struct(iet, mapper, **kwargs): + """ + Rebuild each `CallbackUserStruct` (the child struct) to include any + new fields introduced by the `place_definitions` and `place_casts` passes. + Also, update the iet accordingly (e.g., dereference the new fields). + - `CallbackUserStruct` is used to access information + in PETSc callback functions via `DMGetApplicationContext`. + """ + old_struct = set([ + i for i in FindSymbols().visit(iet) if isinstance(i, CallbackUserStruct) + ]) + + if not old_struct: + return iet, {} + + # There is a unique `CallbackUserStruct` in each callback + assert len(old_struct) == 1 + old_struct = old_struct.pop() + + # Collect any new fields that have been introduced since the struct was + # previously built + new_fields = [ + f for f in get_user_struct_fields(iet) if f not in old_struct.fields + ] + all_fields = old_struct.fields + new_fields + + # Rebuild the struct + new_struct = old_struct._rebuild(fields=all_fields) + mapper[old_struct] = new_struct + + # Replace old struct with the new one + new_body = Uxreplace(mapper).visit(iet.body) + + # Dereference the new fields and insert them as `standalones` at the top of + # the body. This ensures they are defined before any casts/allocs etc introduced + # by the `place_definitions` and `place_casts` passes. + derefs = tuple([Dereference(i, new_struct) for i in new_fields]) + new_body = new_body._rebuild(standalones=new_body.standalones + derefs) + + return iet._rebuild(body=new_body), {} + + +@iet_pass +def rebuild_parent_user_struct(iet, mapper, **kwargs): + """ + Rebuild each `MainUserStruct` (the parent struct) so that it stays in sync + with its corresponding `CallbackUserStruct` (the child struct). Any IET that + references a parent struct is also updated — either the `PopulateUserContext` + callback or the main Kernel, where the parent struct is registered + via `DMSetApplicationContext`. + """ + if not mapper: + return iet, {} + + parent_struct_mapper = { + v.parent: v.parent._rebuild(fields=v.fields) for v in mapper.values() + } + + if not iet.name.startswith("PopulateUserContext"): + new_body = Uxreplace(parent_struct_mapper).visit(iet.body) + return iet._rebuild(body=new_body), {} + + old_struct = [i for i in iet.parameters if isinstance(i, MainUserStruct)] + assert len(old_struct) == 1 + old_struct = old_struct.pop() + + new_struct = parent_struct_mapper[old_struct] + + new_body = [ + DummyExpr(FieldFromPointer(i._C_symbol, new_struct), i._C_symbol) + for i in new_struct.callback_fields + ] + new_body = iet.body._rebuild(body=new_body) + return iet._rebuild(body=new_body, parameters=(new_struct,)), {} + + def initialize(iet): # should be int because the correct type for argc is a C int # and not a int32 @@ -134,12 +230,7 @@ def finalize(iet): return iet._rebuild(body=finalize_body) -def make_core_petsc_calls(objs, comm): - call_mpi = petsc_call_mpi('MPI_Comm_size', [comm, Byref(objs['size'])]) - return call_mpi, BlankLine - - -class Builder: +class BuildSolver: """ This class is designed to support future extensions, enabling different combinations of solver types, preconditioning methods, @@ -165,39 +256,39 @@ def __init__(self, inject_solve, iters, comm, section_mapper, **kwargs): 'section_mapper': self.section_mapper, **self.kwargs } - self.common_kwargs['solver_objs'] = self.object_builder.solver_objs + self.common_kwargs['solver_objs'] = self.type_builder.solver_objs self.common_kwargs['time_dependence'] = self.time_dependence - self.common_kwargs['cbbuilder'] = self.cbbuilder + self.common_kwargs['callback_builder'] = self.callback_builder self.common_kwargs['logger'] = self.logger @cached_property - def object_builder(self): + def type_builder(self): return ( - CoupledObjectBuilder(**self.common_kwargs) + CoupledTypeBuilder(**self.common_kwargs) if self.coupled else - BaseObjectBuilder(**self.common_kwargs) + BaseTypeBuilder(**self.common_kwargs) ) @cached_property def time_dependence(self): mapper = self.inject_solve.expr.rhs.time_mapper - time_class = TimeDependent if mapper else NonTimeDependent + time_class = TimeDependent if mapper else TimeIndependent return time_class(**self.common_kwargs) @cached_property - def cbbuilder(self): - return CCBBuilder(**self.common_kwargs) \ - if self.coupled else CBBuilder(**self.common_kwargs) + def callback_builder(self): + return CoupledCallbackBuilder(**self.common_kwargs) \ + if self.coupled else BaseCallbackBuilder(**self.common_kwargs) @cached_property - def solver_setup(self): - return CoupledSetup(**self.common_kwargs) \ - if self.coupled else BaseSetup(**self.common_kwargs) + def builder(self): + return CoupledBuilder(**self.common_kwargs) \ + if self.coupled else BuilderBase(**self.common_kwargs) @cached_property def solve(self): - return CoupledSolver(**self.common_kwargs) \ - if self.coupled else Solver(**self.common_kwargs) + return CoupledSolve(**self.common_kwargs) \ + if self.coupled else Solve(**self.common_kwargs) @cached_property def logger(self): @@ -209,81 +300,3 @@ def logger(self): @cached_property def calls(self): return List(body=self.solve.calls+self.logger.calls) - - -def populate_matrix_context(efuncs): - if not objs['dummyefunc'] in efuncs.values(): - return - - subdms_expr = DummyExpr( - FieldFromPointer(objs['Subdms']._C_symbol, objs['ljacctx']), - objs['Subdms']._C_symbol - ) - fields_expr = DummyExpr( - FieldFromPointer(objs['Fields']._C_symbol, objs['ljacctx']), - objs['Fields']._C_symbol - ) - body = CallableBody( - List(body=[subdms_expr, fields_expr]), - init=(petsc_func_begin_user,), - retstmt=tuple([Call('PetscFunctionReturn', arguments=[0])]) - ) - name = 'PopulateMatContext' - efuncs[name] = Callable( - name, body, objs['err'], - parameters=[objs['ljacctx'], objs['Subdms'], objs['Fields']] - ) - - -subdms = PointerDM(name='subdms') -fields = PointerIS(name='fields') -submats = PointerMat(name='submats') -rows = PointerIS(name='rows') -cols = PointerIS(name='cols') - - -# A static dict containing shared symbols and objects that are not -# unique to each PETScSolve. -# Many of these objects are used as arguments in callback functions to make -# the C code cleaner and more modular. This is also a step toward leveraging -# Devito's `reuse_efuncs` functionality, allowing reuse of efuncs when -# they are semantically identical. -objs = frozendict({ - 'size': PetscMPIInt(name='size'), - 'err': PetscErrorCode(name='err'), - 'block': CallbackMat('block'), - 'submat_arr': PointerMat(name='submat_arr'), - 'subblockrows': PetscInt('subblockrows'), - 'subblockcols': PetscInt('subblockcols'), - 'rowidx': PetscInt('rowidx'), - 'colidx': PetscInt('colidx'), - 'J': Mat('J'), - 'X': Vec('X'), - 'xloc': CallbackVec('xloc'), - 'Y': Vec('Y'), - 'yloc': CallbackVec('yloc'), - 'F': Vec('F'), - 'floc': CallbackVec('floc'), - 'B': Vec('B'), - 'nfields': PetscInt('nfields'), - 'irow': PointerIS(name='irow'), - 'icol': PointerIS(name='icol'), - 'nsubmats': Scalar('nsubmats', dtype=np.int32), - 'matreuse': MatReuse('scall'), - 'snes': SNES('snes'), - 'rows': rows, - 'cols': cols, - 'Subdms': subdms, - 'LocalSubdms': CallbackPointerDM(name='subdms'), - 'Fields': fields, - 'LocalFields': CallbackPointerIS(name='fields'), - 'Submats': submats, - 'ljacctx': JacobianStruct( - fields=[subdms, fields, submats], modifier=' *' - ), - 'subctx': SubMatrixStruct(fields=[rows, cols]), - 'dummyctx': Symbol('lctx'), - 'dummyptr': DummyArg('dummy'), - 'dummyefunc': Symbol('dummyefunc'), - 'dof': PetscInt('dof'), -}) diff --git a/devito/petsc/iet/routines.py b/devito/petsc/iet/routines.py deleted file mode 100644 index 94952b3842..0000000000 --- a/devito/petsc/iet/routines.py +++ /dev/null @@ -1,1877 +0,0 @@ -from collections import OrderedDict -from functools import cached_property -import math - -from devito.ir.iet import (Call, FindSymbols, List, Uxreplace, CallableBody, - Dereference, DummyExpr, BlankLine, Callable, FindNodes, - retrieve_iteration_tree, filter_iterations, Iteration, - PointerCast) -from devito.symbolics import (Byref, FieldFromPointer, cast, VOID, - FieldFromComposite, IntDiv, Deref, Mod, String, Null) -from devito.symbolics.unevaluation import Mul -from devito.types.basic import AbstractFunction -from devito.types import Temp, Dimension -from devito.tools import filter_ordered - -from devito.petsc.iet.nodes import (PETScCallable, FormFunctionCallback, - MatShellSetOp, PetscMetaData) -from devito.petsc.iet.utils import (petsc_call, petsc_struct, zero_vector, - dereference_funcs, residual_bundle) -from devito.petsc.types import (PETScArray, PetscBundle, DM, Mat, CallbackVec, Vec, - KSP, PC, SNES, PetscInt, StartPtr, PointerIS, PointerDM, - VecScatter, DMCast, JacobianStruct, SubMatrixStruct, - CallbackDM) -from devito.petsc.types.macros import petsc_func_begin_user - - -class CBBuilder: - """ - Build IET routines to generate PETSc callback functions. - """ - def __init__(self, **kwargs): - - self.rcompile = kwargs.get('rcompile', None) - self.sregistry = kwargs.get('sregistry', None) - self.concretize_mapper = kwargs.get('concretize_mapper', {}) - self.time_dependence = kwargs.get('time_dependence') - self.objs = kwargs.get('objs') - self.solver_objs = kwargs.get('solver_objs') - self.inject_solve = kwargs.get('inject_solve') - self.solve_expr = self.inject_solve.expr.rhs - - self._efuncs = OrderedDict() - self._struct_params = [] - - self._set_options_efunc = None - self._clear_options_efunc = None - self._main_matvec_callback = None - self._user_struct_callback = None - self._F_efunc = None - self._b_efunc = None - - self._J_efuncs = [] - self._initial_guesses = [] - - self._make_core() - self._efuncs = self._uxreplace_efuncs() - - @property - def efuncs(self): - return self._efuncs - - @property - def struct_params(self): - return self._struct_params - - @property - def filtered_struct_params(self): - return filter_ordered(self.struct_params) - - @property - def main_matvec_callback(self): - """ - The matrix-vector callback for the full Jacobian. - This is the function set in the main Kernel via: - PetscCall(MatShellSetOperation(J, MATOP_MULT, (void (*)(void))...)); - The callback has the signature `(Mat, Vec, Vec)`. - """ - return self._J_efuncs[0] - - @property - def J_efuncs(self): - """ - List of matrix-vector callbacks. - Each callback has the signature `(Mat, Vec, Vec)`. Typically, this list - contains a single element, but in mixed systems it can include multiple - callbacks, one for each subblock. - """ - return self._J_efuncs - - @property - def initial_guesses(self): - return self._initial_guesses - - @property - def user_struct_callback(self): - return self._user_struct_callback - - @property - def solver_parameters(self): - return self.solve_expr.solver_parameters - - @property - def field_data(self): - return self.solve_expr.field_data - - @property - def formatted_prefix(self): - return self.solve_expr.formatted_prefix - - @property - def arrays(self): - return self.field_data.arrays - - @property - def target(self): - return self.field_data.target - - def _make_core(self): - self._make_options_callback() - self._make_matvec(self.field_data.jacobian) - self._make_formfunc() - self._make_formrhs() - if self.field_data.initial_guess.exprs: - self._make_initial_guess() - self._make_user_struct_callback() - - def _make_petsc_callable(self, prefix, body, parameters=()): - return PETScCallable( - self.sregistry.make_name(prefix=prefix), - body, - retval=self.objs['err'], - parameters=parameters - ) - - def _make_callable_body(self, body, stacks=(), casts=()): - return CallableBody( - List(body=body), - init=(petsc_func_begin_user,), - stacks=stacks, - casts=casts, - retstmt=(Call('PetscFunctionReturn', arguments=[0]),) - ) - - def _make_options_callback(self): - """ - Create two callbacks: one to set PETSc options and one - to clear them. - - Options are only set/cleared if they were not specifed via - command line arguments. - """ - params = self.solver_parameters - prefix = self.formatted_prefix - - set_body, clear_body = [], [] - - for k, v in params.items(): - option = f'-{prefix}{k}' - - # TODO: Revisit use of a global variable here. - # Consider replacing this with a call to `PetscGetArgs`, though - # initial attempts failed, possibly because the argv pointer is - # created in Python?.. - import devito.petsc.initialize - if option in devito.petsc.initialize._petsc_clargs: - # Ensures that the command line args take priority - continue - - option_name = String(option) - # For options without a value e.g `ksp_view`, pass Null - option_value = Null if v is None else String(str(v)) - set_body.append( - petsc_call('PetscOptionsSetValue', [Null, option_name, option_value]) - ) - clear_body.append( - petsc_call('PetscOptionsClearValue', [Null, option_name]) - ) - - set_body = self._make_callable_body(set_body) - clear_body = self._make_callable_body(clear_body) - - set_callback = self._make_petsc_callable('SetPetscOptions', set_body) - clear_callback = self._make_petsc_callable('ClearPetscOptions', clear_body) - - self._set_options_efunc = set_callback - self._efuncs[set_callback.name] = set_callback - self._clear_options_efunc = clear_callback - self._efuncs[clear_callback.name] = clear_callback - - def _make_matvec(self, jacobian, prefix='MatMult'): - # Compile `matvecs` into an IET via recursive compilation - matvecs = jacobian.matvecs - irs, _ = self.rcompile( - matvecs, options={'mpi': False}, sregistry=self.sregistry, - concretize_mapper=self.concretize_mapper - ) - body = self._create_matvec_body( - List(body=irs.uiet.body), jacobian - ) - objs = self.objs - cb = self._make_petsc_callable( - prefix, body, parameters=(objs['J'], objs['X'], objs['Y']) - ) - self._J_efuncs.append(cb) - self._efuncs[cb.name] = cb - - def _create_matvec_body(self, body, jacobian): - linsolve_expr = self.inject_solve.expr.rhs - objs = self.objs - sobjs = self.solver_objs - - dmda = sobjs['callbackdm'] - ctx = objs['dummyctx'] - xlocal = objs['xloc'] - ylocal = objs['yloc'] - y_matvec = self.arrays[jacobian.row_target]['y'] - x_matvec = self.arrays[jacobian.col_target]['x'] - - body = self.time_dependence.uxreplace_time(body) - - fields = self._dummy_fields(body) - - mat_get_dm = petsc_call('MatGetDM', [objs['J'], Byref(dmda)]) - - dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] - ) - - zero_y_memory = zero_vector(objs['Y']) if jacobian.zero_memory else None - - dm_get_local_xvec = petsc_call( - 'DMGetLocalVector', [dmda, Byref(xlocal)] - ) - - global_to_local_begin = petsc_call( - 'DMGlobalToLocalBegin', [dmda, objs['X'], - insert_vals, xlocal] - ) - - global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ - dmda, objs['X'], insert_vals, xlocal - ]) - - dm_get_local_yvec = petsc_call( - 'DMGetLocalVector', [dmda, Byref(ylocal)] - ) - - zero_ylocal_memory = zero_vector(ylocal) - - vec_get_array_y = petsc_call( - 'VecGetArray', [ylocal, Byref(y_matvec._C_symbol)] - ) - - vec_get_array_x = petsc_call( - 'VecGetArray', [xlocal, Byref(x_matvec._C_symbol)] - ) - - dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] - ) - - vec_restore_array_y = petsc_call( - 'VecRestoreArray', [ylocal, Byref(y_matvec._C_symbol)] - ) - - vec_restore_array_x = petsc_call( - 'VecRestoreArray', [xlocal, Byref(x_matvec._C_symbol)] - ) - - dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ - dmda, ylocal, add_vals, objs['Y'] - ]) - - dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ - dmda, ylocal, add_vals, objs['Y'] - ]) - - dm_restore_local_xvec = petsc_call( - 'DMRestoreLocalVector', [dmda, Byref(xlocal)] - ) - - dm_restore_local_yvec = petsc_call( - 'DMRestoreLocalVector', [dmda, Byref(ylocal)] - ) - - # TODO: Some of the calls are placed in the `stacks` argument of the - # `CallableBody` to ensure that they precede the `cast` statements. The - # 'casts' depend on the calls, so this order is necessary. By doing this, - # you avoid having to manually construct the `casts` and can allow - # Devito to handle their construction. This is a temporary solution and - # should be revisited - - body = body._rebuild( - body=body.body + - (vec_restore_array_y, - vec_restore_array_x, - dm_local_to_global_begin, - dm_local_to_global_end, - dm_restore_local_xvec, - dm_restore_local_yvec) - ) - - stacks = ( - mat_get_dm, - dm_get_app_context, - zero_y_memory, - dm_get_local_xvec, - global_to_local_begin, - global_to_local_end, - dm_get_local_yvec, - zero_ylocal_memory, - vec_get_array_y, - vec_get_array_x, - dm_get_local_info - ) - - # Dereference function data in struct - derefs = dereference_funcs(ctx, fields) - - body = self._make_callable_body(body, stacks=stacks+derefs) - - # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} - body = Uxreplace(subs).visit(body) - - self._struct_params.extend(fields) - return body - - def _make_formfunc(self): - objs = self.objs - F_exprs = self.field_data.residual.F_exprs - # Compile `F_exprs` into an IET via recursive compilation - irs, _ = self.rcompile( - F_exprs, options={'mpi': False}, sregistry=self.sregistry, - concretize_mapper=self.concretize_mapper - ) - body_formfunc = self._create_formfunc_body( - List(body=irs.uiet.body) - ) - parameters = (objs['snes'], objs['X'], objs['F'], objs['dummyptr']) - cb = self._make_petsc_callable('FormFunction', body_formfunc, parameters) - - self._F_efunc = cb - self._efuncs[cb.name] = cb - - def _create_formfunc_body(self, body): - linsolve_expr = self.inject_solve.expr.rhs - objs = self.objs - sobjs = self.solver_objs - arrays = self.arrays - target = self.target - - dmda = sobjs['callbackdm'] - ctx = objs['dummyctx'] - - body = self.time_dependence.uxreplace_time(body) - - fields = self._dummy_fields(body) - self._struct_params.extend(fields) - - f_formfunc = arrays[target]['f'] - x_formfunc = arrays[target]['x'] - - dm_cast = DummyExpr(dmda, DMCast(objs['dummyptr']), init=True) - - dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] - ) - - zero_f_memory = zero_vector(objs['F']) - - dm_get_local_xvec = petsc_call( - 'DMGetLocalVector', [dmda, Byref(objs['xloc'])] - ) - - global_to_local_begin = petsc_call( - 'DMGlobalToLocalBegin', [dmda, objs['X'], - insert_vals, objs['xloc']] - ) - - global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ - dmda, objs['X'], insert_vals, objs['xloc'] - ]) - - dm_get_local_yvec = petsc_call( - 'DMGetLocalVector', [dmda, Byref(objs['floc'])] - ) - - vec_get_array_y = petsc_call( - 'VecGetArray', [objs['floc'], Byref(f_formfunc._C_symbol)] - ) - - vec_get_array_x = petsc_call( - 'VecGetArray', [objs['xloc'], Byref(x_formfunc._C_symbol)] - ) - - dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] - ) - - vec_restore_array_y = petsc_call( - 'VecRestoreArray', [objs['floc'], Byref(f_formfunc._C_symbol)] - ) - - vec_restore_array_x = petsc_call( - 'VecRestoreArray', [objs['xloc'], Byref(x_formfunc._C_symbol)] - ) - - dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ - dmda, objs['floc'], add_vals, objs['F'] - ]) - - dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ - dmda, objs['floc'], add_vals, objs['F'] - ]) - - dm_restore_local_xvec = petsc_call( - 'DMRestoreLocalVector', [dmda, Byref(objs['xloc'])] - ) - - dm_restore_local_yvec = petsc_call( - 'DMRestoreLocalVector', [dmda, Byref(objs['floc'])] - ) - - body = body._rebuild( - body=body.body + - (vec_restore_array_y, - vec_restore_array_x, - dm_local_to_global_begin, - dm_local_to_global_end, - dm_restore_local_xvec, - dm_restore_local_yvec) - ) - - stacks = ( - dm_cast, - dm_get_app_context, - zero_f_memory, - dm_get_local_xvec, - global_to_local_begin, - global_to_local_end, - dm_get_local_yvec, - vec_get_array_y, - vec_get_array_x, - dm_get_local_info - ) - - # Dereference function data in struct - derefs = dereference_funcs(ctx, fields) - - body = self._make_callable_body(body, stacks=stacks+derefs) - # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} - - return Uxreplace(subs).visit(body) - - def _make_formrhs(self): - b_exprs = self.field_data.residual.b_exprs - sobjs = self.solver_objs - - # Compile `b_exprs` into an IET via recursive compilation - irs, _ = self.rcompile( - b_exprs, options={'mpi': False}, sregistry=self.sregistry, - concretize_mapper=self.concretize_mapper - ) - body = self._create_form_rhs_body( - List(body=irs.uiet.body) - ) - objs = self.objs - cb = self._make_petsc_callable( - 'FormRHS', body, parameters=(sobjs['callbackdm'], objs['B']) - ) - self._b_efunc = cb - self._efuncs[cb.name] = cb - - def _create_form_rhs_body(self, body): - linsolve_expr = self.inject_solve.expr.rhs - objs = self.objs - sobjs = self.solver_objs - target = self.target - - dmda = sobjs['callbackdm'] - ctx = objs['dummyctx'] - - dm_get_local = petsc_call( - 'DMGetLocalVector', [dmda, Byref(sobjs['blocal'])] - ) - - dm_global_to_local_begin = petsc_call( - 'DMGlobalToLocalBegin', [dmda, objs['B'], - insert_vals, sobjs['blocal']] - ) - - dm_global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ - dmda, objs['B'], insert_vals, - sobjs['blocal'] - ]) - - b_arr = self.field_data.arrays[target]['b'] - - vec_get_array = petsc_call( - 'VecGetArray', [sobjs['blocal'], Byref(b_arr._C_symbol)] - ) - - dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] - ) - - body = self.time_dependence.uxreplace_time(body) - - fields = self._dummy_fields(body) - self._struct_params.extend(fields) - - dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] - ) - - dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ - dmda, sobjs['blocal'], insert_vals, - objs['B'] - ]) - - dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ - dmda, sobjs['blocal'], insert_vals, - objs['B'] - ]) - - vec_restore_array = petsc_call( - 'VecRestoreArray', [sobjs['blocal'], Byref(b_arr._C_symbol)] - ) - - dm_restore_local_bvec = petsc_call( - 'DMRestoreLocalVector', [dmda, Byref(sobjs['blocal'])] - ) - - body = body._rebuild(body=body.body + ( - dm_local_to_global_begin, dm_local_to_global_end, vec_restore_array, - dm_restore_local_bvec - )) - - stacks = ( - dm_get_local, - dm_global_to_local_begin, - dm_global_to_local_end, - vec_get_array, - dm_get_app_context, - dm_get_local_info - ) - - # Dereference function data in struct - derefs = dereference_funcs(ctx, fields) - - body = self._make_callable_body([body], stacks=stacks+derefs) - - # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for - i in fields if not isinstance(i.function, AbstractFunction)} - - return Uxreplace(subs).visit(body) - - def _make_initial_guess(self): - exprs = self.field_data.initial_guess.exprs - sobjs = self.solver_objs - objs = self.objs - - # Compile initital guess `eqns` into an IET via recursive compilation - irs, _ = self.rcompile( - exprs, options={'mpi': False}, sregistry=self.sregistry, - concretize_mapper=self.concretize_mapper - ) - body = self._create_initial_guess_body( - List(body=irs.uiet.body) - ) - cb = self._make_petsc_callable( - 'FormInitialGuess', body, parameters=(sobjs['callbackdm'], objs['xloc']) - ) - self._initial_guesses.append(cb) - self._efuncs[cb.name] = cb - - def _create_initial_guess_body(self, body): - linsolve_expr = self.inject_solve.expr.rhs - objs = self.objs - sobjs = self.solver_objs - target = self.target - - dmda = sobjs['callbackdm'] - ctx = objs['dummyctx'] - - x_arr = self.field_data.arrays[target]['x'] - - vec_get_array = petsc_call( - 'VecGetArray', [objs['xloc'], Byref(x_arr._C_symbol)] - ) - - dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] - ) - - body = self.time_dependence.uxreplace_time(body) - - fields = self._dummy_fields(body) - self._struct_params.extend(fields) - - dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] - ) - - vec_restore_array = petsc_call( - 'VecRestoreArray', [objs['xloc'], Byref(x_arr._C_symbol)] - ) - - body = body._rebuild(body=body.body + (vec_restore_array,)) - - stacks = ( - vec_get_array, - dm_get_app_context, - dm_get_local_info - ) - - # Dereference function data in struct - derefs = dereference_funcs(ctx, fields) - body = self._make_callable_body(body, stacks=stacks+derefs) - - # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for - i in fields if not isinstance(i.function, AbstractFunction)} - - return Uxreplace(subs).visit(body) - - def _make_user_struct_callback(self): - """ - This is the struct initialised inside the main kernel and - attached to the DM via DMSetApplicationContext. - # TODO: this could be common between all PETScSolves instead? - """ - mainctx = self.solver_objs['userctx'] = petsc_struct( - self.sregistry.make_name(prefix='ctx'), - self.filtered_struct_params, - self.sregistry.make_name(prefix='UserCtx'), - ) - body = [ - DummyExpr(FieldFromPointer(i._C_symbol, mainctx), i._C_symbol) - for i in mainctx.callback_fields - ] - struct_callback_body = self._make_callable_body(body) - cb = Callable( - self.sregistry.make_name(prefix='PopulateUserContext'), - struct_callback_body, self.objs['err'], - parameters=[mainctx] - ) - self._efuncs[cb.name] = cb - self._user_struct_callback = cb - - def _dummy_fields(self, iet): - # Place all context data required by the shell routines into a struct - fields = [f.function for f in FindSymbols('basics').visit(iet)] - fields = [f for f in fields if not isinstance(f.function, (PETScArray, Temp))] - fields = [ - f for f in fields if not (f.is_Dimension and not (f.is_Time or f.is_Modulo)) - ] - return fields - - def _uxreplace_efuncs(self): - sobjs = self.solver_objs - luserctx = petsc_struct( - sobjs['userctx'].name, - self.filtered_struct_params, - sobjs['userctx'].pname, - modifier=' *' - ) - mapper = {} - visitor = Uxreplace({self.objs['dummyctx']: luserctx}) - for k, v in self._efuncs.items(): - mapper.update({k: visitor.visit(v)}) - return mapper - - -class CCBBuilder(CBBuilder): - def __init__(self, **kwargs): - self._submatrices_callback = None - super().__init__(**kwargs) - - @property - def submatrices_callback(self): - return self._submatrices_callback - - @property - def jacobian(self): - return self.inject_solve.expr.rhs.field_data.jacobian - - @property - def main_matvec_callback(self): - """ - This is the matrix-vector callback associated with the whole Jacobian i.e - is set in the main kernel via - `PetscCall(MatShellSetOperation(J,MATOP_MULT,(void (*)(void))MyMatShellMult));` - """ - return self._main_matvec_callback - - def _make_core(self): - for sm in self.field_data.jacobian.nonzero_submatrices: - self._make_matvec(sm, prefix=f'{sm.name}_MatMult') - - self._make_options_callback() - self._make_whole_matvec() - self._make_whole_formfunc() - self._make_user_struct_callback() - self._create_submatrices() - self._efuncs['PopulateMatContext'] = self.objs['dummyefunc'] - - def _make_whole_matvec(self): - objs = self.objs - body = self._whole_matvec_body() - - parameters = (objs['J'], objs['X'], objs['Y']) - cb = self._make_petsc_callable( - 'WholeMatMult', List(body=body), parameters=parameters - ) - self._main_matvec_callback = cb - self._efuncs[cb.name] = cb - - def _whole_matvec_body(self): - objs = self.objs - sobjs = self.solver_objs - - jctx = objs['ljacctx'] - ctx_main = petsc_call('MatShellGetContext', [objs['J'], Byref(jctx)]) - - nonzero_submats = self.jacobian.nonzero_submatrices - - zero_y_memory = zero_vector(objs['Y']) - - calls = () - for sm in nonzero_submats: - name = sm.name - ctx = sobjs[f'{name}ctx'] - X = sobjs[f'{name}X'] - Y = sobjs[f'{name}Y'] - rows = objs['rows'].base - cols = objs['cols'].base - sm_indexed = objs['Submats'].indexed[sm.linear_idx] - - calls += ( - DummyExpr(sobjs[name], FieldFromPointer(sm_indexed, jctx)), - petsc_call('MatShellGetContext', [sobjs[name], Byref(ctx)]), - petsc_call( - 'VecGetSubVector', - [objs['X'], Deref(FieldFromPointer(cols, ctx)), Byref(X)] - ), - petsc_call( - 'VecGetSubVector', - [objs['Y'], Deref(FieldFromPointer(rows, ctx)), Byref(Y)] - ), - petsc_call('MatMult', [sobjs[name], X, Y]), - petsc_call( - 'VecRestoreSubVector', - [objs['X'], Deref(FieldFromPointer(cols, ctx)), Byref(X)] - ), - petsc_call( - 'VecRestoreSubVector', - [objs['Y'], Deref(FieldFromPointer(rows, ctx)), Byref(Y)] - ), - ) - body = (ctx_main, zero_y_memory, BlankLine) + calls - return self._make_callable_body(body) - - def _make_whole_formfunc(self): - objs = self.objs - F_exprs = self.field_data.residual.F_exprs - # Compile `F_exprs` into an IET via recursive compilation - irs, _ = self.rcompile( - F_exprs, options={'mpi': False}, sregistry=self.sregistry, - concretize_mapper=self.concretize_mapper - ) - body = self._whole_formfunc_body(List(body=irs.uiet.body)) - - parameters = (objs['snes'], objs['X'], objs['F'], objs['dummyptr']) - cb = self._make_petsc_callable( - 'WholeFormFunc', body, parameters=parameters - ) - - self._F_efunc = cb - self._efuncs[cb.name] = cb - - def _whole_formfunc_body(self, body): - linsolve_expr = self.inject_solve.expr.rhs - objs = self.objs - sobjs = self.solver_objs - - dmda = sobjs['callbackdm'] - ctx = objs['dummyctx'] - - body = self.time_dependence.uxreplace_time(body) - - fields = self._dummy_fields(body) - self._struct_params.extend(fields) - - # Process body with bundles for residual callback - bundles = sobjs['bundles'] - fbundle = bundles['f'] - xbundle = bundles['x'] - - body = residual_bundle(body, bundles) - - dm_cast = DummyExpr(dmda, DMCast(objs['dummyptr']), init=True) - - dm_get_app_context = petsc_call( - 'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)] - ) - - zero_f_memory = zero_vector(objs['F']) - - dm_get_local_xvec = petsc_call( - 'DMGetLocalVector', [dmda, Byref(objs['xloc'])] - ) - - global_to_local_begin = petsc_call( - 'DMGlobalToLocalBegin', [dmda, objs['X'], - insert_vals, objs['xloc']] - ) - - global_to_local_end = petsc_call('DMGlobalToLocalEnd', [ - dmda, objs['X'], insert_vals, objs['xloc'] - ]) - - dm_get_local_yvec = petsc_call( - 'DMGetLocalVector', [dmda, Byref(objs['floc'])] - ) - - vec_get_array_f = petsc_call( - 'VecGetArray', [objs['floc'], Byref(fbundle.vector._C_symbol)] - ) - - vec_get_array_x = petsc_call( - 'VecGetArray', [objs['xloc'], Byref(xbundle.vector._C_symbol)] - ) - - dm_get_local_info = petsc_call( - 'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)] - ) - - vec_restore_array_f = petsc_call( - 'VecRestoreArray', [objs['floc'], Byref(fbundle.vector._C_symbol)] - ) - - vec_restore_array_x = petsc_call( - 'VecRestoreArray', [objs['xloc'], Byref(xbundle.vector._C_symbol)] - ) - - dm_local_to_global_begin = petsc_call('DMLocalToGlobalBegin', [ - dmda, objs['floc'], add_vals, objs['F'] - ]) - - dm_local_to_global_end = petsc_call('DMLocalToGlobalEnd', [ - dmda, objs['floc'], add_vals, objs['F'] - ]) - - dm_restore_local_xvec = petsc_call( - 'DMRestoreLocalVector', [dmda, Byref(objs['xloc'])] - ) - - dm_restore_local_yvec = petsc_call( - 'DMRestoreLocalVector', [dmda, Byref(objs['floc'])] - ) - - body = body._rebuild( - body=body.body + - (vec_restore_array_f, - vec_restore_array_x, - dm_local_to_global_begin, - dm_local_to_global_end, - dm_restore_local_xvec, - dm_restore_local_yvec) - ) - - stacks = ( - dm_cast, - dm_get_app_context, - zero_f_memory, - dm_get_local_xvec, - global_to_local_begin, - global_to_local_end, - dm_get_local_yvec, - vec_get_array_f, - vec_get_array_x, - dm_get_local_info - ) - - # Dereference function data in struct - derefs = dereference_funcs(ctx, fields) - - f_soa = PointerCast(fbundle) - x_soa = PointerCast(xbundle) - - formfunc_body = self._make_callable_body( - body, stacks=stacks+derefs, - casts=(f_soa, x_soa), - ) - # Replace non-function data with pointer to data in struct - subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for i in fields} - - return Uxreplace(subs).visit(formfunc_body) - - def _create_submatrices(self): - body = self._submat_callback_body() - objs = self.objs - params = ( - objs['J'], - objs['nfields'], - objs['irow'], - objs['icol'], - objs['matreuse'], - objs['Submats'], - ) - cb = self._make_petsc_callable( - 'MatCreateSubMatrices', body, parameters=params) - - self._submatrices_callback = cb - self._efuncs[cb.name] = cb - - def _submat_callback_body(self): - objs = self.objs - sobjs = self.solver_objs - - n_submats = DummyExpr( - objs['nsubmats'], Mul(objs['nfields'], objs['nfields']) - ) - - malloc_submats = petsc_call('PetscCalloc1', [objs['nsubmats'], objs['Submats']]) - - mat_get_dm = petsc_call('MatGetDM', [objs['J'], Byref(sobjs['callbackdm'])]) - - dm_get_app = petsc_call( - 'DMGetApplicationContext', [sobjs['callbackdm'], Byref(objs['dummyctx'])] - ) - - get_ctx = petsc_call('MatShellGetContext', [objs['J'], Byref(objs['ljacctx'])]) - - dm_get_info = petsc_call( - 'DMDAGetInfo', [ - sobjs['callbackdm'], Null, Byref(sobjs['M']), Byref(sobjs['N']), - Null, Null, Null, Null, Byref(objs['dof']), Null, Null, Null, Null, Null - ] - ) - subblock_rows = DummyExpr(objs['subblockrows'], Mul(sobjs['M'], sobjs['N'])) - subblock_cols = DummyExpr(objs['subblockcols'], Mul(sobjs['M'], sobjs['N'])) - - ptr = DummyExpr(objs['submat_arr']._C_symbol, Deref(objs['Submats']), init=True) - - mat_create = petsc_call('MatCreate', [sobjs['comm'], Byref(objs['block'])]) - - mat_set_sizes = petsc_call( - 'MatSetSizes', [ - objs['block'], 'PETSC_DECIDE', 'PETSC_DECIDE', - objs['subblockrows'], objs['subblockcols'] - ] - ) - - mat_set_type = petsc_call('MatSetType', [objs['block'], 'MATSHELL']) - - malloc = petsc_call('PetscMalloc1', [1, Byref(objs['subctx'])]) - i = Dimension(name='i') - - row_idx = DummyExpr(objs['rowidx'], IntDiv(i, objs['dof'])) - col_idx = DummyExpr(objs['colidx'], Mod(i, objs['dof'])) - - deref_subdm = Dereference(objs['Subdms'], objs['ljacctx']) - - set_rows = DummyExpr( - FieldFromPointer(objs['rows'].base, objs['subctx']), - Byref(objs['irow'].indexed[objs['rowidx']]) - ) - set_cols = DummyExpr( - FieldFromPointer(objs['cols'].base, objs['subctx']), - Byref(objs['icol'].indexed[objs['colidx']]) - ) - dm_set_ctx = petsc_call( - 'DMSetApplicationContext', [ - objs['Subdms'].indexed[objs['rowidx']], objs['dummyctx'] - ] - ) - matset_dm = petsc_call('MatSetDM', [ - objs['block'], objs['Subdms'].indexed[objs['rowidx']] - ]) - - set_ctx = petsc_call('MatShellSetContext', [objs['block'], objs['subctx']]) - - mat_setup = petsc_call('MatSetUp', [objs['block']]) - - assign_block = DummyExpr(objs['submat_arr'].indexed[i], objs['block']) - - iter_body = ( - mat_create, - mat_set_sizes, - mat_set_type, - malloc, - row_idx, - col_idx, - set_rows, - set_cols, - dm_set_ctx, - matset_dm, - set_ctx, - mat_setup, - assign_block - ) - - upper_bound = objs['nsubmats'] - 1 - iteration = Iteration(List(body=iter_body), i, upper_bound) - - nonzero_submats = self.jacobian.nonzero_submatrices - matvec_lookup = {mv.name.split('_')[0]: mv for mv in self.J_efuncs} - - matmult_op = [ - petsc_call( - 'MatShellSetOperation', - [ - objs['submat_arr'].indexed[sb.linear_idx], - 'MATOP_MULT', - MatShellSetOp(matvec_lookup[sb.name].name, void, void), - ], - ) - for sb in nonzero_submats if sb.name in matvec_lookup - ] - - body = [ - n_submats, - malloc_submats, - mat_get_dm, - dm_get_app, - dm_get_info, - subblock_rows, - subblock_cols, - ptr, - BlankLine, - iteration, - ] + matmult_op - - return self._make_callable_body(tuple(body), stacks=(get_ctx, deref_subdm)) - - -class BaseObjectBuilder: - """ - A base class for constructing objects needed for a PETSc solver. - Designed to be extended by subclasses, which can override the `_extend_build` - method to support specific use cases. - """ - def __init__(self, **kwargs): - self.inject_solve = kwargs.get('inject_solve') - self.objs = kwargs.get('objs') - self.sregistry = kwargs.get('sregistry') - self.comm = kwargs.get('comm') - self.field_data = self.inject_solve.expr.rhs.field_data - self.solver_objs = self._build() - - def _build(self): - """ - # TODO: update docs - Constructs the core dictionary of solver objects and allows - subclasses to extend or modify it via `_extend_build`. - Returns: - dict: A dictionary containing the following objects: - - 'Jac' (Mat): A matrix representing the jacobian. - - 'xglobal' (GlobalVec): The global solution vector. - - 'xlocal' (LocalVec): The local solution vector. - - 'bglobal': (GlobalVec) Global RHS vector `b`, where `F(x) = b`. - - 'blocal': (LocalVec) Local RHS vector `b`, where `F(x) = b`. - - 'ksp': (KSP) Krylov solver object that manages the linear solver. - - 'pc': (PC) Preconditioner object. - - 'snes': (SNES) Nonlinear solver object. - - 'localsize' (PetscInt): The local length of the solution vector. - - 'dmda' (DM): The DMDA object associated with this solve, linked to - the SNES object via `SNESSetDM`. - - 'callbackdm' (CallbackDM): The DM object accessed within callback - functions via `SNESGetDM`. - """ - sreg = self.sregistry - targets = self.field_data.targets - - snes_name = sreg.make_name(prefix='snes') - formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix - - base_dict = { - 'Jac': Mat(sreg.make_name(prefix='J')), - 'xglobal': Vec(sreg.make_name(prefix='xglobal')), - 'xlocal': Vec(sreg.make_name(prefix='xlocal')), - 'bglobal': Vec(sreg.make_name(prefix='bglobal')), - 'blocal': CallbackVec(sreg.make_name(prefix='blocal')), - 'ksp': KSP(sreg.make_name(prefix='ksp')), - 'pc': PC(sreg.make_name(prefix='pc')), - 'snes': SNES(snes_name), - 'localsize': PetscInt(sreg.make_name(prefix='localsize')), - 'dmda': DM(sreg.make_name(prefix='da'), dofs=len(targets)), - 'callbackdm': CallbackDM(sreg.make_name(prefix='dm')), - 'snes_prefix': String(formatted_prefix), - } - - base_dict['comm'] = self.comm - self._target_dependent(base_dict) - return self._extend_build(base_dict) - - def _target_dependent(self, base_dict): - """ - '_ptr' (StartPtr): A pointer to the beginning of the solution array - that will be updated at each time step. - """ - sreg = self.sregistry - target = self.field_data.target - base_dict[f'{target.name}_ptr'] = StartPtr( - sreg.make_name(prefix=f'{target.name}_ptr'), target.dtype - ) - - def _extend_build(self, base_dict): - """ - Subclasses can override this method to extend or modify the - base dictionary of solver objects. - """ - return base_dict - - -class CoupledObjectBuilder(BaseObjectBuilder): - def _extend_build(self, base_dict): - sreg = self.sregistry - objs = self.objs - targets = self.field_data.targets - arrays = self.field_data.arrays - - base_dict['fields'] = PointerIS( - name=sreg.make_name(prefix='fields'), nindices=len(targets) - ) - base_dict['subdms'] = PointerDM( - name=sreg.make_name(prefix='subdms'), nindices=len(targets) - ) - base_dict['nfields'] = PetscInt(sreg.make_name(prefix='nfields')) - - space_dims = len(self.field_data.grid.dimensions) - - dim_labels = ["M", "N", "P"] - base_dict.update({ - dim_labels[i]: PetscInt(dim_labels[i]) for i in range(space_dims) - }) - - submatrices = self.field_data.jacobian.nonzero_submatrices - - base_dict['jacctx'] = JacobianStruct( - name=sreg.make_name(prefix=objs['ljacctx'].name), - fields=objs['ljacctx'].fields, - ) - - for sm in submatrices: - name = sm.name - base_dict[name] = Mat(name=name) - base_dict[f'{name}ctx'] = SubMatrixStruct( - name=f'{name}ctx', - fields=objs['subctx'].fields, - ) - base_dict[f'{name}X'] = CallbackVec(f'{name}X') - base_dict[f'{name}Y'] = CallbackVec(f'{name}Y') - base_dict[f'{name}F'] = CallbackVec(f'{name}F') - - # Bundle objects/metadata required by the coupled residual callback - f_components, x_components = [], [] - bundle_mapper = {} - pname = sreg.make_name(prefix='Field') - - target_indices = {t: i for i, t in enumerate(targets)} - - for t in targets: - f_arr = arrays[t]['f'] - x_arr = arrays[t]['x'] - f_components.append(f_arr) - x_components.append(x_arr) - - fbundle = PetscBundle( - name='f_bundle', components=f_components, pname=pname - ) - xbundle = PetscBundle( - name='x_bundle', components=x_components, pname=pname - ) - - # Build the bundle mapper - for f_arr, x_arr in zip(f_components, x_components): - bundle_mapper[f_arr.base] = fbundle - bundle_mapper[x_arr.base] = xbundle - - base_dict['bundles'] = { - 'f': fbundle, - 'x': xbundle, - 'bundle_mapper': bundle_mapper, - 'target_indices': target_indices - } - - return base_dict - - def _target_dependent(self, base_dict): - sreg = self.sregistry - targets = self.field_data.targets - for t in targets: - name = t.name - base_dict[f'{name}_ptr'] = StartPtr( - sreg.make_name(prefix=f'{name}_ptr'), t.dtype - ) - base_dict[f'xlocal{name}'] = CallbackVec( - sreg.make_name(prefix=f'xlocal{name}'), liveness='eager' - ) - base_dict[f'Fglobal{name}'] = CallbackVec( - sreg.make_name(prefix=f'Fglobal{name}'), liveness='eager' - ) - base_dict[f'Xglobal{name}'] = CallbackVec( - sreg.make_name(prefix=f'Xglobal{name}') - ) - base_dict[f'xglobal{name}'] = Vec( - sreg.make_name(prefix=f'xglobal{name}') - ) - base_dict[f'blocal{name}'] = CallbackVec( - sreg.make_name(prefix=f'blocal{name}'), liveness='eager' - ) - base_dict[f'bglobal{name}'] = Vec( - sreg.make_name(prefix=f'bglobal{name}') - ) - base_dict[f'da{name}'] = DM( - sreg.make_name(prefix=f'da{name}'), liveness='eager' - ) - base_dict[f'scatter{name}'] = VecScatter( - sreg.make_name(prefix=f'scatter{name}') - ) - - -class BaseSetup: - def __init__(self, **kwargs): - self.inject_solve = kwargs.get('inject_solve') - self.objs = kwargs.get('objs') - self.solver_objs = kwargs.get('solver_objs') - self.cbbuilder = kwargs.get('cbbuilder') - self.field_data = self.inject_solve.expr.rhs.field_data - self.formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix - self.calls = self._setup() - - @property - def snes_ctx(self): - """ - The [optional] context for private data for the function evaluation routine. - https://petsc.org/main/manualpages/SNES/SNESSetFunction/ - """ - return VOID(self.solver_objs['dmda'], stars='*') - - def _setup(self): - sobjs = self.solver_objs - dmda = sobjs['dmda'] - - snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])]) - - snes_options_prefix = petsc_call( - 'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snes_prefix']] - ) if self.formatted_prefix else None - - set_options = petsc_call( - self.cbbuilder._set_options_efunc.name, [] - ) - - snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda]) - - create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])]) - - snes_set_jac = petsc_call( - 'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'], - sobjs['Jac'], 'MatMFFDComputeJacobian', Null] - ) - - global_x = petsc_call('DMCreateGlobalVector', - [dmda, Byref(sobjs['xglobal'])]) - - target = self.field_data.target - field_from_ptr = FieldFromPointer( - target.function._C_field_data, target.function._C_symbol - ) - - local_size = math.prod( - v for v, dim in zip(target.shape_allocated, target.dimensions) if dim.is_Space - ) - # TODO: Check - VecCreateSeqWithArray - local_x = petsc_call('VecCreateMPIWithArray', - [sobjs['comm'], 1, local_size, 'PETSC_DECIDE', - field_from_ptr, Byref(sobjs['xlocal'])]) - - # TODO: potentially also need to set the DM and local/global map to xlocal - - get_local_size = petsc_call('VecGetSize', - [sobjs['xlocal'], Byref(sobjs['localsize'])]) - - global_b = petsc_call('DMCreateGlobalVector', - [dmda, Byref(sobjs['bglobal'])]) - - snes_get_ksp = petsc_call('SNESGetKSP', - [sobjs['snes'], Byref(sobjs['ksp'])]) - - matvec = self.cbbuilder.main_matvec_callback - matvec_operation = petsc_call( - 'MatShellSetOperation', - [sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)] - ) - formfunc = self.cbbuilder._F_efunc - formfunc_operation = petsc_call( - 'SNESSetFunction', - [sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void), - self.snes_ctx] - ) - - snes_set_options = petsc_call( - 'SNESSetFromOptions', [sobjs['snes']] - ) - - dmda_calls = self._create_dmda_calls(dmda) - - mainctx = sobjs['userctx'] - - call_struct_callback = petsc_call( - self.cbbuilder.user_struct_callback.name, [Byref(mainctx)] - ) - - # TODO: maybe don't need to explictly set this - mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda]) - - calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)]) - - base_setup = dmda_calls + ( - snes_create, - snes_options_prefix, - set_options, - snes_set_dm, - create_matrix, - snes_set_jac, - global_x, - local_x, - get_local_size, - global_b, - snes_get_ksp, - matvec_operation, - formfunc_operation, - snes_set_options, - call_struct_callback, - mat_set_dm, - calls_set_app_ctx, - BlankLine - ) - extended_setup = self._extend_setup() - return base_setup + extended_setup - - def _extend_setup(self): - """ - Hook for subclasses to add additional setup calls. - """ - return () - - def _create_dmda_calls(self, dmda): - dmda_create = self._create_dmda(dmda) - dm_setup = petsc_call('DMSetUp', [dmda]) - dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL']) - return dmda_create, dm_setup, dm_mat_type - - def _create_dmda(self, dmda): - sobjs = self.solver_objs - grid = self.field_data.grid - nspace_dims = len(grid.dimensions) - - # MPI communicator - args = [sobjs['comm']] - - # Type of ghost nodes - args.extend(['DM_BOUNDARY_GHOSTED' for _ in range(nspace_dims)]) - - # Stencil type - if nspace_dims > 1: - args.append('DMDA_STENCIL_BOX') - - # Global dimensions - args.extend(list(grid.shape)[::-1]) - # No.of processors in each dimension - if nspace_dims > 1: - args.extend(list(grid.distributor.topology)[::-1]) - - # Number of degrees of freedom per node - args.append(dmda.dofs) - # "Stencil width" -> size of overlap - # TODO: Instead, this probably should be - # extracted from field_data.target._size_outhalo? - stencil_width = self.field_data.space_order - - args.append(stencil_width) - args.extend([Null]*nspace_dims) - - # The distributed array object - args.append(Byref(dmda)) - - # The PETSc call used to create the DMDA - dmda = petsc_call(f'DMDACreate{nspace_dims}d', args) - - return dmda - - -class CoupledSetup(BaseSetup): - def _setup(self): - # TODO: minimise code duplication with superclass - objs = self.objs - sobjs = self.solver_objs - dmda = sobjs['dmda'] - - snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])]) - - snes_options_prefix = petsc_call( - 'SNESSetOptionsPrefix', [sobjs['snes'], sobjs['snes_prefix']] - ) if self.formatted_prefix else None - - set_options = petsc_call( - self.cbbuilder._set_options_efunc.name, [] - ) - - snes_set_dm = petsc_call('SNESSetDM', [sobjs['snes'], dmda]) - - create_matrix = petsc_call('DMCreateMatrix', [dmda, Byref(sobjs['Jac'])]) - - snes_set_jac = petsc_call( - 'SNESSetJacobian', [sobjs['snes'], sobjs['Jac'], - sobjs['Jac'], 'MatMFFDComputeJacobian', Null] - ) - - global_x = petsc_call('DMCreateGlobalVector', - [dmda, Byref(sobjs['xglobal'])]) - - local_x = petsc_call('DMCreateLocalVector', [dmda, Byref(sobjs['xlocal'])]) - - get_local_size = petsc_call('VecGetSize', - [sobjs['xlocal'], Byref(sobjs['localsize'])]) - - snes_get_ksp = petsc_call('SNESGetKSP', - [sobjs['snes'], Byref(sobjs['ksp'])]) - - matvec = self.cbbuilder.main_matvec_callback - matvec_operation = petsc_call( - 'MatShellSetOperation', - [sobjs['Jac'], 'MATOP_MULT', MatShellSetOp(matvec.name, void, void)] - ) - formfunc = self.cbbuilder._F_efunc - formfunc_operation = petsc_call( - 'SNESSetFunction', - [sobjs['snes'], Null, FormFunctionCallback(formfunc.name, void, void), - self.snes_ctx] - ) - - snes_set_options = petsc_call( - 'SNESSetFromOptions', [sobjs['snes']] - ) - - dmda_calls = self._create_dmda_calls(dmda) - - mainctx = sobjs['userctx'] - - call_struct_callback = petsc_call( - self.cbbuilder.user_struct_callback.name, [Byref(mainctx)] - ) - - # TODO: maybe don't need to explictly set this - mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda]) - - calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)]) - - create_field_decomp = petsc_call( - 'DMCreateFieldDecomposition', - [dmda, Byref(sobjs['nfields']), Null, Byref(sobjs['fields']), - Byref(sobjs['subdms'])] - ) - submat_cb = self.cbbuilder.submatrices_callback - matop_create_submats_op = petsc_call( - 'MatShellSetOperation', - [sobjs['Jac'], 'MATOP_CREATE_SUBMATRICES', - MatShellSetOp(submat_cb.name, void, void)] - ) - - call_coupled_struct_callback = petsc_call( - 'PopulateMatContext', - [Byref(sobjs['jacctx']), sobjs['subdms'], sobjs['fields']] - ) - - shell_set_ctx = petsc_call( - 'MatShellSetContext', [sobjs['Jac'], Byref(sobjs['jacctx']._C_symbol)] - ) - - create_submats = petsc_call( - 'MatCreateSubMatrices', - [sobjs['Jac'], sobjs['nfields'], sobjs['fields'], - sobjs['fields'], 'MAT_INITIAL_MATRIX', - Byref(FieldFromComposite(objs['Submats'].base, sobjs['jacctx']))] - ) - - targets = self.field_data.targets - - deref_dms = [ - DummyExpr(sobjs[f'da{t.name}'], sobjs['subdms'].indexed[i]) - for i, t in enumerate(targets) - ] - - xglobals = [petsc_call( - 'DMCreateGlobalVector', - [sobjs[f'da{t.name}'], Byref(sobjs[f'xglobal{t.name}'])] - ) for t in targets] - - coupled_setup = dmda_calls + ( - snes_create, - snes_options_prefix, - set_options, - snes_set_dm, - create_matrix, - snes_set_jac, - global_x, - local_x, - get_local_size, - snes_get_ksp, - matvec_operation, - formfunc_operation, - snes_set_options, - call_struct_callback, - mat_set_dm, - calls_set_app_ctx, - create_field_decomp, - matop_create_submats_op, - call_coupled_struct_callback, - shell_set_ctx, - create_submats) + \ - tuple(deref_dms) + tuple(xglobals) - return coupled_setup - - -class Solver: - def __init__(self, **kwargs): - self.inject_solve = kwargs.get('inject_solve') - self.objs = kwargs.get('objs') - self.solver_objs = kwargs.get('solver_objs') - self.iters = kwargs.get('iters') - self.cbbuilder = kwargs.get('cbbuilder') - self.time_dependence = kwargs.get('time_dependence') - self.calls = self._execute_solve() - - def _execute_solve(self): - """ - Assigns the required time iterators to the struct and executes - the necessary calls to execute the SNES solver. - """ - sobjs = self.solver_objs - target = self.inject_solve.expr.rhs.field_data.target - - struct_assignment = self.time_dependence.assign_time_iters(sobjs['userctx']) - - b_efunc = self.cbbuilder._b_efunc - - dmda = sobjs['dmda'] - - rhs_call = petsc_call(b_efunc.name, [sobjs['dmda'], sobjs['bglobal']]) - - vec_place_array = self.time_dependence.place_array(target) - - if self.cbbuilder.initial_guesses: - initguess = self.cbbuilder.initial_guesses[0] - initguess_call = petsc_call(initguess.name, [dmda, sobjs['xlocal']]) - else: - initguess_call = None - - dm_local_to_global_x = petsc_call( - 'DMLocalToGlobal', [dmda, sobjs['xlocal'], insert_vals, - sobjs['xglobal']] - ) - - snes_solve = petsc_call('SNESSolve', [ - sobjs['snes'], sobjs['bglobal'], sobjs['xglobal']] - ) - - dm_global_to_local_x = petsc_call('DMGlobalToLocal', [ - dmda, sobjs['xglobal'], insert_vals, sobjs['xlocal']] - ) - - vec_reset_array = self.time_dependence.reset_array(target) - - run_solver_calls = (struct_assignment,) + ( - rhs_call, - ) + vec_place_array + ( - initguess_call, - dm_local_to_global_x, - snes_solve, - dm_global_to_local_x, - vec_reset_array, - BlankLine, - ) - return run_solver_calls - - @cached_property - def spatial_body(self): - spatial_body = [] - # TODO: remove the iters[0] - for tree in retrieve_iteration_tree(self.iters[0]): - root = filter_iterations(tree, key=lambda i: i.dim.is_Space) - if root: - root = root[0] - if self.inject_solve in FindNodes(PetscMetaData).visit(root): - spatial_body.append(root) - spatial_body, = spatial_body - return spatial_body - - -class CoupledSolver(Solver): - def _execute_solve(self): - """ - Assigns the required time iterators to the struct and executes - the necessary calls to execute the SNES solver. - """ - sobjs = self.solver_objs - xglob = sobjs['xglobal'] - - struct_assignment = self.time_dependence.assign_time_iters(sobjs['userctx']) - targets = self.inject_solve.expr.rhs.field_data.targets - - # TODO: optimise the ccode generated here - pre_solve = () - post_solve = () - - for i, t in enumerate(targets): - name = t.name - dm = sobjs[f'da{name}'] - target_xloc = sobjs[f'xlocal{name}'] - target_xglob = sobjs[f'xglobal{name}'] - field = sobjs['fields'].indexed[i] - s = sobjs[f'scatter{name}'] - - pre_solve += ( - # TODO: Switch to createwitharray and move to setup - petsc_call('DMCreateLocalVector', [dm, Byref(target_xloc)]), - - # TODO: Need to call reset array - self.time_dependence.place_array(t), - petsc_call( - 'DMLocalToGlobal', - [dm, target_xloc, insert_vals, target_xglob] - ), - petsc_call( - 'VecScatterCreate', - [xglob, field, target_xglob, Null, Byref(s)] - ), - petsc_call( - 'VecScatterBegin', - [s, target_xglob, xglob, insert_vals, sreverse] - ), - petsc_call( - 'VecScatterEnd', - [s, target_xglob, xglob, insert_vals, sreverse] - ), - BlankLine, - ) - - post_solve += ( - petsc_call( - 'VecScatterBegin', - [s, xglob, target_xglob, insert_vals, sforward] - ), - petsc_call( - 'VecScatterEnd', - [s, xglob, target_xglob, insert_vals, sforward] - ), - petsc_call( - 'DMGlobalToLocal', - [dm, target_xglob, insert_vals, target_xloc] - ) - ) - - snes_solve = (petsc_call('SNESSolve', [sobjs['snes'], Null, xglob]),) - - return (struct_assignment,) + pre_solve + snes_solve + post_solve + (BlankLine,) - - -class NonTimeDependent: - def __init__(self, **kwargs): - self.inject_solve = kwargs.get('inject_solve') - self.iters = kwargs.get('iters') - self.sobjs = kwargs.get('solver_objs') - self.kwargs = kwargs - self.origin_to_moddim = self._origin_to_moddim_mapper(self.iters) - self.time_idx_to_symb = self.inject_solve.expr.rhs.time_mapper - - def _origin_to_moddim_mapper(self, iters): - return {} - - def uxreplace_time(self, body): - return body - - def place_array(self, target): - sobjs = self.sobjs - - field_from_ptr = FieldFromPointer( - target.function._C_field_data, target.function._C_symbol - ) - xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal']) - return (petsc_call('VecPlaceArray', [xlocal, field_from_ptr]),) - - def reset_array(self, target): - """ - """ - sobjs = self.sobjs - xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal']) - return ( - petsc_call('VecResetArray', [xlocal]) - ) - - def assign_time_iters(self, struct): - return [] - - -class TimeDependent(NonTimeDependent): - """ - A class for managing time-dependent solvers. - - This includes scenarios where the target is not directly a `TimeFunction`, - but depends on other functions that are. - - Outline of time loop abstraction with PETSc: - - - At PETScSolve, time indices are replaced with temporary `Symbol` objects - via a mapper (e.g., {t: tau0, t + dt: tau1}) to prevent the time loop - from being generated in the callback functions. These callbacks, needed - for each `SNESSolve` at every time step, don't require the time loop, but - may still need access to data from other time steps. - - All `Function` objects are passed through the initial lowering via the - `SolverMetaData` object, ensuring the correct time loop is generated - in the main kernel. - - Another mapper is created based on the modulo dimensions - generated by the `SolverMetaData` object in the main kernel - (e.g., {time: time, t: t0, t + 1: t1}). - - These two mappers are used to generate a final mapper `symb_to_moddim` - (e.g. {tau0: t0, tau1: t1}) which is used at the IET level to - replace the temporary `Symbol` objects in the callback functions with - the correct modulo dimensions. - - Modulo dimensions are updated in the matrix context struct at each time - step and can be accessed in the callback functions where needed. - """ - @property - def time_spacing(self): - return self.inject_solve.expr.rhs.grid.stepping_dim.spacing - - @cached_property - def symb_to_moddim(self): - """ - Maps temporary `Symbol` objects created during `PETScSolve` to their - corresponding modulo dimensions (e.g. creates {tau0: t0, tau1: t1}). - """ - mapper = { - v: k.xreplace({self.time_spacing: 1, -self.time_spacing: -1}) - for k, v in self.time_idx_to_symb.items() - } - return {symb: self.origin_to_moddim[mapper[symb]] for symb in mapper} - - def is_target_time(self, target): - return any(i.is_Time for i in target.dimensions) - - def target_time(self, target): - target_time = [ - i for i, d in zip(target.indices, target.dimensions) - if d.is_Time - ] - assert len(target_time) == 1 - target_time = target_time.pop() - return target_time - - def uxreplace_time(self, body): - return Uxreplace(self.symb_to_moddim).visit(body) - - def _origin_to_moddim_mapper(self, iters): - """ - Creates a mapper of the origin of the time dimensions to their corresponding - modulo dimensions from a list of `Iteration` objects. - - Examples - -------- - >>> iters - (, - ) - >>> _origin_to_moddim_mapper(iters) - {time: time, t: t0, t + 1: t1} - """ - time_iter = [i for i in iters if any(d.is_Time for d in i.dimensions)] - mapper = {} - - if not time_iter: - return mapper - - for i in time_iter: - for d in i.dimensions: - if d.is_Modulo: - mapper[d.origin] = d - elif d.is_Time: - mapper[d] = d - return mapper - - def place_array(self, target): - """ - In the case that the actual target is time-dependent e.g a `TimeFunction`, - a pointer to the first element in the array that will be updated during - the time step is passed to VecPlaceArray(). - - Examples - -------- - >>> target - f1(time + dt, x, y) - >>> calls = place_array(target) - >>> print(List(body=calls)) - float * f1_ptr0 = (time + 1)*localsize0 + (float*)(f1_vec->data); - PetscCall(VecPlaceArray(xlocal0,f1_ptr0)); - - >>> target - f1(t + dt, x, y) - >>> calls = place_array(target) - >>> print(List(body=calls)) - float * f1_ptr0 = t1*localsize0 + (float*)(f1_vec->data); - PetscCall(VecPlaceArray(xlocal0,f1_ptr0)); - """ - sobjs = self.sobjs - - if self.is_target_time(target): - mapper = {self.time_spacing: 1, -self.time_spacing: -1} - - target_time = self.target_time(target).xreplace(mapper) - target_time = self.origin_to_moddim.get(target_time, target_time) - - xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal']) - start_ptr = sobjs[f'{target.name}_ptr'] - - caster = cast(target.dtype, '*') - return ( - DummyExpr( - start_ptr, - caster( - FieldFromPointer(target._C_field_data, target._C_symbol) - ) + Mul(target_time, sobjs['localsize']), - init=True - ), - petsc_call('VecPlaceArray', [xlocal, start_ptr]) - ) - return super().place_array(target) - - def assign_time_iters(self, struct): - """ - Assign required time iterators to the struct. - These iterators are updated at each timestep in the main kernel - for use in callback functions. - - Examples - -------- - >>> struct - ctx - >>> struct.fields - [h_x, x_M, x_m, f1(t, x), t0, t1] - >>> assigned = assign_time_iters(struct) - >>> print(assigned[0]) - ctx.t0 = t0; - >>> print(assigned[1]) - ctx.t1 = t1; - """ - to_assign = [ - f for f in struct.fields if (f.is_Dimension and (f.is_Time or f.is_Modulo)) - ] - time_iter_assignments = [ - DummyExpr(FieldFromComposite(field, struct), field) - for field in to_assign - ] - return time_iter_assignments - - -void = 'void' -insert_vals = 'INSERT_VALUES' -add_vals = 'ADD_VALUES' -sreverse = 'SCATTER_REVERSE' -sforward = 'SCATTER_FORWARD' diff --git a/devito/petsc/iet/solve.py b/devito/petsc/iet/solve.py new file mode 100644 index 0000000000..f6c1fa22d5 --- /dev/null +++ b/devito/petsc/iet/solve.py @@ -0,0 +1,156 @@ +from functools import cached_property + +from devito.ir.iet import ( + BlankLine, FindNodes, retrieve_iteration_tree, filter_iterations +) +from devito.symbolics import Byref, Null + +from devito.petsc.iet.nodes import PetscMetaData, petsc_call +from devito.petsc.types.modes import InsertMode, ScatterMode + + +class Solve: + def __init__(self, **kwargs): + self.inject_solve = kwargs.get('inject_solve') + self.objs = kwargs.get('objs') + self.solver_objs = kwargs.get('solver_objs') + self.iters = kwargs.get('iters') + self.callback_builder = kwargs.get('callback_builder') + self.time_dependence = kwargs.get('time_dependence') + self.calls = self._execute_solve() + + def _execute_solve(self): + """ + Assigns the required time iterators to the struct and executes + the necessary calls to execute the SNES solver. + """ + sobjs = self.solver_objs + target = self.inject_solve.expr.rhs.field_data.target + + struct_assignment = self.time_dependence.assign_time_iters(sobjs['userctx']) + + b_efunc = self.callback_builder._b_efunc + + dmda = sobjs['dmda'] + + rhs_call = petsc_call(b_efunc.name, [sobjs['dmda'], sobjs['bglobal']]) + + vec_place_array = self.time_dependence.place_array(target) + + if self.callback_builder.initial_guesses: + initguess = self.callback_builder.initial_guesses[0] + initguess_call = petsc_call(initguess.name, [dmda, sobjs['xlocal']]) + else: + initguess_call = None + + dm_local_to_global_x = petsc_call( + 'DMLocalToGlobal', [dmda, sobjs['xlocal'], insert_values, + sobjs['xglobal']] + ) + + snes_solve = petsc_call('SNESSolve', [ + sobjs['snes'], sobjs['bglobal'], sobjs['xglobal']] + ) + + dm_global_to_local_x = petsc_call('DMGlobalToLocal', [ + dmda, sobjs['xglobal'], insert_values, sobjs['xlocal']] + ) + + vec_reset_array = self.time_dependence.reset_array(target) + + run_solver_calls = (struct_assignment,) + ( + rhs_call, + ) + vec_place_array + ( + initguess_call, + dm_local_to_global_x, + snes_solve, + dm_global_to_local_x, + vec_reset_array, + BlankLine, + ) + return run_solver_calls + + @cached_property + def spatial_body(self): + spatial_body = [] + # TODO: remove the iters[0] + for tree in retrieve_iteration_tree(self.iters[0]): + root = filter_iterations(tree, key=lambda i: i.dim.is_Space) + if root: + root = root[0] + if self.inject_solve in FindNodes(PetscMetaData).visit(root): + spatial_body.append(root) + spatial_body, = spatial_body + return spatial_body + + +class CoupledSolve(Solve): + def _execute_solve(self): + """ + Assigns the required time iterators to the struct and executes + the necessary calls to execute the SNES solver. + """ + sobjs = self.solver_objs + xglob = sobjs['xglobal'] + + struct_assignment = self.time_dependence.assign_time_iters(sobjs['userctx']) + targets = self.inject_solve.expr.rhs.field_data.targets + + # TODO: optimise the ccode generated here + pre_solve = () + post_solve = () + + for i, t in enumerate(targets): + name = t.name + dm = sobjs[f'da{name}'] + target_xloc = sobjs[f'xlocal{name}'] + target_xglob = sobjs[f'xglobal{name}'] + field = sobjs['fields'].indexed[i] + s = sobjs[f'scatter{name}'] + + pre_solve += ( + # TODO: Need to call reset array + self.time_dependence.place_array(t), + petsc_call( + 'DMLocalToGlobal', + [dm, target_xloc, insert_values, target_xglob] + ), + petsc_call( + 'VecScatterCreate', + [xglob, field, target_xglob, Null, Byref(s)] + ), + petsc_call( + 'VecScatterBegin', + [s, target_xglob, xglob, insert_values, scatter_reverse] + ), + petsc_call( + 'VecScatterEnd', + [s, target_xglob, xglob, insert_values, scatter_reverse] + ), + BlankLine, + ) + + post_solve += ( + petsc_call( + 'VecScatterBegin', + [s, xglob, target_xglob, insert_values, scatter_forward] + ), + petsc_call( + 'VecScatterEnd', + [s, xglob, target_xglob, insert_values, scatter_forward] + ), + petsc_call( + 'DMGlobalToLocal', + [dm, target_xglob, insert_values, target_xloc] + ) + ) + + snes_solve = (petsc_call('SNESSolve', [sobjs['snes'], Null, xglob]),) + + return (struct_assignment,) + pre_solve + snes_solve + post_solve + (BlankLine,) + + +insert_values = InsertMode.insert_values +add_values = InsertMode.add_values +scatter_reverse = ScatterMode.scatter_reverse +scatter_forward = ScatterMode.scatter_forward diff --git a/devito/petsc/iet/time_dependence.py b/devito/petsc/iet/time_dependence.py new file mode 100644 index 0000000000..abc1b7d69e --- /dev/null +++ b/devito/petsc/iet/time_dependence.py @@ -0,0 +1,197 @@ +from functools import cached_property + +from devito.ir.iet import Uxreplace, DummyExpr +from devito.symbolics import FieldFromPointer, cast, FieldFromComposite +from devito.symbolics.unevaluation import Mul + +from devito.petsc.iet.nodes import petsc_call + + +class TimeBase: + def __init__(self, **kwargs): + self.inject_solve = kwargs.get('inject_solve') + self.iters = kwargs.get('iters') + self.sobjs = kwargs.get('solver_objs') + self.kwargs = kwargs + self.origin_to_moddim = self._origin_to_moddim_mapper(self.iters) + self.time_idx_to_symb = self.inject_solve.expr.rhs.time_mapper + + def _origin_to_moddim_mapper(self, iters): + return {} + + def uxreplace_time(self, body): + return body + + def place_array(self, target): + return () + + def reset_array(self, target): + return () + + def assign_time_iters(self, struct): + return [] + + +class TimeIndependent(TimeBase): + pass + + +class TimeDependent(TimeBase): + """ + A class for managing time-dependent solvers. + This includes scenarios where the target is not directly a `TimeFunction`, + but depends on other functions that are. + Outline of time loop abstraction with PETSc: + - At `petscsolve`, time indices are replaced with temporary `Symbol` objects + via a mapper (e.g., {t: tau0, t + dt: tau1}) to prevent the time loop + from being generated in the callback functions. These callbacks, needed + for each `SNESSolve` at every time step, don't require the time loop, but + may still need access to data from other time steps. + - All `Function` objects are passed through the initial lowering via the + `SolverMetaData` object, ensuring the correct time loop is generated + in the main kernel. + - Another mapper is created based on the modulo dimensions + generated by the `SolverMetaData` object in the main kernel + (e.g., {time: time, t: t0, t + 1: t1}). + - These two mappers are used to generate a final mapper `symb_to_moddim` + (e.g. {tau0: t0, tau1: t1}) which is used at the IET level to + replace the temporary `Symbol` objects in the callback functions with + the correct modulo dimensions. + - Modulo dimensions are updated in the matrix context struct at each time + step and can be accessed in the callback functions where needed. + """ + @property + def time_spacing(self): + return self.inject_solve.expr.rhs.grid.stepping_dim.spacing + + @cached_property + def symb_to_moddim(self): + """ + Maps temporary `Symbol` objects created during `petscsolve` to their + corresponding modulo dimensions (e.g. creates {tau0: t0, tau1: t1}). + """ + mapper = { + v: k.xreplace({self.time_spacing: 1, -self.time_spacing: -1}) + for k, v in self.time_idx_to_symb.items() + } + return {symb: self.origin_to_moddim[mapper[symb]] for symb in mapper} + + def is_target_time(self, target): + return any(i.is_Time for i in target.dimensions) + + def target_time(self, target): + target_time = [ + i for i, d in zip(target.indices, target.dimensions) + if d.is_Time + ] + assert len(target_time) == 1 + target_time = target_time.pop() + return target_time + + def uxreplace_time(self, body): + return Uxreplace(self.symb_to_moddim).visit(body) + + def _origin_to_moddim_mapper(self, iters): + """ + Creates a mapper of the origin of the time dimensions to their corresponding + modulo dimensions from a list of `Iteration` objects. + Examples + -------- + >>> iters + (, + ) + >>> _origin_to_moddim_mapper(iters) + {time: time, t: t0, t + 1: t1} + """ + time_iter = [i for i in iters if any(d.is_Time for d in i.dimensions)] + mapper = {} + + if not time_iter: + return mapper + + for i in time_iter: + for d in i.dimensions: + if d.is_Modulo: + mapper[d.origin] = d + elif d.is_Time: + mapper[d] = d + return mapper + + def place_array(self, target): + """ + In the case that the actual target is time-dependent e.g a `TimeFunction`, + a pointer to the first element in the array that will be updated during + the time step is passed to VecPlaceArray(). + Examples + -------- + >>> target + f1(time + dt, x, y) + >>> calls = place_array(target) + >>> print(List(body=calls)) + float * f1_ptr0 = (time + 1)*localsize0 + (float*)(f1_vec->data); + PetscCall(VecPlaceArray(xlocal0,f1_ptr0)); + >>> target + f1(t + dt, x, y) + >>> calls = place_array(target) + >>> print(List(body=calls)) + float * f1_ptr0 = t1*localsize0 + (float*)(f1_vec->data); + PetscCall(VecPlaceArray(xlocal0,f1_ptr0)); + """ + sobjs = self.sobjs + + if self.is_target_time(target): + mapper = {self.time_spacing: 1, -self.time_spacing: -1} + + target_time = self.target_time(target).xreplace(mapper) + target_time = self.origin_to_moddim.get(target_time, target_time) + + xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal']) + start_ptr = sobjs[f'{target.name}_ptr'] + + caster = cast(target.dtype, '*') + return ( + DummyExpr( + start_ptr, + caster( + FieldFromPointer(target._C_field_data, target._C_symbol) + ) + Mul(target_time, sobjs['localsize']), + init=True + ), + petsc_call('VecPlaceArray', [xlocal, start_ptr]) + ) + return super().place_array(target) + + def reset_array(self, target): + if self.is_target_time(target): + sobjs = self.sobjs + xlocal = sobjs.get(f'xlocal{target.name}', sobjs['xlocal']) + return ( + petsc_call('VecResetArray', [xlocal]) + ) + return super().reset_array(target) + + def assign_time_iters(self, struct): + """ + Assign required time iterators to the struct. + These iterators are updated at each timestep in the main kernel + for use in callback functions. + Examples + -------- + >>> struct + ctx + >>> struct.fields + [h_x, x_M, x_m, f1(t, x), t0, t1] + >>> assigned = assign_time_iters(struct) + >>> print(assigned[0]) + ctx.t0 = t0; + >>> print(assigned[1]) + ctx.t1 = t1; + """ + to_assign = [ + f for f in struct.fields if (f.is_Dimension and (f.is_Time or f.is_Modulo)) + ] + time_iter_assignments = [ + DummyExpr(FieldFromComposite(field, struct), field) + for field in to_assign + ] + return time_iter_assignments diff --git a/devito/petsc/iet/type_builder.py b/devito/petsc/iet/type_builder.py new file mode 100644 index 0000000000..8462ebb916 --- /dev/null +++ b/devito/petsc/iet/type_builder.py @@ -0,0 +1,252 @@ +import numpy as np + +from devito.symbolics import String +from devito.types import Symbol +from devito.tools import frozendict + +from devito.petsc.types import ( + PetscBundle, DM, Mat, CallbackVec, Vec, KSP, PC, SNES, PetscInt, StartPtr, + PointerIS, PointerDM, VecScatter, JacobianStruct, SubMatrixStruct, CallbackDM, + PetscMPIInt, PetscErrorCode, PointerMat, MatReuse, CallbackPointerDM, + CallbackPointerIS, CallbackMat, DummyArg, NofSubMats +) + + +class BaseTypeBuilder: + """ + A base class for constructing objects needed for a PETSc solver. + Designed to be extended by subclasses, which can override the `_extend_build` + method to support specific use cases. + """ + def __init__(self, **kwargs): + self.inject_solve = kwargs.get('inject_solve') + self.objs = kwargs.get('objs') + self.sregistry = kwargs.get('sregistry') + self.comm = kwargs.get('comm') + self.field_data = self.inject_solve.expr.rhs.field_data + self.solver_objs = self._build() + + def _build(self): + """ + # TODO: update docs + Constructs the core dictionary of solver objects and allows + subclasses to extend or modify it via `_extend_build`. + Returns: + dict: A dictionary containing the following objects: + - 'Jac' (Mat): A matrix representing the jacobian. + - 'xglobal' (GlobalVec): The global solution vector. + - 'xlocal' (LocalVec): The local solution vector. + - 'bglobal': (GlobalVec) Global RHS vector `b`, where `F(x) = b`. + - 'blocal': (LocalVec) Local RHS vector `b`, where `F(x) = b`. + - 'ksp': (KSP) Krylov solver object that manages the linear solver. + - 'pc': (PC) Preconditioner object. + - 'snes': (SNES) Nonlinear solver object. + - 'localsize' (PetscInt): The local length of the solution vector. + - 'dmda' (DM): The DMDA object associated with this solve, linked to + the SNES object via `SNESSetDM`. + - 'callbackdm' (CallbackDM): The DM object accessed within callback + functions via `SNESGetDM`. + """ + sreg = self.sregistry + targets = self.field_data.targets + + snes_name = sreg.make_name(prefix='snes') + formatted_prefix = self.inject_solve.expr.rhs.formatted_prefix + + base_dict = { + 'Jac': Mat(sreg.make_name(prefix='J')), + 'xglobal': Vec(sreg.make_name(prefix='xglobal')), + 'xlocal': Vec(sreg.make_name(prefix='xlocal')), + 'bglobal': Vec(sreg.make_name(prefix='bglobal')), + 'blocal': CallbackVec(sreg.make_name(prefix='blocal')), + 'ksp': KSP(sreg.make_name(prefix='ksp')), + 'pc': PC(sreg.make_name(prefix='pc')), + 'snes': SNES(snes_name), + 'localsize': PetscInt(sreg.make_name(prefix='localsize')), + 'dmda': DM(sreg.make_name(prefix='da'), dofs=len(targets)), + 'callbackdm': CallbackDM(sreg.make_name(prefix='dm')), + 'snes_prefix': String(formatted_prefix), + } + + base_dict['comm'] = self.comm + self._target_dependent(base_dict) + return self._extend_build(base_dict) + + def _target_dependent(self, base_dict): + """ + '_ptr' (StartPtr): A pointer to the beginning of the solution array + that will be updated at each time step. + """ + sreg = self.sregistry + target = self.field_data.target + base_dict[f'{target.name}_ptr'] = StartPtr( + sreg.make_name(prefix=f'{target.name}_ptr'), target.dtype + ) + + def _extend_build(self, base_dict): + """ + Subclasses can override this method to extend or modify the + base dictionary of solver objects. + """ + return base_dict + + +class CoupledTypeBuilder(BaseTypeBuilder): + def _extend_build(self, base_dict): + sreg = self.sregistry + objs = self.objs + targets = self.field_data.targets + arrays = self.field_data.arrays + + base_dict['fields'] = PointerIS( + name=sreg.make_name(prefix='fields'), nindices=len(targets) + ) + base_dict['subdms'] = PointerDM( + name=sreg.make_name(prefix='subdms'), nindices=len(targets) + ) + base_dict['nfields'] = PetscInt(sreg.make_name(prefix='nfields')) + + space_dims = len(self.field_data.grid.dimensions) + + dim_labels = ["M", "N", "P"] + base_dict.update({ + dim_labels[i]: PetscInt(dim_labels[i]) for i in range(space_dims) + }) + + submatrices = self.field_data.jacobian.nonzero_submatrices + + base_dict['jacctx'] = JacobianStruct( + name=sreg.make_name(prefix=objs['ljacctx'].name), + fields=objs['ljacctx'].fields, + ) + + for sm in submatrices: + name = sm.name + base_dict[name] = Mat(name=name) + base_dict[f'{name}ctx'] = SubMatrixStruct( + name=f'{name}ctx', + fields=objs['subctx'].fields, + ) + base_dict[f'{name}X'] = CallbackVec(f'{name}X') + base_dict[f'{name}Y'] = CallbackVec(f'{name}Y') + base_dict[f'{name}F'] = CallbackVec(f'{name}F') + + # Bundle objects/metadata required by the coupled residual callback + f_components, x_components = [], [] + bundle_mapper = {} + pname = sreg.make_name(prefix='Field') + + target_indices = {t: i for i, t in enumerate(targets)} + + for t in targets: + f_arr = arrays[t]['f'] + x_arr = arrays[t]['x'] + f_components.append(f_arr) + x_components.append(x_arr) + + fbundle = PetscBundle( + name='f_bundle', components=f_components, pname=pname + ) + xbundle = PetscBundle( + name='x_bundle', components=x_components, pname=pname + ) + + # Build the bundle mapper + for f_arr, x_arr in zip(f_components, x_components): + bundle_mapper[f_arr.base] = fbundle + bundle_mapper[x_arr.base] = xbundle + + base_dict['bundles'] = { + 'f': fbundle, + 'x': xbundle, + 'bundle_mapper': bundle_mapper, + 'target_indices': target_indices + } + + return base_dict + + def _target_dependent(self, base_dict): + sreg = self.sregistry + targets = self.field_data.targets + for t in targets: + name = t.name + base_dict[f'{name}_ptr'] = StartPtr( + sreg.make_name(prefix=f'{name}_ptr'), t.dtype + ) + base_dict[f'xlocal{name}'] = CallbackVec( + sreg.make_name(prefix=f'xlocal{name}'), liveness='eager' + ) + base_dict[f'Fglobal{name}'] = CallbackVec( + sreg.make_name(prefix=f'Fglobal{name}'), liveness='eager' + ) + base_dict[f'Xglobal{name}'] = CallbackVec( + sreg.make_name(prefix=f'Xglobal{name}') + ) + base_dict[f'xglobal{name}'] = Vec( + sreg.make_name(prefix=f'xglobal{name}') + ) + base_dict[f'blocal{name}'] = CallbackVec( + sreg.make_name(prefix=f'blocal{name}'), liveness='eager' + ) + base_dict[f'bglobal{name}'] = Vec( + sreg.make_name(prefix=f'bglobal{name}') + ) + base_dict[f'da{name}'] = DM( + sreg.make_name(prefix=f'da{name}'), liveness='eager' + ) + base_dict[f'scatter{name}'] = VecScatter( + sreg.make_name(prefix=f'scatter{name}') + ) + + +subdms = PointerDM(name='subdms') +fields = PointerIS(name='fields') +submats = PointerMat(name='submats') +rows = PointerIS(name='rows') +cols = PointerIS(name='cols') + + +# A static dict containing shared symbols and objects that are not +# unique to each `petscsolve` call. +# Many of these objects are used as arguments in callback functions to make +# the C code cleaner and more modular. +objs = frozendict({ + 'size': PetscMPIInt(name='size'), + 'err': PetscErrorCode(name='err'), + 'block': CallbackMat('block'), + 'submat_arr': PointerMat(name='submat_arr'), + 'subblockrows': PetscInt('subblockrows'), + 'subblockcols': PetscInt('subblockcols'), + 'rowidx': PetscInt('rowidx'), + 'colidx': PetscInt('colidx'), + 'J': Mat('J'), + 'X': Vec('X'), + 'xloc': CallbackVec('xloc'), + 'Y': Vec('Y'), + 'yloc': CallbackVec('yloc'), + 'F': Vec('F'), + 'floc': CallbackVec('floc'), + 'B': Vec('B'), + 'nfields': PetscInt('nfields'), + 'irow': PointerIS(name='irow'), + 'icol': PointerIS(name='icol'), + 'nsubmats': NofSubMats('nsubmats', dtype=np.int32), + # 'nsubmats': PetscInt('nsubmats'), + 'matreuse': MatReuse('scall'), + 'snes': SNES('snes'), + 'rows': rows, + 'cols': cols, + 'Subdms': subdms, + 'LocalSubdms': CallbackPointerDM(name='subdms'), + 'Fields': fields, + 'LocalFields': CallbackPointerIS(name='fields'), + 'Submats': submats, + 'ljacctx': JacobianStruct( + fields=[subdms, fields, submats], modifier=' *' + ), + 'subctx': SubMatrixStruct(fields=[rows, cols]), + 'dummyctx': Symbol('lctx'), + 'dummyptr': DummyArg('dummy'), + 'dummyefunc': Symbol('dummyefunc'), + 'dof': PetscInt('dof'), +}) diff --git a/devito/petsc/iet/utils.py b/devito/petsc/iet/utils.py deleted file mode 100644 index d7ccfb2b4e..0000000000 --- a/devito/petsc/iet/utils.py +++ /dev/null @@ -1,73 +0,0 @@ -from devito.ir.equations import OpPetsc -from devito.ir.iet import Dereference, FindSymbols, Uxreplace -from devito.types.basic import AbstractFunction - -from devito.petsc.iet.nodes import PetscMetaData, PETScCall - - -def petsc_call(specific_call, call_args): - return PETScCall('PetscCall', [PETScCall(specific_call, arguments=call_args)]) - - -def petsc_call_mpi(specific_call, call_args): - return PETScCall('PetscCallMPI', [PETScCall(specific_call, arguments=call_args)]) - - -def petsc_struct(name, fields, pname, liveness='lazy', modifier=None): - # TODO: Fix this circular import - from devito.petsc.types.object import PETScStruct - return PETScStruct(name=name, pname=pname, - fields=fields, liveness=liveness, - modifier=modifier) - - -def zero_vector(vec): - """ - Set all entries of a PETSc vector to zero. - """ - return petsc_call('VecSet', [vec, 0.0]) - - -def dereference_funcs(struct, fields): - """ - Dereference AbstractFunctions from a struct. - """ - return tuple( - [Dereference(i, struct) for i in - fields if isinstance(i.function, AbstractFunction)] - ) - - -def residual_bundle(body, bundles): - """ - Replaces PetscArrays in `body` with PetscBundle struct field accesses - (e.g., f_v[ix][iy] -> f_bundle[ix][iy].v). - - Example: - f_v[ix][iy] = x_v[ix][iy]; - f_u[ix][iy] = x_u[ix][iy]; - becomes: - f_bundle[ix][iy].v = x_bundle[ix][iy].v; - f_bundle[ix][iy].u = x_bundle[ix][iy].u; - - NOTE: This is used because the data is interleaved for - multi-component DMDAs in PETSc. - """ - mapper = bundles['bundle_mapper'] - indexeds = FindSymbols('indexeds').visit(body) - subs = {} - - for i in indexeds: - if i.base in mapper: - bundle = mapper[i.base] - index = bundles['target_indices'][i.function.target] - index = (index,) + i.indices - subs[i] = bundle.__getitem__(index) - - body = Uxreplace(subs).visit(body) - return body - - -# Mapping special Eq operations to their corresponding IET Expression subclass types. -# These operations correspond to subclasses of Eq utilised within PETScSolve. -petsc_iet_mapper = {OpPetsc: PetscMetaData} diff --git a/devito/petsc/initialize.py b/devito/petsc/initialize.py index a4c136f71a..f225caa4b7 100644 --- a/devito/petsc/initialize.py +++ b/devito/petsc/initialize.py @@ -6,6 +6,7 @@ from devito import Operator, switchconfig from devito.types import Symbol from devito.types.equation import PetscEq + from devito.petsc.types import Initialize, Finalize global _petsc_initialized diff --git a/devito/petsc/logging.py b/devito/petsc/logging.py index 979b4b582a..acbf4cc86f 100644 --- a/devito/petsc/logging.py +++ b/devito/petsc/logging.py @@ -7,7 +7,7 @@ from devito.petsc.types import ( PetscInt, PetscScalar, KSPType, KSPConvergedReason, KSPNormType ) -from devito.petsc.utils import petsc_type_to_ctype +from devito.petsc.config import petsc_type_to_ctype class PetscEntry: @@ -29,6 +29,7 @@ def __repr__(self): class PetscSummary(dict): """ + # TODO: Actually print to screen when DEBUG of PERF is enabled A summary of PETSc statistics collected for all solver runs associated with a single operator during execution. """ @@ -49,8 +50,8 @@ def __init__(self, params, *args, **kwargs): # Dynamically create a property on this class for each PETSc function self._add_properties() - # Initialize the summary by adding PETSc information from each PetscInfo - # object (each corresponding to an individual PETScSolve) + # Initialize the summary with PETSc information from each `PetscInfo` + # object (each corresponding to a `petscsolve` call) for i in self.petscinfos: self.add_info(i) @@ -68,8 +69,8 @@ def petsc_entry(self, petscinfo): Create a named tuple entry for the given PetscInfo object, containing the values for each PETSc function call. """ - # Collect the function names associated with this PetscInfo - # instance (i.e., for a single PETScSolve). + # Collect the function names from this `PetscInfo` + # instance (specific to its `petscsolve` call). funcs = [ petsc_return_variable_dict[f].name for f in petscinfo.query_functions ] diff --git a/devito/petsc/solve.py b/devito/petsc/solve.py index 86b6d624e4..3856392436 100644 --- a/devito/petsc/solve.py +++ b/devito/petsc/solve.py @@ -1,19 +1,23 @@ from devito.types.equation import PetscEq -from devito.tools import as_tuple -from devito.petsc.types import (LinearSolverMetaData, PETScArray, DMDALocalInfo, - FieldData, MultipleFieldData, Jacobian, Residual, - MixedResidual, MixedJacobian, InitialGuess) +from devito.tools import filter_ordered, as_tuple +from devito.types import Symbol, SteppingDimension, TimeDimension +from devito.operations.solve import eval_time_derivatives +from devito.symbolics import retrieve_functions, retrieve_dimensions + +from devito.petsc.types import ( + LinearSolverMetaData, PETScArray, DMDALocalInfo, FieldData, MultipleFieldData, + Jacobian, Residual, MixedResidual, MixedJacobian, InitialGuess +) from devito.petsc.types.equation import EssentialBC -from devito.petsc.solver_parameters import (linear_solver_parameters, - format_options_prefix) -from devito.petsc.utils import get_funcs, generate_time_mapper +from devito.petsc.solver_parameters import ( + linear_solver_parameters, format_options_prefix +) -__all__ = ['PETScSolve'] +__all__ = ['petscsolve'] -# TODO: Rename this to petsc_solve, petscsolve? -def PETScSolve(target_exprs, target=None, solver_parameters=None, +def petscsolve(target_exprs, target=None, solver_parameters=None, options_prefix=None, get_info=[]): """ Returns a symbolic expression representing a linear PETSc solver, @@ -31,13 +35,13 @@ def PETScSolve(target_exprs, target=None, solver_parameters=None, - Single-field problem: Pass a single Eq or list of Eq, and specify `target` separately: - PETScSolve(Eq1, target) - PETScSolve([Eq1, Eq2], target) + petscsolve(Eq1, target) + petscsolve([Eq1, Eq2], target) - Multi-field (mixed) problem: Pass a dictionary mapping each target field to its Eq(s): - PETScSolve({u: Eq1, v: Eq2}) - PETScSolve({u: [Eq1, Eq2], v: [Eq3, Eq4]}) + petscsolve({u: Eq1, v: Eq2}) + petscsolve({u: [Eq1, Eq2], v: [Eq3, Eq4]}) target : Function-like The function (e.g., `Function`, `TimeFunction`) into which the linear @@ -120,14 +124,14 @@ def linear_solve_args(self): exprs = as_tuple(exprs) funcs = get_funcs(exprs) - self.time_mapper = generate_time_mapper(funcs) + self.time_mapper = generate_time_mapper(exprs) arrays = self.generate_arrays(target) exprs = sorted(exprs, key=lambda e: not isinstance(e, EssentialBC)) jacobian = Jacobian(target, exprs, arrays, self.time_mapper) residual = Residual(target, exprs, arrays, self.time_mapper, jacobian.scdiag) - initial_guess = InitialGuess(target, exprs, arrays) + initial_guess = InitialGuess(target, exprs, arrays, self.time_mapper) field_data = FieldData( target=target, @@ -137,7 +141,7 @@ def linear_solve_args(self): arrays=arrays ) - return target, tuple(funcs), field_data + return target, funcs, field_data def generate_arrays(self, *targets): return { @@ -162,7 +166,7 @@ def linear_solve_args(self): exprs.extend(e) funcs = get_funcs(exprs) - self.time_mapper = generate_time_mapper(funcs) + self.time_mapper = generate_time_mapper(exprs) targets = list(self.target_exprs.keys()) arrays = self.generate_arrays(*targets) @@ -183,7 +187,47 @@ def linear_solve_args(self): residual=residual ) - return targets[0], tuple(funcs), all_data + return targets[0], funcs, all_data + + +def get_funcs(exprs): + funcs = [ + f for e in exprs + for f in retrieve_functions(eval_time_derivatives(e.lhs - e.rhs)) + ] + return as_tuple(filter_ordered(funcs)) + + +def generate_time_mapper(exprs): + """ + Replace time indices with `Symbols` in expressions used within + PETSc callback functions. These symbols are Uxreplaced at the IET + level to align with the `TimeDimension` and `ModuloDimension` objects + present in the initial lowering. + NOTE: All functions used in PETSc callback functions are attached to + the `SolverMetaData` object, which is passed through the initial lowering + (and subsequently dropped and replaced with calls to run the solver). + Therefore, the appropriate time loop will always be correctly generated inside + the main kernel. + Examples + -------- + >>> exprs = (Eq(f1(t + dt, x, y), g1(t + dt, x, y) + g2(t, x, y)*f1(t, x, y)),) + >>> generate_time_mapper(exprs) + {t + dt: tau0, t: tau1} + """ + # First, map any actual TimeDimensions + time_indices = [d for d in retrieve_dimensions(exprs) if isinstance(d, TimeDimension)] + + funcs = get_funcs(exprs) + + time_indices.extend(list({ + i if isinstance(d, SteppingDimension) else d + for f in funcs + for i, d in zip(f.indices, f.dimensions) + if d.is_Time + })) + tau_symbs = [Symbol('tau%d' % i) for i in range(len(time_indices))] + return dict(zip(time_indices, tau_symbs)) localinfo = DMDALocalInfo(name='info', liveness='eager') diff --git a/devito/petsc/solver_parameters.py b/devito/petsc/solver_parameters.py index 7173ec9745..63ea80265b 100644 --- a/devito/petsc/solver_parameters.py +++ b/devito/petsc/solver_parameters.py @@ -1,6 +1,7 @@ -from petsctools import flatten_parameters import itertools +from petsctools import flatten_parameters + # NOTE: Will be extended, the default preconditioner is not going to be 'none' base_solve_defaults = { diff --git a/devito/petsc/types/__init__.py b/devito/petsc/types/__init__.py index f2305a8352..c40b4acf91 100644 --- a/devito/petsc/types/__init__.py +++ b/devito/petsc/types/__init__.py @@ -1,5 +1,6 @@ from .array import * # noqa -from .types import * # noqa +from .metadata import * # noqa from .object import * # noqa from .equation import * # noqa from .macros import * # noqa +from .modes import * # noqa diff --git a/devito/petsc/types/array.py b/devito/petsc/types/array.py index 1bed71ec50..a56490f3af 100644 --- a/devito/petsc/types/array.py +++ b/devito/petsc/types/array.py @@ -20,7 +20,7 @@ class PETScArray(ArrayBasic, Differentiable): PETScArray objects represent vector objects within PETSc. They correspond to the spatial domain of a Function-like object - provided by the user, which is passed to PETScSolve as the target. + provided by the user, which is passed to `petscsolve` as the target. TODO: Potentially re-evaluate and separate into PETScFunction(Differentiable) and then PETScArray(ArrayBasic). diff --git a/devito/petsc/types/equation.py b/devito/petsc/types/equation.py index e819b48a22..fe9611c1fb 100644 --- a/devito/petsc/types/equation.py +++ b/devito/petsc/types/equation.py @@ -6,7 +6,7 @@ class EssentialBC(Eq): """ - Represents an essential boundary condition for use with PETScSolve. + Represents an essential boundary condition for use with `petscsolve`. Due to ongoing work on PetscSection and DMDA integration (WIP), these conditions are imposed as trivial equations. The compiler @@ -16,8 +16,8 @@ class EssentialBC(Eq): Note: - To define an essential boundary condition, use: Eq(target, boundary_value, subdomain=...), - where `target` is the Function-like object passed to PETScSolve. - - SubDomains used for multiple EssentialBCs must not overlap. + where `target` is the Function-like object passed to `petscsolve`. + - SubDomains used for multiple `EssentialBC`s must not overlap. """ pass diff --git a/devito/petsc/types/macros.py b/devito/petsc/types/macros.py index 4355535e64..94d9368b5e 100644 --- a/devito/petsc/types/macros.py +++ b/devito/petsc/types/macros.py @@ -1,5 +1,4 @@ import cgen as c - # TODO: Don't use c.Line here? petsc_func_begin_user = c.Line('PetscFunctionBeginUser;') diff --git a/devito/petsc/types/types.py b/devito/petsc/types/metadata.py similarity index 96% rename from devito/petsc/types/types.py rename to devito/petsc/types/metadata.py index 598fc658fe..d36e088a36 100644 --- a/devito/petsc/types/types.py +++ b/devito/petsc/types/metadata.py @@ -8,7 +8,7 @@ from devito.types.equation import Eq from devito.operations.solve import eval_time_derivatives -from devito.petsc.utils import petsc_variables +from devito.petsc.config import petsc_variables from devito.petsc.types.equation import EssentialBC, ZeroRow, ZeroColumn @@ -212,7 +212,8 @@ def space_dimensions(self): if len(space_dims) > 1: # TODO: This may not actually have to be the case, but enforcing it for now raise ValueError( - "All targets within a PETScSolve have to have the same space dimensions." + "All targets within a `petscsolve` call must have the" + " same space dimensions." ) return space_dims.pop() @@ -222,7 +223,8 @@ def grid(self): grids = [t.grid for t in self.targets] if len(set(grids)) > 1: raise ValueError( - "All targets within a PETScSolve have to have the same grid." + "Multiple `Grid`s detected in `petscsolve`;" + " all targets must share one `Grid`." ) return grids.pop() @@ -236,7 +238,7 @@ def space_order(self): space_orders = [t.space_order for t in self.targets] if len(set(space_orders)) > 1: raise ValueError( - "All targets within a PETScSolve have to have the same space order." + "All targets within a `petscsolve` call must have the same space order." ) return space_orders.pop() @@ -571,9 +573,13 @@ def _make_F_target(self, eq, F_target, targets): # The initial guess satisfies the essential BCs, so this term is zero. # Still included to support Jacobian testing via finite differences. rhs = arrays['x'] - eq.rhs - zero_row = ZeroRow(arrays['f'], rhs, subdomain=eq.subdomain) + zero_row = ZeroRow( + arrays['f'], rhs.subs(self.time_mapper), subdomain=eq.subdomain + ) # Move essential boundary condition to the right-hand side - zero_col = ZeroColumn(arrays['x'], eq.rhs, subdomain=eq.subdomain) + zero_col = ZeroColumn( + arrays['x'], eq.rhs.subs(self.time_mapper), subdomain=eq.subdomain + ) return (zero_row, zero_col) else: @@ -670,9 +676,10 @@ class InitialGuess: symbolic expressions, enforcing the initial guess to satisfy essential boundary conditions. """ - def __init__(self, target, exprs, arrays): + def __init__(self, target, exprs, arrays, time_mapper): self.target = target self.arrays = arrays + self.time_mapper = time_mapper self._build_exprs(as_tuple(exprs)) @property @@ -694,7 +701,7 @@ def _make_initial_guess(self, expr): if isinstance(expr, EssentialBC): assert expr.lhs == self.target return Eq( - self.arrays[self.target]['x'], expr.rhs, + self.arrays[self.target]['x'], expr.rhs.subs(self.time_mapper), subdomain=expr.subdomain ) else: diff --git a/devito/petsc/types/modes.py b/devito/petsc/types/modes.py new file mode 100644 index 0000000000..0850dc38e7 --- /dev/null +++ b/devito/petsc/types/modes.py @@ -0,0 +1,16 @@ +class InsertMode: + """ + How the entries are combined with the current values in the vectors or matrices. + Reference - https://petsc.org/main/manualpages/Sys/InsertMode/ + """ + insert_values = 'INSERT_VALUES' + add_values = 'ADD_VALUES' + + +class ScatterMode: + """ + Determines the direction of a scatter in `VecScatterBegin()` and `VecScatterEnd()`. + Reference - https://petsc.org/release/manualpages/Vec/ScatterMode/ + """ + scatter_reverse = 'SCATTER_REVERSE' + scatter_forward = 'SCATTER_FORWARD' diff --git a/devito/petsc/types/object.py b/devito/petsc/types/object.py index e674a3d014..8db82be365 100644 --- a/devito/petsc/types/object.py +++ b/devito/petsc/types/object.py @@ -1,12 +1,14 @@ from ctypes import POINTER, c_char from devito.tools import CustomDtype, dtype_to_ctype, as_tuple, CustomIntType -from devito.types import (LocalObject, LocalCompositeObject, ModuloDimension, - TimeDimension, ArrayObject, CustomDimension) +from devito.types import ( + LocalObject, LocalCompositeObject, ModuloDimension, TimeDimension, ArrayObject, + CustomDimension, Scalar +) from devito.symbolics import Byref, cast -from devito.types.basic import DataSymbol +from devito.types.basic import DataSymbol, LocalType -from devito.petsc.iet.utils import petsc_call +from devito.petsc.iet.nodes import petsc_call class PetscMixin: @@ -197,7 +199,7 @@ class PETScStruct(LocalCompositeObject): def time_dim_fields(self): """ Fields within the struct that are updated during the time loop. - These are not set in the `PopulateMatContext` callback. + These are not set in the `PopulateUserContext` callback. """ return [f for f in self.fields if isinstance(f, (ModuloDimension, TimeDimension))] @@ -205,7 +207,7 @@ def time_dim_fields(self): @property def callback_fields(self): """ - Fields within the struct that are initialized in the `PopulateMatContext` + Fields within the struct that are initialized in the `PopulateUserContext` callback. These fields are not updated in the time loop. """ return [f for f in self.fields if f not in self.time_dim_fields] @@ -213,6 +215,22 @@ def callback_fields(self): _C_modifier = ' *' +class MainUserStruct(PETScStruct): + pass + + +class CallbackUserStruct(PETScStruct): + __rkwargs__ = PETScStruct.__rkwargs__ + ('parent',) + + def __init__(self, *args, parent=None, **kwargs): + super().__init__(*args, **kwargs) + self._parent = parent + + @property + def parent(self): + return self._parent + + class JacobianStruct(PETScStruct): def __init__(self, name='jctx', pname='JacobianCtx', fields=None, modifier='', liveness='lazy'): @@ -227,7 +245,7 @@ def __init__(self, name='subctx', pname='SubMatrixCtx', fields=None, _C_modifier = None -class PETScArrayObject(PetscMixin, ArrayObject): +class PETScArrayObject(PetscMixin, ArrayObject, LocalType): _data_alignment = False def __init_finalize__(self, *args, **kwargs): @@ -313,6 +331,10 @@ def _C_ctype(self): return POINTER(POINTER(c_char)) +class NofSubMats(Scalar, LocalType): + pass + + FREE_PRIORITY = { PETScArrayObject: 0, Vec: 1, diff --git a/devito/petsc/utils.py b/devito/petsc/utils.py deleted file mode 100644 index a0b5753255..0000000000 --- a/devito/petsc/utils.py +++ /dev/null @@ -1,148 +0,0 @@ -import os -import ctypes -from pathlib import Path - -from devito.tools import memoized_func, filter_ordered -from devito.types import Symbol, SteppingDimension -from devito.operations.solve import eval_time_derivatives -from devito.symbolics import retrieve_functions - - -class PetscOSError(OSError): - pass - - -@memoized_func -def get_petsc_dir(): - petsc_dir = os.environ.get('PETSC_DIR') - if petsc_dir is None: - raise PetscOSError("PETSC_DIR environment variable not set") - else: - petsc_dir = (Path(petsc_dir),) - - petsc_arch = os.environ.get('PETSC_ARCH') - if petsc_arch is not None: - petsc_dir += (petsc_dir[0] / petsc_arch,) - - petsc_installed = petsc_dir[-1] / 'include' / 'petscconf.h' - if not petsc_installed.is_file(): - raise PetscOSError("PETSc is not installed") - - return petsc_dir - - -@memoized_func -def core_metadata(): - petsc_dir = get_petsc_dir() - - petsc_include = tuple([arch / 'include' for arch in petsc_dir]) - petsc_lib = tuple([arch / 'lib' for arch in petsc_dir]) - - return { - 'includes': ('petscsnes.h', 'petscdmda.h'), - 'include_dirs': petsc_include, - 'libs': ('petsc'), - 'lib_dirs': petsc_lib, - 'ldflags': tuple([f"-Wl,-rpath,{lib}" for lib in petsc_lib]) - } - - -@memoized_func -def get_petsc_variables(): - """ - Taken from https://www.firedrakeproject.org/_modules/firedrake/petsc.html - Get a dict of PETSc environment variables from the file: - $PETSC_DIR/$PETSC_ARCH/lib/petsc/conf/petscvariables - """ - try: - petsc_dir = get_petsc_dir() - except PetscOSError: - petsc_variables = {} - else: - path = [petsc_dir[-1], 'lib', 'petsc', 'conf', 'petscvariables'] - variables_path = Path(*path) - - with open(variables_path) as fh: - # Split lines on first '=' (assignment) - splitlines = (line.split("=", maxsplit=1) for line in fh.readlines()) - petsc_variables = {k.strip(): v.strip() for k, v in splitlines} - - return petsc_variables - - -petsc_variables = get_petsc_variables() -# TODO: Use petsctools get_petscvariables() instead? - - -def get_petsc_type_mappings(): - try: - petsc_precision = petsc_variables['PETSC_PRECISION'] - except KeyError: - printer_mapper = {} - petsc_type_to_ctype = {} - else: - petsc_scalar = 'PetscScalar' - # TODO: Check to see whether Petsc is compiled with - # 32-bit or 64-bit integers - printer_mapper = {ctypes.c_int: 'PetscInt'} - - if petsc_precision == 'single': - printer_mapper[ctypes.c_float] = petsc_scalar - elif petsc_precision == 'double': - printer_mapper[ctypes.c_double] = petsc_scalar - - # Used to construct ctypes.Structures that wrap PETSc objects - petsc_type_to_ctype = {v: k for k, v in printer_mapper.items()} - # Add other PETSc types - petsc_type_to_ctype.update({ - 'KSPType': ctypes.c_char_p, - 'KSPConvergedReason': petsc_type_to_ctype['PetscInt'], - 'KSPNormType': petsc_type_to_ctype['PetscInt'], - }) - return printer_mapper, petsc_type_to_ctype - - -petsc_type_mappings, petsc_type_to_ctype = get_petsc_type_mappings() - - -petsc_languages = ['petsc'] - - -def get_funcs(exprs): - funcs = [ - f for e in exprs - for f in retrieve_functions(eval_time_derivatives(e.lhs - e.rhs)) - ] - return filter_ordered(funcs) - - -def generate_time_mapper(funcs): - """ - Replace time indices with `Symbols` in expressions used within - PETSc callback functions. These symbols are Uxreplaced at the IET - level to align with the `TimeDimension` and `ModuloDimension` objects - present in the initial lowering. - NOTE: All functions used in PETSc callback functions are attached to - the `SolverMetaData` object, which is passed through the initial lowering - (and subsequently dropped and replaced with calls to run the solver). - Therefore, the appropriate time loop will always be correctly generated inside - the main kernel. - Examples - -------- - >>> funcs = [ - >>> f1(t + dt, x, y), - >>> g1(t + dt, x, y), - >>> g2(t, x, y), - >>> f1(t, x, y) - >>> ] - >>> generate_time_mapper(funcs) - {t + dt: tau0, t: tau1} - """ - time_indices = list({ - i if isinstance(d, SteppingDimension) else d - for f in funcs - for i, d in zip(f.indices, f.dimensions) - if d.is_Time - }) - tau_symbs = [Symbol('tau%d' % i) for i in range(len(time_indices))] - return dict(zip(time_indices, tau_symbs)) diff --git a/examples/petsc/Poisson/01_poisson.py b/examples/petsc/Poisson/01_poisson.py index 7ed32e8bbd..318e77f000 100644 --- a/examples/petsc/Poisson/01_poisson.py +++ b/examples/petsc/Poisson/01_poisson.py @@ -4,7 +4,7 @@ from devito import (Grid, Function, Eq, Operator, switchconfig, configuration, SubDomain) -from devito.petsc import PETScSolve, EssentialBC +from devito.petsc import petscsolve, EssentialBC from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -97,7 +97,7 @@ def analytical(x, y): bcs += [EssentialBC(phi, bc, subdomain=sub4)] exprs = [eqn] + bcs - petsc = PETScSolve(exprs, target=phi, solver_parameters={'ksp_rtol': 1e-8}) + petsc = petscsolve(exprs, target=phi, solver_parameters={'ksp_rtol': 1e-8}) with switchconfig(language='petsc'): op = Operator(petsc) diff --git a/examples/petsc/Poisson/02_laplace.py b/examples/petsc/Poisson/02_laplace.py index 9df68f9ab9..7ac621d1db 100644 --- a/examples/petsc/Poisson/02_laplace.py +++ b/examples/petsc/Poisson/02_laplace.py @@ -3,7 +3,7 @@ from devito import (Grid, Function, Eq, Operator, SubDomain, configuration, switchconfig) -from devito.petsc import PETScSolve, EssentialBC +from devito.petsc import petscsolve, EssentialBC from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -107,7 +107,7 @@ def analytical(x, y, Lx, Ly): bcs += [EssentialBC(phi, bc_func, subdomain=sub4)] # right exprs = [eqn] + bcs - petsc = PETScSolve(exprs, target=phi, solver_parameters={'ksp_rtol': 1e-8}) + petsc = petscsolve(exprs, target=phi, solver_parameters={'ksp_rtol': 1e-8}) with switchconfig(language='petsc'): op = Operator(petsc) diff --git a/examples/petsc/Poisson/03_poisson.py b/examples/petsc/Poisson/03_poisson.py index 9fa9a9e68a..dd72f265ff 100644 --- a/examples/petsc/Poisson/03_poisson.py +++ b/examples/petsc/Poisson/03_poisson.py @@ -4,7 +4,7 @@ from devito import (Grid, Function, Eq, Operator, switchconfig, configuration, SubDomain) -from devito.petsc import PETScSolve, EssentialBC +from devito.petsc import petscsolve, EssentialBC from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -75,7 +75,7 @@ def analytical(x): bcs += [EssentialBC(u, np.float64(0.), subdomain=sub2)] exprs = [eqn] + bcs - petsc = PETScSolve(exprs, target=u, solver_parameters={'ksp_rtol': 1e-7}) + petsc = petscsolve(exprs, target=u, solver_parameters={'ksp_rtol': 1e-7}) with switchconfig(language='petsc'): op = Operator(petsc) diff --git a/examples/petsc/Poisson/04_poisson.py b/examples/petsc/Poisson/04_poisson.py index 637ce44076..f5f618e5b2 100644 --- a/examples/petsc/Poisson/04_poisson.py +++ b/examples/petsc/Poisson/04_poisson.py @@ -4,7 +4,7 @@ from devito import (Grid, Function, Eq, Operator, switchconfig, configuration, SubDomain) -from devito.petsc import PETScSolve, EssentialBC +from devito.petsc import petscsolve, EssentialBC from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -104,7 +104,7 @@ def analytical(x, y): bc_eqns += [EssentialBC(u, bcs, subdomain=sub4)] exprs = [eqn]+bc_eqns - petsc = PETScSolve(exprs, target=u, solver_parameters={'ksp_rtol': 1e-6}) + petsc = petscsolve(exprs, target=u, solver_parameters={'ksp_rtol': 1e-6}) with switchconfig(language='petsc'): op = Operator(petsc) diff --git a/examples/petsc/cfd/01_navierstokes.py b/examples/petsc/cfd/01_navierstokes.py index 1c678d977b..13b9e5d450 100644 --- a/examples/petsc/cfd/01_navierstokes.py +++ b/examples/petsc/cfd/01_navierstokes.py @@ -5,7 +5,7 @@ Operator, SubDomain, switchconfig, configuration) from devito.symbolics import retrieve_functions, INT -from devito.petsc import PETScSolve, EssentialBC +from devito.petsc import petscsolve, EssentialBC from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -248,7 +248,7 @@ def neumann_right(eq, subdomain): bc_pn1 += [neumann_right(neumann_top(eq_pn1, sub8), sub8)] -eqn_p = PETScSolve([eq_pn1]+bc_pn1, pn1.forward) +eqn_p = petscsolve([eq_pn1]+bc_pn1, pn1.forward) eq_u1 = Eq(u1.dt + u1*u1.dxc + v1*u1.dyc, nu*u1.laplace, subdomain=grid.interior) eq_v1 = Eq(v1.dt + u1*v1.dxc + v1*v1.dyc, nu*v1.laplace, subdomain=grid.interior) @@ -285,8 +285,8 @@ def neumann_right(eq, subdomain): bc_petsc_v1 += [EssentialBC(v1.forward, 0., subdomain=sub1)] # top bc_petsc_v1 += [EssentialBC(v1.forward, 0., subdomain=sub2)] # bottom -tentu = PETScSolve([eq_u1]+bc_petsc_u1, u1.forward) -tentv = PETScSolve([eq_v1]+bc_petsc_v1, v1.forward) +tentu = petscsolve([eq_u1]+bc_petsc_u1, u1.forward) +tentv = petscsolve([eq_v1]+bc_petsc_v1, v1.forward) exprs = [tentu, tentv, eqn_p, update_u, update_v] + bc_u1 + bc_v1 diff --git a/examples/petsc/petsc_test.py b/examples/petsc/petsc_test.py index 5d93669d5f..76b0aac957 100644 --- a/examples/petsc/petsc_test.py +++ b/examples/petsc/petsc_test.py @@ -3,7 +3,7 @@ from devito import (Grid, Function, Eq, Operator, configuration, switchconfig) -from devito.petsc import PETScSolve +from devito.petsc import petscsolve from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -22,7 +22,7 @@ eq = Eq(v, u.laplace, subdomain=grid.interior) -petsc = PETScSolve([eq], u) +petsc = petscsolve([eq], u) with switchconfig(language='petsc'): op = Operator(petsc) diff --git a/examples/petsc/random/01_helmholtz.py b/examples/petsc/random/01_helmholtz.py index 8702fbf298..dc498a0dae 100644 --- a/examples/petsc/random/01_helmholtz.py +++ b/examples/petsc/random/01_helmholtz.py @@ -4,7 +4,7 @@ from devito.symbolics import retrieve_functions, INT from devito import (configuration, Operator, Eq, Grid, Function, SubDomain, switchconfig) -from devito.petsc import PETScSolve +from devito.petsc import petscsolve from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -241,7 +241,7 @@ def analytical_solution(x, y): bcs += [neumann_left(neumann_top(eqn, sub7), sub7)] bcs += [neumann_right(neumann_top(eqn, sub8), sub8)] - solver = PETScSolve([eqn]+bcs, target=u, solver_parameters={'rtol': 1e-8}) + solver = petscsolve([eqn]+bcs, target=u, solver_parameters={'rtol': 1e-8}) with switchconfig(openmp=False, language='petsc'): op = Operator(solver) diff --git a/examples/petsc/random/02_biharmonic.py b/examples/petsc/random/02_biharmonic.py index f08ffc07de..635c4ff42b 100644 --- a/examples/petsc/random/02_biharmonic.py +++ b/examples/petsc/random/02_biharmonic.py @@ -7,7 +7,7 @@ from devito import (Grid, Function, Eq, Operator, switchconfig, configuration, SubDomain) -from devito.petsc import PETScSolve, EssentialBC +from devito.petsc import petscsolve, EssentialBC from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -124,7 +124,7 @@ def f_fcn(x, y): # T (see ref) is nonsymmetric so need to set default KSP type to GMRES params = {'ksp_rtol': 1e-10} - petsc = PETScSolve({v: [eqn1]+bc_v, u: [eqn2]+bc_u}, solver_parameters=params) + petsc = petscsolve({v: [eqn1]+bc_v, u: [eqn2]+bc_u}, solver_parameters=params) with switchconfig(language='petsc'): op = Operator(petsc) diff --git a/examples/petsc/seismic/01_staggered_acoustic.py b/examples/petsc/seismic/01_staggered_acoustic.py index fc9e75938d..2352083236 100644 --- a/examples/petsc/seismic/01_staggered_acoustic.py +++ b/examples/petsc/seismic/01_staggered_acoustic.py @@ -2,7 +2,7 @@ import os import numpy as np from examples.seismic.source import DGaussSource, TimeAxis -from devito.petsc import PETScSolve +from devito.petsc import petscsolve from devito.petsc.initialize import PetscInitialize configuration['compiler'] = 'custom' os.environ['CC'] = 'mpicc' @@ -57,12 +57,12 @@ v_x_2 = Eq(vx2.dt, ro * p2.dx) v_z_2 = Eq(vz2.dt, ro * p2.dz) -petsc_v_x_2 = PETScSolve(v_x_2, target=vx2.forward) -petsc_v_z_2 = PETScSolve(v_z_2, target=vz2.forward) +petsc_v_x_2 = petscsolve(v_x_2, target=vx2.forward) +petsc_v_z_2 = petscsolve(v_z_2, target=vz2.forward) p_2 = Eq(p2.dt, l2m * (vx2.forward.dx + vz2.forward.dz)) -petsc_p_2 = PETScSolve(p_2, target=p2.forward, solver_parameters={'ksp_rtol': 1e-7}) +petsc_p_2 = petscsolve(p_2, target=p2.forward, solver_parameters={'ksp_rtol': 1e-7}) with switchconfig(language='petsc'): op_2 = Operator([petsc_v_x_2, petsc_v_z_2, petsc_p_2, src_p_2], opt='noop') @@ -82,12 +82,12 @@ v_x_4 = Eq(vx4.dt, ro * p4.dx) v_z_4 = Eq(vz4.dt, ro * p4.dz) -petsc_v_x_4 = PETScSolve(v_x_4, target=vx4.forward) -petsc_v_z_4 = PETScSolve(v_z_4, target=vz4.forward) +petsc_v_x_4 = petscsolve(v_x_4, target=vx4.forward) +petsc_v_z_4 = petscsolve(v_z_4, target=vz4.forward) p_4 = Eq(p4.dt, l2m * (vx4.forward.dx + vz4.forward.dz)) -petsc_p_4 = PETScSolve(p_4, target=p4.forward, solver_parameters={'ksp_rtol': 1e-7}) +petsc_p_4 = petscsolve(p_4, target=p4.forward, solver_parameters={'ksp_rtol': 1e-7}) with switchconfig(language='petsc'): op_4 = Operator([petsc_v_x_4, petsc_v_z_4, petsc_p_4, src_p_4], opt='noop') diff --git a/tests/test_petsc.py b/tests/test_petsc.py index 28f7c04025..6926ae76cd 100644 --- a/tests/test_petsc.py +++ b/tests/test_petsc.py @@ -3,10 +3,11 @@ import numpy as np import os import re +import sympy as sp from conftest import skipif from devito import (Grid, Function, TimeFunction, Eq, Operator, - configuration, norm, switchconfig, SubDomain) + configuration, norm, switchconfig, SubDomain, sin) from devito.operator.profiling import PerformanceSummary from devito.ir.iet import (Call, ElementalFunction, FindNodes, retrieve_iteration_tree) @@ -16,7 +17,7 @@ PC, KSPConvergedReason, PETScArray, FieldData, MultipleFieldData, SubMatrixBlock) -from devito.petsc.solve import PETScSolve, EssentialBC +from devito.petsc.solve import petscsolve, EssentialBC from devito.petsc.iet.nodes import Expression from devito.petsc.initialize import PetscInitialize from devito.petsc.logging import PetscSummary @@ -129,7 +130,7 @@ def test_petsc_subs(): @skipif('petsc') def test_petsc_solve(): """ - Test PETScSolve. + Test `petscsolve`. """ grid = Grid((2, 2), dtype=np.float64) @@ -138,7 +139,7 @@ def test_petsc_solve(): eqn = Eq(f.laplace, g) - petsc = PETScSolve(eqn, f) + petsc = petscsolve(eqn, f) with switchconfig(language='petsc'): op = Operator(petsc, opt='noop') @@ -178,7 +179,7 @@ def test_petsc_solve(): @skipif('petsc') def test_multiple_petsc_solves(): """ - Test multiple PETScSolves. + Test multiple `petscsolve` calls, passed to a single `Operator`. """ grid = Grid((2, 2), dtype=np.float64) @@ -191,8 +192,8 @@ def test_multiple_petsc_solves(): eqn1 = Eq(f1.laplace, g1) eqn2 = Eq(f2.laplace, g2) - petsc1 = PETScSolve(eqn1, f1, options_prefix='pde1') - petsc2 = PETScSolve(eqn2, f2, options_prefix='pde2') + petsc1 = petscsolve(eqn1, f1, options_prefix='pde1') + petsc2 = petscsolve(eqn2, f2, options_prefix='pde2') with switchconfig(language='petsc'): op = Operator([petsc1, petsc2], opt='noop') @@ -222,9 +223,9 @@ def test_petsc_cast(): eqn2 = Eq(f2.laplace, 10) eqn3 = Eq(f3.laplace, 10) - petsc1 = PETScSolve(eqn1, f1) - petsc2 = PETScSolve(eqn2, f2) - petsc3 = PETScSolve(eqn3, f3) + petsc1 = petscsolve(eqn1, f1) + petsc2 = petscsolve(eqn2, f2) + petsc3 = petscsolve(eqn3, f3) with switchconfig(language='petsc'): op1 = Operator(petsc1) @@ -254,9 +255,9 @@ def test_dmda_create(): eqn2 = Eq(f2.laplace, 10) eqn3 = Eq(f3.laplace, 10) - petsc1 = PETScSolve(eqn1, f1) - petsc2 = PETScSolve(eqn2, f2) - petsc3 = PETScSolve(eqn3, f3) + petsc1 = petscsolve(eqn1, f1) + petsc2 = petscsolve(eqn2, f2) + petsc3 = petscsolve(eqn3, f3) with switchconfig(language='petsc'): op1 = Operator(petsc1, opt='noop') @@ -275,34 +276,113 @@ def test_dmda_create(): ',1,1,1,1,6,NULL,NULL,NULL,&da0));' in str(op3) -@skipif('petsc') -def test_cinterface_petsc_struct(): +class TestStruct: + @skipif('petsc') + def test_cinterface_petsc_struct(self): - grid = Grid(shape=(11, 11), dtype=np.float64) - f = Function(name='f', grid=grid, space_order=2) - eq = Eq(f.laplace, 10) - petsc = PETScSolve(eq, f) + grid = Grid(shape=(11, 11), dtype=np.float64) + f = Function(name='f', grid=grid, space_order=2) + eq = Eq(f.laplace, 10) + petsc = petscsolve(eq, f) - name = "foo" + name = "foo" - with switchconfig(language='petsc'): - op = Operator(petsc, name=name) + with switchconfig(language='petsc'): + op = Operator(petsc, name=name) + + # Trigger the generation of a .c and a .h files + ccode, hcode = op.cinterface(force=True) + + dirname = op._compiler.get_jit_dir() + assert os.path.isfile(os.path.join(dirname, "%s.c" % name)) + assert os.path.isfile(os.path.join(dirname, "%s.h" % name)) + + ccode = str(ccode) + hcode = str(hcode) + + assert 'include "%s.h"' % name in ccode + + # The public `struct UserCtx` only appears in the header file + assert 'struct UserCtx0\n{' not in ccode + assert 'struct UserCtx0\n{' in hcode + + @skipif('petsc') + def test_temp_arrays_in_struct(self): + + grid = Grid(shape=(11, 11, 11), dtype=np.float64) + + u = TimeFunction(name='u', grid=grid, space_order=2) + x, y, _ = grid.dimensions + + eqn = Eq(u.forward, sin(sp.pi*(x+y)/3.), subdomain=grid.interior) + petsc = petscsolve(eqn, target=u.forward) + + with switchconfig(log_level='DEBUG', language='petsc'): + op = Operator(petsc) + # Check that it runs + op.apply(time_M=3) + + assert 'ctx0->x_size = x_size;' in str(op.ccode) + assert 'ctx0->y_size = y_size;' in str(op.ccode) + + assert 'const PetscInt y_size = ctx0->y_size;' in str(op.ccode) + assert 'const PetscInt x_size = ctx0->x_size;' in str(op.ccode) + + @skipif('petsc') + def test_parameters(self): + + grid = Grid((2, 2), dtype=np.float64) + + f1 = Function(name='f1', grid=grid, space_order=2) + g1 = Function(name='g1', grid=grid, space_order=2) + + mu1 = Constant(name='mu1', value=2.0) + mu2 = Constant(name='mu2', value=2.0) + + eqn1 = Eq(f1.laplace, g1*mu1) + petsc1 = petscsolve(eqn1, f1) - # Trigger the generation of a .c and a .h files - ccode, hcode = op.cinterface(force=True) + eqn2 = Eq(f1, g1*mu2) - dirname = op._compiler.get_jit_dir() - assert os.path.isfile(os.path.join(dirname, "%s.c" % name)) - assert os.path.isfile(os.path.join(dirname, "%s.h" % name)) + with switchconfig(language='petsc'): + op = Operator([eqn2, petsc1]) + + arguments = op.arguments() + + # Check mu1 and mu2 in arguments + assert 'mu1' in arguments + assert 'mu2' in arguments + + # Check mu1 and mu2 in op.parameters + assert mu1 in op.parameters + assert mu2 in op.parameters + + # Check PETSc struct not in op.parameters + assert all(not isinstance(i, LocalCompositeObject) for i in op.parameters) - ccode = str(ccode) - hcode = str(hcode) + @skipif('petsc') + def test_field_order(self): + """Verify that the order of fields in the user struct is fixed for + `identical` Operator instances. + """ + grid = Grid(shape=(11, 11, 11), dtype=np.float64) + f = TimeFunction(name='f', grid=grid, space_order=2) + x, y, _ = grid.dimensions + t = grid.time_dim + eq = Eq(f.dt, f.laplace + t*0.005 + sin(sp.pi*(x+y)/3.), subdomain=grid.interior) + petsc = petscsolve(eq, f.forward) - assert 'include "%s.h"' % name in ccode + with switchconfig(language='petsc'): + op1 = Operator(petsc, name="foo1") + op2 = Operator(petsc, name="foo2") - # The public `struct UserCtx` only appears in the header file - assert 'struct UserCtx0\n{' not in ccode - assert 'struct UserCtx0\n{' in hcode + op1_user_struct = op1._func_table['PopulateUserContext0'].root.parameters[0] + op2_user_struct = op2._func_table['PopulateUserContext0'].root.parameters[0] + + assert len(op1_user_struct.fields) == len(op2_user_struct.fields) + assert len(op1_user_struct.callback_fields) == \ + len(op1_user_struct.callback_fields) + assert str(op1_user_struct.fields) == str(op2_user_struct.fields) @skipif('petsc') @@ -317,7 +397,7 @@ def test_callback_arguments(): eqn1 = Eq(f1.laplace, g1) - petsc1 = PETScSolve(eqn1, f1) + petsc1 = petscsolve(eqn1, f1) with switchconfig(language='petsc'): op = Operator(petsc1) @@ -332,39 +412,6 @@ def test_callback_arguments(): assert str(ff.parameters) == '(snes, X, F, dummy)' -@skipif('petsc') -def test_petsc_struct(): - - grid = Grid((2, 2), dtype=np.float64) - - f1 = Function(name='f1', grid=grid, space_order=2) - g1 = Function(name='g1', grid=grid, space_order=2) - - mu1 = Constant(name='mu1', value=2.0) - mu2 = Constant(name='mu2', value=2.0) - - eqn1 = Eq(f1.laplace, g1*mu1) - petsc1 = PETScSolve(eqn1, f1) - - eqn2 = Eq(f1, g1*mu2) - - with switchconfig(language='petsc'): - op = Operator([eqn2, petsc1]) - - arguments = op.arguments() - - # Check mu1 and mu2 in arguments - assert 'mu1' in arguments - assert 'mu2' in arguments - - # Check mu1 and mu2 in op.parameters - assert mu1 in op.parameters - assert mu2 in op.parameters - - # Check PETSc struct not in op.parameters - assert all(not isinstance(i, LocalCompositeObject) for i in op.parameters) - - @skipif('petsc') def test_apply(): @@ -376,7 +423,7 @@ def test_apply(): eqn = Eq(pn.laplace*mu, rhs, subdomain=grid.interior) - petsc = PETScSolve(eqn, pn) + petsc = petscsolve(eqn, pn) with switchconfig(language='petsc'): # Build the op @@ -399,7 +446,7 @@ def test_petsc_frees(): g = Function(name='g', grid=grid, space_order=2) eqn = Eq(f.laplace, g) - petsc = PETScSolve(eqn, f) + petsc = petscsolve(eqn, f) with switchconfig(language='petsc'): op = Operator(petsc) @@ -424,7 +471,7 @@ def test_calls_to_callbacks(): g = Function(name='g', grid=grid, space_order=2) eqn = Eq(f.laplace, g) - petsc = PETScSolve(eqn, f) + petsc = petscsolve(eqn, f) with switchconfig(language='petsc'): op = Operator(petsc) @@ -448,7 +495,7 @@ def test_start_ptr(): grid = Grid((11, 11), dtype=np.float64) u1 = TimeFunction(name='u1', grid=grid, space_order=2) eq1 = Eq(u1.dt, u1.laplace, subdomain=grid.interior) - petsc1 = PETScSolve(eq1, u1.forward) + petsc1 = petscsolve(eq1, u1.forward) with switchconfig(language='petsc'): op1 = Operator(petsc1) @@ -460,7 +507,7 @@ def test_start_ptr(): # Verify the case with no modulo time stepping u2 = TimeFunction(name='u2', grid=grid, space_order=2, save=5) eq2 = Eq(u2.dt, u2.laplace, subdomain=grid.interior) - petsc2 = PETScSolve(eq2, u2.forward) + petsc2 = petscsolve(eq2, u2.forward) with switchconfig(language='petsc'): op2 = Operator(petsc2) @@ -469,86 +516,146 @@ def test_start_ptr(): '(PetscScalar*)(u2_vec->data);') in str(op2) -@skipif('petsc') -def test_time_loop(): - """ - Verify the following: - - Modulo dimensions are correctly assigned and updated in the PETSc struct - at each time step. - - Only assign/update the modulo dimensions required by any of the - PETSc callback functions. - """ - grid = Grid((11, 11), dtype=np.float64) +class TestTimeLoop: + @skipif('petsc') + @pytest.mark.parametrize('dim', [1, 2, 3]) + def test_time_dimensions(self, dim): + """ + Verify the following: + - Modulo dimensions are correctly assigned and updated in the PETSc struct + at each time step. + - Only assign/update the modulo dimensions required by any of the + PETSc callback functions. + """ + shape = tuple(11 for _ in range(dim)) + grid = Grid(shape=shape, dtype=np.float64) - # Modulo time stepping - u1 = TimeFunction(name='u1', grid=grid, space_order=2) - v1 = Function(name='v1', grid=grid, space_order=2) - eq1 = Eq(v1.laplace, u1) - petsc1 = PETScSolve(eq1, v1) + # Modulo time stepping + u1 = TimeFunction(name='u1', grid=grid, space_order=2) + v1 = Function(name='v1', grid=grid, space_order=2) + eq1 = Eq(v1.laplace, u1) + petsc1 = petscsolve(eq1, v1) - with switchconfig(language='petsc'): - op1 = Operator(petsc1) - op1.apply(time_M=3) - body1 = str(op1.body) - rhs1 = str(op1._func_table['FormRHS0'].root.ccode) + with switchconfig(language='petsc'): + op1 = Operator(petsc1) + op1.apply(time_M=3) + body1 = str(op1.body) + rhs1 = str(op1._func_table['FormRHS0'].root.ccode) + + assert 'ctx0.t0 = t0' in body1 + assert 'ctx0.t1 = t1' not in body1 + assert 'ctx0->t0' in rhs1 + assert 'ctx0->t1' not in rhs1 + + # Non-modulo time stepping + u2 = TimeFunction(name='u2', grid=grid, space_order=2, save=5) + v2 = Function(name='v2', grid=grid, space_order=2, save=5) + eq2 = Eq(v2.laplace, u2) + petsc2 = petscsolve(eq2, v2) - assert 'ctx0.t0 = t0' in body1 - assert 'ctx0.t1 = t1' not in body1 - assert 'ctx0->t0' in rhs1 - assert 'ctx0->t1' not in rhs1 + with switchconfig(language='petsc'): + op2 = Operator(petsc2) + op2.apply(time_M=3) + body2 = str(op2.body) + rhs2 = str(op2._func_table['FormRHS0'].root.ccode) - # Non-modulo time stepping - u2 = TimeFunction(name='u2', grid=grid, space_order=2, save=5) - v2 = Function(name='v2', grid=grid, space_order=2, save=5) - eq2 = Eq(v2.laplace, u2) - petsc2 = PETScSolve(eq2, v2) + assert 'ctx0.time = time' in body2 + assert 'ctx0->time' in rhs2 - with switchconfig(language='petsc'): - op2 = Operator(petsc2) - op2.apply(time_M=3) - body2 = str(op2.body) - rhs2 = str(op2._func_table['FormRHS0'].root.ccode) + # Modulo time stepping with more than one time step + # used in one of the callback functions + eq3 = Eq(v1.laplace, u1 + u1.forward) + petsc3 = petscsolve(eq3, v1) - assert 'ctx0.time = time' in body2 - assert 'ctx0->time' in rhs2 + with switchconfig(language='petsc'): + op3 = Operator(petsc3) + op3.apply(time_M=3) + body3 = str(op3.body) + rhs3 = str(op3._func_table['FormRHS0'].root.ccode) + + assert 'ctx0.t0 = t0' in body3 + assert 'ctx0.t1 = t1' in body3 + assert 'ctx0->t0' in rhs3 + assert 'ctx0->t1' in rhs3 + + # Multiple petsc solves within the same time loop + v2 = Function(name='v2', grid=grid, space_order=2) + eq4 = Eq(v1.laplace, u1) + petsc4 = petscsolve(eq4, v1) + eq5 = Eq(v2.laplace, u1) + petsc5 = petscsolve(eq5, v2) - # Modulo time stepping with more than one time step - # used in one of the callback functions - eq3 = Eq(v1.laplace, u1 + u1.forward) - petsc3 = PETScSolve(eq3, v1) + with switchconfig(language='petsc'): + op4 = Operator([petsc4, petsc5]) + op4.apply(time_M=3) + body4 = str(op4.body) - with switchconfig(language='petsc'): - op3 = Operator(petsc3) - op3.apply(time_M=3) - body3 = str(op3.body) - rhs3 = str(op3._func_table['FormRHS0'].root.ccode) - - assert 'ctx0.t0 = t0' in body3 - assert 'ctx0.t1 = t1' in body3 - assert 'ctx0->t0' in rhs3 - assert 'ctx0->t1' in rhs3 - - # Multiple petsc solves within the same time loop - v2 = Function(name='v2', grid=grid, space_order=2) - eq4 = Eq(v1.laplace, u1) - petsc4 = PETScSolve(eq4, v1) - eq5 = Eq(v2.laplace, u1) - petsc5 = PETScSolve(eq5, v2) + assert 'ctx0.t0 = t0' in body4 + assert body4.count('ctx0.t0 = t0') == 1 - with switchconfig(language='petsc'): - op4 = Operator([petsc4, petsc5]) - op4.apply(time_M=3) - body4 = str(op4.body) + @skipif('petsc') + @pytest.mark.parametrize('dim', [1, 2, 3]) + def test_trivial_operator(self, dim): + """ + Test trivial time-dependent problems with `petscsolve`. + """ + # create shape based on dimension + shape = tuple(4 for _ in range(dim)) + grid = Grid(shape=shape, dtype=np.float64) + u = TimeFunction(name='u', grid=grid, save=3) + + eqn = Eq(u.forward, u + 1) - assert 'ctx0.t0 = t0' in body4 - assert body4.count('ctx0.t0 = t0') == 1 + petsc = petscsolve(eqn, target=u.forward) + + with switchconfig(log_level='DEBUG'): + op = Operator(petsc, language='petsc') + op.apply() + + assert np.all(u.data[0] == 0.) + assert np.all(u.data[1] == 1.) + assert np.all(u.data[2] == 2.) + + @skipif('petsc') + @pytest.mark.parametrize('dim', [1, 2, 3]) + def test_time_dim(self, dim): + """ + Verify the time loop abstraction + when a mixture of TimeDimensions and time dependent + SteppingDimensions are used + """ + shape = tuple(4 for _ in range(dim)) + grid = Grid(shape=shape, dtype=np.float64) + # Use modoulo time stepping, i.e don't pass the save argument + u = TimeFunction(name='u', grid=grid) + # Use grid.time_dim in the equation, as well as the TimeFunction itself + petsc = petscsolve(Eq(u.forward, u + 1 + grid.time_dim), target=u.forward) + + with switchconfig(): + op = Operator(petsc, language='petsc') + op.apply(time_M=1) + + body = str(op.body) + rhs = str(op._func_table['FormRHS0'].root.ccode) + + # Check both ctx0.t0 and ctx0.time are assigned since they are both used + # in the callback functions, specifically in FormRHS0 + assert 'ctx0.t0 = t0' in body + assert 'ctx0.time = time' in body + assert 'ctx0->t0' in rhs + assert 'ctx0->time' in rhs + + # Check the ouput is as expected given two time steps have been + # executed (time_M=1) + assert np.all(u.data[1] == 1.) + assert np.all(u.data[0] == 3.) @skipif('petsc') def test_solve_output(): """ - Verify that PETScSolve returns the correct output for - simple cases e.g with the identity matrix. + Verify that `petscsolve` returns the correct output for + simple cases e.g. forming the identity matrix. """ grid = Grid(shape=(11, 11), dtype=np.float64) @@ -558,7 +665,7 @@ def test_solve_output(): # Solving Ax=b where A is the identity matrix v.data[:] = 5.0 eqn = Eq(u, v) - petsc = PETScSolve(eqn, target=u) + petsc = petscsolve(eqn, target=u) with switchconfig(language='petsc'): op = Operator(petsc) @@ -568,74 +675,75 @@ def test_solve_output(): assert np.allclose(u.data, v.data) -@skipif('petsc') -def test_essential_bcs(): - """ - Verify that PETScSolve returns the correct output with - essential boundary conditions. - """ - # SubDomains used for essential boundary conditions - # should not overlap. - class SubTop(SubDomain): - name = 'subtop' +class TestEssentialBCs: + @skipif('petsc') + def test_essential_bcs(self): + """ + Verify that `petscsolve` returns the correct output with + essential boundary conditions (`EssentialBC`). + """ + # SubDomains used for essential boundary conditions + # should not overlap. + class SubTop(SubDomain): + name = 'subtop' - def define(self, dimensions): - x, y = dimensions - return {x: x, y: ('right', 1)} - sub1 = SubTop() + def define(self, dimensions): + x, y = dimensions + return {x: x, y: ('right', 1)} + sub1 = SubTop() - class SubBottom(SubDomain): - name = 'subbottom' + class SubBottom(SubDomain): + name = 'subbottom' - def define(self, dimensions): - x, y = dimensions - return {x: x, y: ('left', 1)} - sub2 = SubBottom() + def define(self, dimensions): + x, y = dimensions + return {x: x, y: ('left', 1)} + sub2 = SubBottom() - class SubLeft(SubDomain): - name = 'subleft' + class SubLeft(SubDomain): + name = 'subleft' - def define(self, dimensions): - x, y = dimensions - return {x: ('left', 1), y: ('middle', 1, 1)} - sub3 = SubLeft() + def define(self, dimensions): + x, y = dimensions + return {x: ('left', 1), y: ('middle', 1, 1)} + sub3 = SubLeft() - class SubRight(SubDomain): - name = 'subright' + class SubRight(SubDomain): + name = 'subright' - def define(self, dimensions): - x, y = dimensions - return {x: ('right', 1), y: ('middle', 1, 1)} - sub4 = SubRight() + def define(self, dimensions): + x, y = dimensions + return {x: ('right', 1), y: ('middle', 1, 1)} + sub4 = SubRight() - subdomains = (sub1, sub2, sub3, sub4) - grid = Grid(shape=(11, 11), subdomains=subdomains, dtype=np.float64) + subdomains = (sub1, sub2, sub3, sub4) + grid = Grid(shape=(11, 11), subdomains=subdomains, dtype=np.float64) - u = Function(name='u', grid=grid, space_order=2) - v = Function(name='v', grid=grid, space_order=2) + u = Function(name='u', grid=grid, space_order=2) + v = Function(name='v', grid=grid, space_order=2) - # Solving Ax=b where A is the identity matrix - v.data[:] = 5.0 - eqn = Eq(u, v, subdomain=grid.interior) + # Solving Ax=b where A is the identity matrix + v.data[:] = 5.0 + eqn = Eq(u, v, subdomain=grid.interior) - bcs = [EssentialBC(u, 1., subdomain=sub1)] # top - bcs += [EssentialBC(u, 2., subdomain=sub2)] # bottom - bcs += [EssentialBC(u, 3., subdomain=sub3)] # left - bcs += [EssentialBC(u, 4., subdomain=sub4)] # right + bcs = [EssentialBC(u, 1., subdomain=sub1)] # top + bcs += [EssentialBC(u, 2., subdomain=sub2)] # bottom + bcs += [EssentialBC(u, 3., subdomain=sub3)] # left + bcs += [EssentialBC(u, 4., subdomain=sub4)] # right - petsc = PETScSolve([eqn]+bcs, target=u) + petsc = petscsolve([eqn]+bcs, target=u) - with switchconfig(language='petsc'): - op = Operator(petsc) - op.apply() + with switchconfig(language='petsc'): + op = Operator(petsc) + op.apply() - # Check u is equal to v on the interior - assert np.allclose(u.data[1:-1, 1:-1], v.data[1:-1, 1:-1]) - # Check u satisfies the boundary conditions - assert np.allclose(u.data[1:-1, -1], 1.0) # top - assert np.allclose(u.data[1:-1, 0], 2.0) # bottom - assert np.allclose(u.data[0, 1:-1], 3.0) # left - assert np.allclose(u.data[-1, 1:-1], 4.0) # right + # Check u is equal to v on the interior + assert np.allclose(u.data[1:-1, 1:-1], v.data[1:-1, 1:-1]) + # Check u satisfies the boundary conditions + assert np.allclose(u.data[1:-1, -1], 1.0) # top + assert np.allclose(u.data[1:-1, 0], 2.0) # bottom + assert np.allclose(u.data[0, 1:-1], 3.0) # left + assert np.allclose(u.data[-1, 1:-1], 4.0) # right @skipif('petsc') @@ -668,7 +776,7 @@ def define(self, dimensions): eq1 = Eq(e.laplace + e, f + 2.0) - petsc = PETScSolve([eq1, bc_1, bc_2], target=e) + petsc = petscsolve([eq1, bc_1, bc_2], target=e) jac = petsc.rhs.field_data.jacobian @@ -714,7 +822,7 @@ def define(self, dimensions): eq1 = Eq(e.laplace + e, f + 2.0) - petsc = PETScSolve([eq1, bc_1, bc_2], target=e) + petsc = petscsolve([eq1, bc_1, bc_2], target=e) res = petsc.rhs.field_data.residual @@ -753,7 +861,7 @@ class TestCoupledLinear: def test_coupled_vs_non_coupled(self, eq1, eq2, so): """ Test that solving multiple **uncoupled** equations separately - vs. together with `PETScSolve` yields the same result. + vs. together with `petscsolve` yields the same result. This test is non time-dependent. """ grid = Grid(shape=(11, 11), dtype=np.float64) @@ -769,8 +877,8 @@ def test_coupled_vs_non_coupled(self, eq1, eq2, so): eq2 = eval(eq2) # Non-coupled - petsc1 = PETScSolve(eq1, target=e) - petsc2 = PETScSolve(eq2, target=g) + petsc1 = petscsolve(eq1, target=e) + petsc2 = petscsolve(eq2, target=g) with switchconfig(language='petsc'): op1 = Operator([petsc1, petsc2], opt='noop') @@ -784,7 +892,7 @@ def test_coupled_vs_non_coupled(self, eq1, eq2, so): g.data[:] = 0 # Coupled - petsc3 = PETScSolve({e: [eq1], g: [eq2]}) + petsc3 = petscsolve({e: [eq1], g: [eq2]}) with switchconfig(language='petsc'): op2 = Operator(petsc3, opt='noop') @@ -828,7 +936,7 @@ def test_coupled_structs(self): eq1 = Eq(e + 5, f) eq2 = Eq(g + 10, h) - petsc = PETScSolve({f: [eq1], h: [eq2]}) + petsc = petscsolve({f: [eq1], h: [eq2]}) name = "foo" @@ -873,7 +981,7 @@ def test_coupled_frees(self, n_fields): *solved_funcs, h = functions equations = [Eq(func.laplace, h) for func in solved_funcs] - petsc = PETScSolve({func: [eq] for func, eq in zip(solved_funcs, equations)}) + petsc = petscsolve({func: [eq] for func, eq in zip(solved_funcs, equations)}) with switchconfig(language='petsc'): op = Operator(petsc, opt='noop') @@ -903,9 +1011,9 @@ def test_dmda_dofs(self): eq2 = Eq(f.laplace, h) eq3 = Eq(g.laplace, h) - petsc1 = PETScSolve({e: [eq1]}) - petsc2 = PETScSolve({e: [eq1], f: [eq2]}) - petsc3 = PETScSolve({e: [eq1], f: [eq2], g: [eq3]}) + petsc1 = petscsolve({e: [eq1]}) + petsc2 = petscsolve({e: [eq1], f: [eq2]}) + petsc3 = petscsolve({e: [eq1], f: [eq2], g: [eq3]}) with switchconfig(language='petsc'): op1 = Operator(petsc1, opt='noop') @@ -936,7 +1044,7 @@ def test_mixed_jacobian(self): eq1 = Eq(e.laplace, f) eq2 = Eq(g.laplace, h) - petsc = PETScSolve({e: [eq1], g: [eq2]}) + petsc = petscsolve({e: [eq1], g: [eq2]}) jacobian = petsc.rhs.field_data.jacobian @@ -1048,7 +1156,7 @@ def test_coupling(self, eq1, eq2, j01_matvec, j10_matvec): eq1 = eval(eq1) eq2 = eval(eq2) - petsc = PETScSolve({e: [eq1], g: [eq2]}) + petsc = petscsolve({e: [eq1], g: [eq2]}) jacobian = petsc.rhs.field_data.jacobian @@ -1097,7 +1205,7 @@ def test_jacobian_scaling_1D(self, eq1, eq2, so, scale): eq1 = eval(eq1) eq2 = eval(eq2) - petsc = PETScSolve({e: [eq1], g: [eq2]}) + petsc = petscsolve({e: [eq1], g: [eq2]}) jacobian = petsc.rhs.field_data.jacobian @@ -1147,7 +1255,7 @@ def test_jacobian_scaling_2D(self, eq1, eq2, so, scale): eq1 = eval(eq1) eq2 = eval(eq2) - petsc = PETScSolve({e: [eq1], g: [eq2]}) + petsc = petscsolve({e: [eq1], g: [eq2]}) jacobian = petsc.rhs.field_data.jacobian @@ -1200,7 +1308,7 @@ def test_jacobian_scaling_3D(self, eq1, eq2, so, scale): eq1 = eval(eq1) eq2 = eval(eq2) - petsc = PETScSolve({e: [eq1], g: [eq2]}) + petsc = petscsolve({e: [eq1], g: [eq2]}) jacobian = petsc.rhs.field_data.jacobian @@ -1222,9 +1330,9 @@ def test_residual_bundle(self): eq2 = Eq(f.laplace, h) eq3 = Eq(g.laplace, h) - petsc1 = PETScSolve({e: [eq1]}) - petsc2 = PETScSolve({e: [eq1], f: [eq2]}) - petsc3 = PETScSolve({e: [eq1], f: [eq2], g: [eq3]}) + petsc1 = petscsolve({e: [eq1]}) + petsc2 = petscsolve({e: [eq1], f: [eq2]}) + petsc3 = petscsolve({e: [eq1], f: [eq2], g: [eq3]}) with switchconfig(language='petsc'): op1 = Operator(petsc1, opt='noop', name='op1') @@ -1265,7 +1373,7 @@ def test_residual_callback(self): eq1 = Eq(e.laplace, f) eq2 = Eq(g.laplace, h) - petsc = PETScSolve({e: [eq1], g: [eq2]}) + petsc = petscsolve({e: [eq1], g: [eq2]}) with switchconfig(language='petsc'): op = Operator(petsc) @@ -1312,7 +1420,7 @@ def define(self, dimensions): bc_u = [EssentialBC(u, 0., subdomain=sub1)] bc_v = [EssentialBC(v, 0., subdomain=sub1)] - petsc = PETScSolve({v: [eqn1]+bc_v, u: [eqn2]+bc_u}) + petsc = petscsolve({v: [eqn1]+bc_v, u: [eqn2]+bc_u}) with switchconfig(language='petsc'): op = Operator(petsc) @@ -1390,7 +1498,7 @@ def define(self, dimensions): bcs = [EssentialBC(u, u0, subdomain=sub1)] bcs += [EssentialBC(u, u1, subdomain=sub2)] - petsc = PETScSolve([eqn] + bcs, target=u, solver_parameters={'ksp_rtol': 1e-10}) + petsc = petscsolve([eqn] + bcs, target=u, solver_parameters={'ksp_rtol': 1e-10}) op = Operator(petsc, language='petsc') op.apply() @@ -1414,7 +1522,7 @@ def test_logging(self, log_level): f.data[:] = 5.0 eq = Eq(e.laplace, f) - petsc = PETScSolve(eq, target=e, options_prefix='poisson') + petsc = petscsolve(eq, target=e, options_prefix='poisson') with switchconfig(language='petsc', log_level=log_level): op = Operator(petsc) @@ -1473,8 +1581,8 @@ def test_logging_multiple_solves(self, log_level): eq1 = Eq(g.laplace, e) eq2 = Eq(h, f + 5.0) - solver1 = PETScSolve(eq1, target=g, options_prefix='poisson1') - solver2 = PETScSolve(eq2, target=h, options_prefix='poisson2') + solver1 = petscsolve(eq1, target=g, options_prefix='poisson1') + solver2 = petscsolve(eq2, target=h, options_prefix='poisson2') with switchconfig(language='petsc', log_level=log_level): op = Operator([solver1, solver2]) @@ -1516,8 +1624,8 @@ def test_logging_user_prefixes(self, log_level): pde1 = Eq(e.laplace, f) pde2 = Eq(g.laplace, h) - petsc1 = PETScSolve(pde1, target=e, options_prefix='pde1') - petsc2 = PETScSolve(pde2, target=g, options_prefix='pde2') + petsc1 = petscsolve(pde1, target=e, options_prefix='pde1') + petsc2 = petscsolve(pde2, target=g, options_prefix='pde2') with switchconfig(language='petsc', log_level=log_level): op = Operator([petsc1, petsc2]) @@ -1545,8 +1653,8 @@ def test_logging_default_prefixes(self, log_level): pde1 = Eq(e.laplace, f) pde2 = Eq(g.laplace, h) - petsc1 = PETScSolve(pde1, target=e) - petsc2 = PETScSolve(pde2, target=g) + petsc1 = petscsolve(pde1, target=e) + petsc2 = petscsolve(pde2, target=g) with switchconfig(language='petsc', log_level=log_level): op = Operator([petsc1, petsc2]) @@ -1578,11 +1686,11 @@ def setup_class(self): @skipif('petsc') def test_different_solver_params(self): # Explicitly set the solver parameters - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, solver_parameters={'ksp_rtol': '1e-10'} ) # Use solver parameter defaults - solver2 = PETScSolve(self.eq2, target=self.g) + solver2 = petscsolve(self.eq2, target=self.g) with switchconfig(language='petsc'): op = Operator([solver1, solver2]) @@ -1598,10 +1706,10 @@ def test_different_solver_params(self): @skipif('petsc') def test_options_prefix(self): - solver1 = PETScSolve(self.eq1, self.e, + solver1 = petscsolve(self.eq1, self.e, solver_parameters={'ksp_rtol': '1e-10'}, options_prefix='poisson1') - solver2 = PETScSolve(self.eq2, self.g, + solver2 = petscsolve(self.eq2, self.g, solver_parameters={'ksp_rtol': '1e-12'}, options_prefix='poisson2') @@ -1625,7 +1733,7 @@ def test_options_no_value(self): Test solver parameters that do not require a value, such as `snes_view` and `ksp_view`. """ - solver = PETScSolve( + solver = petscsolve( self.eq1, target=self.e, solver_parameters={'snes_view': None}, options_prefix='solver1' ) @@ -1645,7 +1753,7 @@ def test_tolerances(self, log_level): 'ksp_divtol': 1e3, 'ksp_max_it': 100 } - solver = PETScSolve( + solver = petscsolve( self.eq1, target=self.e, solver_parameters=params, options_prefix='solver' ) @@ -1668,11 +1776,11 @@ def test_tolerances(self, log_level): @skipif('petsc') def test_clearing_options(self): # Explicitly set the solver parameters - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, solver_parameters={'ksp_rtol': '1e-10'} ) # Use the solver parameter defaults - solver2 = PETScSolve(self.eq2, target=self.g) + solver2 = petscsolve(self.eq2, target=self.g) with switchconfig(language='petsc'): op = Operator([solver1, solver2]) @@ -1686,11 +1794,11 @@ def test_error_if_same_prefix(self): Test an error is raised if the same options prefix is used for two different solvers within the same Operator. """ - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, options_prefix='poisson', solver_parameters={'ksp_rtol': '1e-10'} ) - solver2 = PETScSolve( + solver2 = petscsolve( self.eq2, target=self.g, options_prefix='poisson', solver_parameters={'ksp_rtol': '1e-12'} ) @@ -1702,19 +1810,19 @@ def test_error_if_same_prefix(self): @pytest.mark.parametrize('log_level', ['PERF', 'DEBUG']) def test_multiple_operators(self, log_level): """ - Verify that solver parameters are set correctly when multiple Operators - are created with PETScSolve instances sharing the same options_prefix. + Verify that solver parameters are set correctly when multiple `Operator`s + are created with `petscsolve` calls sharing the same `options_prefix`. - Note: Using the same options_prefix within a single Operator is not allowed + Note: Using the same `options_prefix` within a single `Operator` is not allowed (see previous test), but the same prefix can be used across - different Operators (although not advised). + different `Operator`s (although not advised). """ - # Create two PETScSolve instances with the same options_prefix - solver1 = PETScSolve( + # Create two `petscsolve` calls with the same `options_prefix`` + solver1 = petscsolve( self.eq1, target=self.e, options_prefix='poisson', solver_parameters={'ksp_rtol': '1e-10'} ) - solver2 = PETScSolve( + solver2 = petscsolve( self.eq2, target=self.g, options_prefix='poisson', solver_parameters={'ksp_rtol': '1e-12'} ) @@ -1743,7 +1851,7 @@ def test_command_line_priority_tols_1(self, command_line, log_level): prefix = 'd17weqroeg' _, expected = command_line - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, options_prefix=prefix ) @@ -1762,7 +1870,7 @@ def test_command_line_priority_tols_2(self, command_line, log_level): prefix = 'riabfodkj5' _, expected = command_line - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, options_prefix=prefix ) @@ -1795,7 +1903,7 @@ def test_command_line_priority_tols3(self, command_line, log_level): 'ksp_max_it': 500 } - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, solver_parameters=params, options_prefix=prefix @@ -1824,7 +1932,7 @@ def test_command_line_priority_ksp_type(self, command_line, log_level): # see the `command_line` fixture). params = {'ksp_type': 'richardson'} - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, solver_parameters=params, options_prefix=prefix @@ -1850,7 +1958,7 @@ def test_command_line_priority_ccode(self, command_line): """ prefix = 'qtr2vfvwiu' - solver = PETScSolve( + solver = petscsolve( self.eq1, target=self.e, # Specify a solver parameter that is not set via the # command line (see the `command_line` fixture for this prefix). @@ -1899,21 +2007,21 @@ def test_solveexpr(self): e, f = functions eq = Eq(e.laplace, f) - # Two PETScSolve instances with different options_prefix values + # Two `petscsolve` calls with different `options_prefix` values # should hash differently. - petsc1 = PETScSolve(eq, target=e, options_prefix='poisson1') - petsc2 = PETScSolve(eq, target=e, options_prefix='poisson2') + petsc1 = petscsolve(eq, target=e, options_prefix='poisson1') + petsc2 = petscsolve(eq, target=e, options_prefix='poisson2') assert hash(petsc1.rhs) != hash(petsc2.rhs) assert petsc1.rhs != petsc2.rhs - # Two PETScSolve instances with the same options_prefix but - # different solver parameters should hash differently. - petsc3 = PETScSolve( + # Two `petscsolve` calls with the same `options_prefix` but + # different `solver_parameters` should hash differently. + petsc3 = petscsolve( eq, target=e, solver_parameters={'ksp_type': 'cg'}, options_prefix='poisson3' ) - petsc4 = PETScSolve( + petsc4 = petscsolve( eq, target=e, solver_parameters={'ksp_type': 'richardson'}, options_prefix='poisson3' ) @@ -1922,7 +2030,7 @@ def test_solveexpr(self): class TestGetInfo: """ - Test the `get_info` optional argument to `PETScSolve`. + Test the `get_info` (optional) argument to `petscsolve`. This argument can be used independently of the `log_level` to retrieve specific information about the solve, such as the number of KSP @@ -1945,7 +2053,7 @@ def setup_class(self): @skipif('petsc') def test_get_info(self): get_info = ['kspgetiterationnumber', 'snesgetiterationnumber'] - petsc = PETScSolve( + petsc = petscsolve( self.eq1, target=self.e, options_prefix='pde1', get_info=get_info ) with switchconfig(language='petsc'): @@ -1968,7 +2076,7 @@ def test_get_info_with_logging(self, log_level): Test that `get_info` works correctly when logging is enabled. """ get_info = ['kspgetiterationnumber'] - petsc = PETScSolve( + petsc = petscsolve( self.eq1, target=self.e, options_prefix='pde1', get_info=get_info ) with switchconfig(language='petsc', log_level=log_level): @@ -1991,15 +2099,15 @@ def test_different_solvers(self): Test that `get_info` works correctly when multiple solvers are used within the same Operator. """ - # Create two PETScSolve instances with different get_info arguments + # Create two `petscsolve` calls with different `get_info` arguments get_info_1 = ['kspgetiterationnumber'] get_info_2 = ['snesgetiterationnumber'] - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, options_prefix='pde1', get_info=get_info_1 ) - solver2 = PETScSolve( + solver2 = petscsolve( self.eq2, target=self.g, options_prefix='pde2', get_info=get_info_2 ) with switchconfig(language='petsc'): @@ -2028,7 +2136,7 @@ def test_case_insensitive(self): """ # Create a list with mixed cases get_info = ['KSPGetIterationNumber', 'snesgetiterationnumber'] - petsc = PETScSolve( + petsc = petscsolve( self.eq1, target=self.e, options_prefix='pde1', get_info=get_info ) with switchconfig(language='petsc'): @@ -2048,10 +2156,10 @@ def test_get_ksp_type(self): a string. """ get_info = ['kspgettype'] - solver1 = PETScSolve( + solver1 = petscsolve( self.eq1, target=self.e, options_prefix='poisson1', get_info=get_info ) - solver2 = PETScSolve( + solver2 = petscsolve( self.eq1, target=self.e, options_prefix='poisson2', solver_parameters={'ksp_type': 'cg'}, get_info=get_info ) @@ -2076,3 +2184,24 @@ def test_get_ksp_type(self): assert entry2.KSPGetType == 'cg' assert entry2['KSPGetType'] == 'cg' assert entry2['kspgettype'] == 'cg' + + +class TestPrinter: + + @skipif('petsc') + def test_petsc_pi(self): + """ + Test that sympy.pi is correctly translated to PETSC_PI in the + generated code. + """ + grid = Grid(shape=(11, 11), dtype=np.float64) + e = Function(name='e', grid=grid) + eq = Eq(e, sp.pi) + + petsc = petscsolve(eq, target=e) + + with switchconfig(language='petsc'): + op = Operator(petsc) + + assert 'PETSC_PI' in str(op.ccode) + assert 'M_PI' not in str(op.ccode)