Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/building/sources/accumulate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,18 @@ For full control, provide an explicit list of ``(basetime, steps)`` pairs.

These two examples are equivalent to those shown in Option 1 above.

Controlling the fields regrouped within accumulation
====================================================

It is possible to control the fields accumulated together through their metadata.
The ``group_by`` keyword allows to ignore some metadata when deciding to group field to accumulate them together.
Ignored keys mean that fields with different values will be accumulated together.
Note that ``date,time,step`` should be ignored by default.

.. literalinclude:: yaml/accumulations-mars-groupby.yaml
:language: yaml


.. [1]

For ECMWF forecasts, the forecasts at 00Z and 12Z are from the stream
Expand Down
32 changes: 27 additions & 5 deletions src/anemoi/datasets/create/sources/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def _compute_accumulations(
dates: list[datetime.datetime],
period: datetime.timedelta,
source: dict,
group_by: dict,
availability: dict[str, Any] | None = None,
patch: dict | None = None,
**kwargs,
Expand All @@ -338,6 +339,8 @@ def _compute_accumulations(
The source configuration to request fields from
period: datetime.timedelta,
The interval over which to accumulate (user-defined)
group_by: dict,
The keys in fields metadata that are grouped together
availability: Any, optional
A description of the available periods in the data source. See documentation.
patch: list[dict] | None, optional
Expand Down Expand Up @@ -399,10 +402,10 @@ def _compute_accumulations(

values = field.values.copy()

key = field.metadata(namespace="mars")
key = {k: v for k, v in key.items() if k not in ["date", "time", "step"]}
key = field.metadata(namespace=group_by["namespace"])
key = {k: v for k, v in key.items() if k not in group_by["ignore"]}
key = tuple(sorted(key.items()))
log = " ".join(f"{k}={v}" for k, v in field.metadata(namespace="mars").items())
log = " ".join(f"{k}={v}" for k, v in field.metadata(namespace=group_by["namespace"]).items())

field_interval = field_to_interval(field)

Expand Down Expand Up @@ -468,6 +471,21 @@ def check_missing_accumulators():
return ds


def patch_groupby_keys(group_by: dict | None = None):
if group_by is None:
return {"namespace": "mars", "ignore": ["date", "time", "step"]}
else:
namespace = group_by.get("namespace", None)
if namespace is None:
raise ValueError("No namespace in group_by (set namespace: mars for default)")
if namespace != "mars":
raise ValueError(f"Namespace {namespace} not supported, use 'mars'")
ignore = group_by.get("ignore", [])
for key in ["date", "time", "step"]:
assert key in ignore, f"{key} absent in ignore list {ignore}, at least 'date', 'time', 'step' required"
return group_by


@source_registry.register("accumulate")
class AccumulateSource(LegacySource):

Expand All @@ -479,6 +497,7 @@ def _execute(
period: str | int | datetime.timedelta,
availability=None,
patch: Any = None,
group_by: dict | None = None,
) -> Any:
"""Accumulation source callable function.
Read the recipe for accumulation in the request dictionary, check main arguments and call computation.
Expand All @@ -497,7 +516,8 @@ def _execute(
A description of the available periods in the data source. See documentation.
patch: Any, optional
A description of patches to apply to fields returned by the source to fix metadata issues.

group_by: dict, optional
A description for field metadata to be regrouped in accumulation
Return
------
The accumulated data source.
Expand All @@ -511,7 +531,9 @@ def _execute(
if "accumulation_period" in source:
raise ValueError("'accumulation_period' should be define outside source for accumulate action as 'period'")

group_by = patch_groupby_keys(group_by)

period = frequency_to_timedelta(period)
return _compute_accumulations(
context, dates, source=source, period=period, availability=availability, patch=patch
context, dates, source=source, period=period, availability=availability, patch=patch, group_by=group_by
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,52 @@
import datetime
import logging

from anemoi.utils.dates import frequency_to_timedelta

from .covering_intervals import SignedInterval

LOG = logging.getLogger(__name__)


def _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=None):
# Because the data wrongly encode start_step, but end_step is correct
# and we know that accumulations are always reseted every multiple of 24 hours
#
# 1-1 -> 0-1
# 2-2 -> 0-2
# ...
# 23-23 -> 0-23
# 24-24 -> 0-24
# 25-25 -> 24-25
# 26-26 -> 24-26
# ...
# 47-47 -> 24-47
# 48-48 -> 24-48
# 49-49 -> 48-49
# 50-50 -> 48-50
# etc.
if endStep % 24 == 0:
# Special case: endStep is exactly 24, 48, 72, etc.
# Map to previous 24-hour boundary (24 -> 0, 48 -> 24, etc.)
return endStep - 24, endStep

# General case: floor to the nearest 24-hour boundary
# (1-23 -> 0, 25-47 -> 24, etc.)
return endStep - (endStep % 24), endStep


patch_registry = {"reset_24h_accumulations": _set_start_step_from_end_step_ceiled_to_24_hours}


class FieldToInterval:
"""Convert a field to its accumulation interval, applying patches if needed."""

def __init__(self, patches: dict | None = None):
if patches is None:
patches = {}
assert isinstance(patches, dict), ("patches must be a dict", patches)
patches = []
assert isinstance(patches, list), ("patches must be a list", patches)

self.patches = patches
for key in patches:
if key not in (
"start_step_is_zero",
"start_step_is_end_step",
"start_step_greater_than_end_step",
"set_start_step_to_zero",
):
if key not in patch_registry:
raise ValueError(f"Unknown patch key: {key}")

def __call__(self, field) -> SignedInterval:
Expand All @@ -44,19 +67,18 @@ def __call__(self, field) -> SignedInterval:
endStep = field.metadata("endStep")
startStep = field.metadata("startStep")

LOG.debug(f" field before patching: {startStep=}, {endStep=}")
LOG.debug(f" 🌧️: field before patching: {startStep=}, {endStep=}")

if self.patches.get("set_start_step_to_zero", False):
startStep, endStep = 0, endStep
for patch_name in self.patches:
patch_func = patch_registry[patch_name]
startStep, endStep = patch_func(startStep, endStep, field)

LOG.debug(f" 🌧️: field after user patches: {startStep=}, {endStep=}")

if startStep > endStep:
startStep, endStep = self.start_step_greater_than_end_step(startStep, endStep, field=field)
startStep, endStep = endStep, startStep
elif startStep == endStep:
startStep, endStep = self.start_step_is_end_step(startStep, endStep, field=field)
elif frequency_to_timedelta(startStep).total_seconds() == 0:
startStep, endStep = self.start_step_is_zero(startStep, endStep, field=field)

LOG.debug(f" field after patching : {startStep=}, {endStep=}")
startStep, endStep = 0, endStep

start_step = datetime.timedelta(hours=startStep)
end_step = datetime.timedelta(hours=endStep)
Expand All @@ -72,82 +94,3 @@ def __call__(self, field) -> SignedInterval:
assert valid_date == interval.max, (valid_date, interval)

return interval

def start_step_is_zero(self, startStep, endStep, field=None):
# Patch to handle cases where start_step is zero
# No patch yet implemented
match self.patches.get("start_step_is_zero", None):
case False | None:
pass # do nothing
case _ as options:
raise ValueError(f"Unknown option for patch.start_step_is_zero: {options}")

return startStep, endStep

def start_step_is_end_step(self, startStep, endStep, field=None):
# Patch to handle cases where start_step equals end_step
# this should not happen in normal cases but some datasets have this issue
# The default is to set start_step to zero
# This can be disabled by setting the patch to False

match self.patches.get("start_step_is_end_step", "set_start_step_to_zero"):
case False | None:
pass # do nothing

case "set_from_end_step_ceiled_to_24_hours":
startStep, endStep = _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=field)

case "set_start_step_to_zero":
startStep, endStep = 0, endStep

case _ as options:
raise ValueError(f"Unknown option for patch.start_step_is_end_step: {options}")

return startStep, endStep

def start_step_greater_than_end_step(self, startStep, endStep, field=None):

# Patch to handle cases where start_step is greater than end_step
# this should not happen in normal cases but some datasets have this issue
# The default is to do swap the values of start_step and end_step
# This can be disabled by setting the patch to False

match self.patches.get("start_step_greater_than_end_step", None):

case False | None:
pass # do nothing

case "swap":
startStep, endStep = endStep, startStep

case _ as options:
raise ValueError(f"Unknown option for patch.start_step_greater_than_end_step: {options}")

return startStep, endStep


def _set_start_step_from_end_step_ceiled_to_24_hours(startStep, endStep, field=None):
# Because the data wrongly encode start_step, but end_step is correct
# and we know that accumulations are always reseted every multiple of 24 hours
#
# 1-1 -> 0-1
# 2-2 -> 0-2
# ...
# 23-23 -> 0-23
# 24-24 -> 0-24
# 25-25 -> 24-25
# 26-26 -> 24-26
# ...
# 47-47 -> 24-47
# 48-48 -> 24-48
# 49-49 -> 48-49
# 50-50 -> 48-50
# etc.
if endStep % 24 == 0:
# Special case: endStep is exactly 24, 48, 72, etc.
# Map to previous 24-hour boundary (24 -> 0, 48 -> 24, etc.)
return endStep - 24, endStep

# General case: floor to the nearest 24-hour boundary
# (1-23 -> 0, 25-47 -> 24, etc.)
return endStep - (endStep % 24), endStep
4 changes: 3 additions & 1 deletion tests/create/accumulate-mars-ea-oper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ input:
grid: 20./20.
levtype: sfc
param: [ tp, cp ]

group_by:
namespace: 'mars'
ignore: ['date','time','step','timespan']

checks:
none: {}
2 changes: 1 addition & 1 deletion tests/create/accumulate-mars-reset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ input:
# print("/".join(steps))
patch:
# because the data wrongly encode start_step as 0 even after 24
start_step_is_end_step: set_from_end_step_ceiled_to_24_hours
- reset_24h_accumulations

source:
mars:
Expand Down
Loading