Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions weatherbenchX/data_loaders/latency_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,18 @@ def __init__(
process_chunk_fn=data_loader._process_chunk_fn,
)

def get_available_init_time(self, init_time: np.datetime64) -> np.datetime64:
def get_available_init_time(
self, init_time: np.datetime64
) -> np.datetime64 | None:
"""Return most recent available nominal init time for requested init time."""
issue_time = self.nominal_init_times + self.latency
diff = (issue_time - init_time).astype(int)
# Find index of issue time that is closest to requested init_time.
# on the left, i.e. with issue_time > nominal init_time.
available_idx = np.nanargmax(np.where(diff <= 0, diff, np.nan))
diff = np.where(diff <= 0, diff, np.nan)
if np.all(np.isnan(diff)):
return None
available_idx = np.nanargmax(diff)
available_init_time = self.nominal_init_times[available_idx]
return available_init_time

Expand Down Expand Up @@ -259,13 +264,20 @@ def _get_data_loader(self, init_time):
lead_time_offsets_and_latencies = []
for data_loader in self._data_loaders:
available_init_time = data_loader.get_available_init_time(init_time)
lead_time_offset = init_time - available_init_time
# Break ties by picking the data loader with largest latency -- note that
# we make latency negative here because we want the smallest
# lead_time_offset, but the largest data loader latency.
lead_time_offsets_and_latencies.append(
(lead_time_offset, -data_loader.latency)
)
if available_init_time is None:
# If there is no available init time, we will assign an infinite lead
# time offset and latency. Since there is no actual "inf" timedelta,
# we'll just use 1e6 days.
inf_time = np.timedelta64(int(1e6), 'D')
lead_time_offsets_and_latencies.append((inf_time, inf_time))
else:
lead_time_offset = init_time - available_init_time
# Break ties by picking the data loader with largest latency -- note
# that we make latency negative here because we want the smallest
# lead_time_offset, but the largest data loader latency.
lead_time_offsets_and_latencies.append(
(lead_time_offset, -data_loader.latency)
)
lead_time_offsets_and_latencies = np.array(
lead_time_offsets_and_latencies,
dtype=[
Expand All @@ -281,7 +293,7 @@ def _get_data_loader(self, init_time):
logging.info(
'Init time: %s, data loader latency: %s',
init_time,
most_recent_data_loader.latency,
most_recent_data_loader.latency.astype('timedelta64[m]'),
)
return most_recent_data_loader

Expand Down
122 changes: 122 additions & 0 deletions weatherbenchX/data_loaders/latency_wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,128 @@ def test_multiple_latency_wrappers(self):
correct_output['2m_temperature'].values,
)

def test_multiple_latency_wrappers_tie_breaking(self):
# Setup two loaders with same nominal init times but different latencies.
# We want a case where for a specific query time, they both return the SAME
# available init time.
prediction_1 = test_utils.mock_prediction_data(
time_start='2020-01-01T00',
time_stop='2020-01-02T00',
time_resolution=np.timedelta64(12, 'h'), # 00, 12
lead_start='0 hours',
lead_stop='24 hours',
lead_resolution='1 hours',
) + 1.0 # Add 1 to distinguish
prediction_2 = test_utils.mock_prediction_data(
time_start='2020-01-01T00',
time_stop='2020-01-02T00',
time_resolution=np.timedelta64(12, 'h'), # 00, 12
lead_start='0 hours',
lead_stop='24 hours',
lead_resolution='1 hours',
) + 2.0 # Add 2 to distinguish

prediction_path_1 = self.create_tempdir('prediction_1.zarr').full_path
prediction_path_2 = self.create_tempdir('prediction_2.zarr').full_path
prediction_1.to_zarr(prediction_path_1)
prediction_2.to_zarr(prediction_path_2)

data_loader_1 = xarray_loaders.PredictionsFromXarray(
path=prediction_path_1, variables=['2m_temperature']
)
data_loader_2 = xarray_loaders.PredictionsFromXarray(
path=prediction_path_2, variables=['2m_temperature']
)

# Loader 1: Latency 6h.
# Loader 2: Latency 12h.
wrapper1 = latency_wrappers.XarrayConstantLatencyWrapper(
data_loader_1, latency=np.timedelta64(6, 'h')
)
wrapper2 = latency_wrappers.XarrayConstantLatencyWrapper(
data_loader_2, latency=np.timedelta64(12, 'h')
)

multi_wrapper = latency_wrappers.MultipleConstantLatencyWrapper(
[wrapper1, wrapper2]
)

# Query time: 2020-01-01T13.
# Nominal times: 00, 12.

# Wrapper 1 (6h): Issue times T06, T18.
# Query T13 -> Issue time T06 (nominal init T00).

# Wrapper 2 (12h): Issue times T12, T24.
# Query T13 -> Issue time T12 (nominal init T00).

# Both return init T00. Wrapper 2 should be chosen (larger latency).
init_times = np.array(['2020-01-01T13'], dtype='datetime64[ns]')
lead_times = np.array([6], dtype='timedelta64[h]')

wrapped_output = multi_wrapper.load_chunk(init_times, lead_times)

# Should match prediction_2
correct_output = data_loader_2.load_chunk(
np.array(['2020-01-01T00'], dtype='datetime64[ns]'),
# offset = 13 - 0 = 13h. lead = 6 + 13 = 19h.
np.array([19], dtype='timedelta64[h]'),
)

np.testing.assert_allclose(
wrapped_output.isel(init_time=[0])['2m_temperature'].values,
correct_output['2m_temperature'].values,
)

def test_multiple_latency_wrappers_with_missing_init_time(self):
prediction = test_utils.mock_prediction_data(
time_start='2020-01-01T00',
time_stop='2020-01-02T00',
time_resolution=np.timedelta64(24, 'h'), # Only 00
lead_start='0 hours',
lead_stop='12 hours',
lead_resolution='1 hours',
)
prediction_path = self.create_tempdir('prediction.zarr').full_path
prediction.to_zarr(prediction_path)

data_loader = xarray_loaders.PredictionsFromXarray(
path=prediction_path, variables=['2m_temperature']
)

# Loader 1: Latency 6h. Issue time T06.
# Loader 2: Latency 1h. Issue time T01.
wrapper1 = latency_wrappers.XarrayConstantLatencyWrapper(
data_loader, latency=np.timedelta64(6, 'h')
)
wrapper2 = latency_wrappers.XarrayConstantLatencyWrapper(
data_loader, latency=np.timedelta64(1, 'h')
)

multi_wrapper = latency_wrappers.MultipleConstantLatencyWrapper(
[wrapper1, wrapper2]
)

# Query time: 05.
# Wrapper 1 (6h): returns None (there is no init time before T05).
# Wrapper 2 (1h): returns issue time T06, nominal init T00.
# Wrapper 2 should be chosen.
init_times = np.array(['2020-01-01T05'], dtype='datetime64[ns]')
lead_times = np.array([1], dtype='timedelta64[h]')

wrapped_output = multi_wrapper.load_chunk(init_times, lead_times)

# Check that we got data (not all NaNs or error)
# Lead loaded = 1 + 5 = 6h.
correct_output = data_loader.load_chunk(
np.array(['2020-01-01T00'], dtype='datetime64[ns]'),
np.array([6], dtype='timedelta64[h]'),
)
np.testing.assert_allclose(
wrapped_output.isel(init_time=[0])['2m_temperature'].values,
correct_output['2m_temperature'].values,
)


if __name__ == '__main__':
absltest.main()
Loading