Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from devito.ir.iet import (FindNodes, FindSymbols, Iteration, ParallelBlock,
retrieve_iteration_tree)
from devito.tools import as_tuple
from devito.petsc.utils import PetscOSError, get_petsc_dir
from devito.petsc.config import PetscOSError, get_petsc_dir

try:
from mpi4py import MPI # noqa
Expand Down
8 changes: 8 additions & 0 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
check_stability, PetscTarget)
from devito.tools import timed_pass

from devito.petsc.iet.passes import lower_petsc_symbols

__all__ = ['Cpu64NoopCOperator', 'Cpu64NoopOmpOperator', 'Cpu64AdvCOperator',
'Cpu64AdvOmpOperator', 'Cpu64FsgCOperator', 'Cpu64FsgOmpOperator',
'Cpu64CustomOperator', 'Cpu64CustomCXXOperator', 'Cpu64AdvCXXOperator',
Expand Down Expand Up @@ -143,6 +145,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

# Lower PETSc symbols
lower_petsc_symbols(graph, **kwargs)

return graph


Expand Down Expand Up @@ -222,6 +227,9 @@ def _specialize_iet(cls, graph, **kwargs):
# Symbol definitions
cls._Target.DataManager(**kwargs).process(graph)

# Lower PETSc symbols
lower_petsc_symbols(graph, **kwargs)

# Linearize n-dimensional Indexeds
linearize(graph, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion devito/ir/iet/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Section, HaloSpot, ExpressionBundle)
from devito.tools import timed_pass
from devito.petsc.types import MetaData
from devito.petsc.iet.utils import petsc_iet_mapper
from devito.petsc.iet.nodes import petsc_iet_mapper

__all__ = ['iet_build']

Expand Down
2 changes: 1 addition & 1 deletion devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ class FindSymbols(LazyVisitor[Any, list[Any], None]):
Drive the search. Accepted:
- `symbolics`: Collect all AbstractFunction objects, default
- `basics`: Collect all Basic objects
- `abstractsymbols`: Collect all AbstractSymbol objects
- `symbols`: Collect all AbstractSymbol objects
- `dimensions`: Collect all Dimensions
- `indexeds`: Collect all Indexed objects
- `indexedbases`: Collect all IndexedBase objects
Expand Down
6 changes: 5 additions & 1 deletion devito/passes/iet/languages/C.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from devito.passes.iet.langbase import LangBB
from devito.symbolics import c_complex, c_double_complex
from devito.tools import dtype_to_cstr
from devito.petsc.utils import petsc_type_mappings

from devito.petsc.config import petsc_type_mappings

__all__ = ['CBB', 'CDataManager', 'COrchestrator']

Expand Down Expand Up @@ -82,3 +83,6 @@ class PetscCPrinter(CPrinter):
_restrict_keyword = ''

type_mappings = {**CPrinter.type_mappings, **petsc_type_mappings}

def _print_Pi(self, expr):
return 'PETSC_PI'
86 changes: 86 additions & 0 deletions devito/petsc/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os
import ctypes
from pathlib import Path

from petsctools import get_petscvariables, MissingPetscException

from devito.tools import memoized_func


class PetscOSError(OSError):
pass


@memoized_func
def get_petsc_dir():
petsc_dir = os.environ.get('PETSC_DIR')
if petsc_dir is None:
raise PetscOSError("PETSC_DIR environment variable not set")
else:
petsc_dir = (Path(petsc_dir),)

petsc_arch = os.environ.get('PETSC_ARCH')
if petsc_arch is not None:
petsc_dir += (petsc_dir[0] / petsc_arch,)

petsc_installed = petsc_dir[-1] / 'include' / 'petscconf.h'
if not petsc_installed.is_file():
raise PetscOSError("PETSc is not installed")

return petsc_dir


@memoized_func
def core_metadata():
petsc_dir = get_petsc_dir()

petsc_include = tuple([arch / 'include' for arch in petsc_dir])
petsc_lib = tuple([arch / 'lib' for arch in petsc_dir])

return {
'includes': ('petscsnes.h', 'petscdmda.h'),
'include_dirs': petsc_include,
'libs': ('petsc'),
'lib_dirs': petsc_lib,
'ldflags': tuple([f"-Wl,-rpath,{lib}" for lib in petsc_lib])
}


try:
petsc_variables = get_petscvariables()
except MissingPetscException:
petsc_variables = {}


def get_petsc_type_mappings():
try:
petsc_precision = petsc_variables['PETSC_PRECISION']
except KeyError:
printer_mapper = {}
petsc_type_to_ctype = {}
else:
petsc_scalar = 'PetscScalar'
# TODO: Check to see whether Petsc is compiled with
# 32-bit or 64-bit integers
printer_mapper = {ctypes.c_int: 'PetscInt'}

if petsc_precision == 'single':
printer_mapper[ctypes.c_float] = petsc_scalar
elif petsc_precision == 'double':
printer_mapper[ctypes.c_double] = petsc_scalar

# Used to construct ctypes.Structures that wrap PETSc objects
petsc_type_to_ctype = {v: k for k, v in printer_mapper.items()}
# Add other PETSc types
petsc_type_to_ctype.update({
'KSPType': ctypes.c_char_p,
'KSPConvergedReason': petsc_type_to_ctype['PetscInt'],
'KSPNormType': petsc_type_to_ctype['PetscInt'],
})
return printer_mapper, petsc_type_to_ctype


petsc_type_mappings, petsc_type_to_ctype = get_petsc_type_mappings()


petsc_languages = ['petsc']
Loading
Loading