Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/parcels/_core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from parcels._core.uxgrid import UxGrid
from parcels._core.xgrid import XGrid, _transpose_xfield_data_to_tzyx, assert_all_field_dims_have_axis
from parcels._python import assert_same_function_signature
from parcels._reprs import default_repr
from parcels._reprs import field_repr, vectorfield_repr
from parcels._typing import VectorType
from parcels.interpolators import (
ZeroInterpolator,
Expand Down Expand Up @@ -148,6 +148,9 @@ def __init__(
if "time" not in self.data.coords:
raise ValueError("Field data is missing a 'time' coordinate.")

def __repr__(self):
return field_repr(self)

@property
def units(self):
return self._units
Expand Down Expand Up @@ -277,11 +280,7 @@ def __init__(
self._vector_interp_method = vector_interp_method

def __repr__(self):
return f"""<{type(self).__name__}>
name: {self.name!r}
U: {default_repr(self.U)}
V: {default_repr(self.V)}
W: {default_repr(self.W)}"""
return vectorfield_repr(self)

@property
def vector_interp_method(self):
Expand Down
4 changes: 4 additions & 0 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from parcels._core.uxgrid import UxGrid
from parcels._core.xgrid import _DEFAULT_XGCM_KWARGS, XGrid
from parcels._logger import logger
from parcels._reprs import fieldset_repr
from parcels._typing import Mesh
from parcels.interpolators import UxPiecewiseConstantFace, UxPiecewiseLinearNode, XConstantField, XLinear

Expand Down Expand Up @@ -75,6 +76,9 @@ def __getattr__(self, name):
else:
raise AttributeError(f"FieldSet has no attribute '{name}'")

def __repr__(self):
return fieldset_repr(self)

@property
def time_interval(self):
"""Returns the valid executable time interval of the FieldSet,
Expand Down
7 changes: 3 additions & 4 deletions src/parcels/_core/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from parcels._core.statuscodes import StatusCode
from parcels._core.utils.string import _assert_str_and_python_varname
from parcels._core.utils.time import TimeInterval
from parcels._reprs import _format_list_items_multiline
from parcels._reprs import particleclass_repr, variable_repr

__all__ = ["Particle", "ParticleClass", "Variable"]
_TO_WRITE_OPTIONS = [True, False, "once"]
Expand Down Expand Up @@ -70,7 +70,7 @@ def name(self):
return self._name

def __repr__(self):
return f"Variable(name={self._name!r}, dtype={self.dtype!r}, initial={self.initial!r}, to_write={self.to_write!r}, attrs={self.attrs!r})"
return variable_repr(self)


class ParticleClass:
Expand All @@ -92,8 +92,7 @@ def __init__(self, variables: list[Variable]):
self.variables = variables

def __repr__(self):
vars = [repr(v) for v in self.variables]
return f"ParticleClass(variables={_format_list_items_multiline(vars)})"
return particleclass_repr(self)

def add_variable(self, variable: Variable | list[Variable]):
"""Add a new variable to the Particle class. This returns a new Particle class with the added variable(s).
Expand Down
8 changes: 2 additions & 6 deletions src/parcels/_core/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import parcels
from parcels._core.particle import ParticleClass
from parcels._core.utils.time import timedelta_to_float
from parcels._reprs import particlefile_repr

if TYPE_CHECKING:
from parcels._core.particle import Variable
Expand Down Expand Up @@ -96,12 +97,7 @@ def __init__(self, store, outputdt, chunks=None, create_new_zarrfile=True):
# TODO v4: Add check that if create_new_zarrfile is False, the store already exists

def __repr__(self) -> str:
return (
f"{type(self).__name__}("
f"outputdt={self.outputdt!r}, "
f"chunks={self.chunks!r}, "
f"create_new_zarrfile={self.create_new_zarrfile!r})"
)
return particlefile_repr(self)

def set_metadata(self, parcels_grid_mesh: Literal["spherical", "flat"]):
self.metadata.update(
Expand Down
14 changes: 3 additions & 11 deletions src/parcels/_core/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import xarray as xr
from tqdm import tqdm
from zarr.storage import DirectoryStore

from parcels._core.converters import _convert_to_flat_array
from parcels._core.kernel import Kernel
Expand All @@ -21,7 +20,7 @@
)
from parcels._core.warnings import ParticleSetWarning
from parcels._logger import logger
from parcels._reprs import particleset_repr
from parcels._reprs import _format_zarr_output_location, particleset_repr

__all__ = ["ParticleSet"]

Expand Down Expand Up @@ -70,7 +69,6 @@ def __init__(
**kwargs,
):
self._data = None
self._repeat_starttime = None
self._kernel = None

self.fieldset = fieldset
Expand Down Expand Up @@ -167,7 +165,7 @@ def __getattr__(self, name):

def __getitem__(self, index):
"""Get a single particle by index."""
return ParticleSetView(self._data, index=index)
return ParticleSetView(self._data, index=index, ptype=self._ptype)

def __setattr__(self, name, value):
if name in ["_data"]:
Expand Down Expand Up @@ -447,7 +445,7 @@ def execute(

# Set up pbar
if output_file:
logger.info(f"Output files are stored in {_format_output_location(output_file.store)}")
logger.info(f"Output files are stored in {_format_zarr_output_location(output_file.store)}")

if verbose_progress:
pbar = tqdm(total=end_time - start_time, file=sys.stdout)
Expand Down Expand Up @@ -592,9 +590,3 @@ def _get_start_time(first_release_time, time_interval, sign_dt, runtime):

start_time = first_release_time if not np.isnan(first_release_time) else fieldset_start
return start_time


def _format_output_location(zarr_obj):
if isinstance(zarr_obj, DirectoryStore):
return zarr_obj.path
return repr(zarr_obj)
16 changes: 11 additions & 5 deletions src/parcels/_core/particlesetview.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import numpy as np

from parcels._reprs import particlesetview_repr


class ParticleSetView:
"""Class to be used in a kernel that links a View of the ParticleSet (on the kernel level) to a ParticleSet."""

def __init__(self, data, index):
def __init__(self, data, index, ptype):
self._data = data
self._index = index
self._ptype = ptype

def __getattr__(self, name):
# Return a proxy that behaves like the underlying numpy array but
Expand All @@ -25,11 +28,14 @@ def __getattr__(self, name):
return self._data[name][self._index]

def __setattr__(self, name, value):
if name in ["_data", "_index"]:
if name in ["_data", "_index", "_ptype"]:
object.__setattr__(self, name, value)
else:
self._data[name][self._index] = value

def __repr__(self):
return particlesetview_repr(self)

def __getitem__(self, index):
# normalize single-element tuple indexing (e.g., (inds,))
if isinstance(index, tuple) and len(index) == 1:
Expand All @@ -50,7 +56,7 @@ def __getitem__(self, index):
raise ValueError(
f"Boolean index has incompatible length {arr.size} for selection of size {int(np.sum(base))}"
)
return ParticleSetView(self._data, new_index)
return ParticleSetView(self._data, new_index, self._ptype)

# Integer array/list, slice or single integer relative to the local view
# (boolean masks were handled above). Normalize and map to global
Expand All @@ -65,12 +71,12 @@ def __getitem__(self, index):
base_arr = np.asarray(base)
sel = base_arr[idx]
new_index[sel] = True
return ParticleSetView(self._data, new_index)
return ParticleSetView(self._data, new_index, self._ptype)

# Fallback: try to assign directly (preserves previous behaviour for other index types)
try:
new_index[base] = index
return ParticleSetView(self._data, new_index)
return ParticleSetView(self._data, new_index, self._ptype)
except Exception as e:
raise TypeError(f"Unsupported index type for ParticleSetView.__getitem__: {type(index)!r}") from e

Expand Down
4 changes: 3 additions & 1 deletion src/parcels/_core/utils/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import cftime
import numpy as np

from parcels._reprs import timeinterval_repr

if TYPE_CHECKING:
from parcels._typing import TimeLike

Expand Down Expand Up @@ -61,7 +63,7 @@ def is_all_time_in_interval(self, time: float):
return (0 <= item).all() and (item <= self.time_length_as_flt).all()

def __repr__(self) -> str:
return f"TimeInterval(left={self.left!r}, right={self.right!r})"
return timeinterval_repr(self)

def __eq__(self, other: object) -> bool:
if not isinstance(other, TimeInterval):
Expand Down
4 changes: 4 additions & 0 deletions src/parcels/_core/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from parcels._core.basegrid import BaseGrid
from parcels._core.index_search import _search_1d_array, _search_indices_curvilinear_2d
from parcels._reprs import xgrid_repr
from parcels._typing import assert_valid_mesh

_XGRID_AXES = Literal["X", "Y", "Z"]
Expand Down Expand Up @@ -135,6 +136,9 @@ def from_dataset(cls, ds: xr.Dataset, mesh, xgcm_kwargs=None):
grid = xgcm.Grid(ds, **xgcm_kwargs)
return cls(grid, mesh=mesh)

def __repr__(self):
return xgrid_repr(self)

@property
def axes(self) -> list[_XGRID_AXES]:
return _get_xgrid_axes(self.xgcm_grid)
Expand Down
Loading
Loading