diff --git a/weatherbenchX/data_loaders/latency_wrappers.py b/weatherbenchX/data_loaders/latency_wrappers.py index a3d96c7..2c2c12b 100644 --- a/weatherbenchX/data_loaders/latency_wrappers.py +++ b/weatherbenchX/data_loaders/latency_wrappers.py @@ -91,13 +91,18 @@ def __init__( process_chunk_fn=data_loader._process_chunk_fn, ) - def get_available_init_time(self, init_time: np.datetime64) -> np.datetime64: + def get_available_init_time( + self, init_time: np.datetime64 + ) -> np.datetime64 | None: """Return most recent available nominal init time for requested init time.""" issue_time = self.nominal_init_times + self.latency diff = (issue_time - init_time).astype(int) # Find index of issue time that is closest to requested init_time. # on the left, i.e. with issue_time > nominal init_time. - available_idx = np.nanargmax(np.where(diff <= 0, diff, np.nan)) + diff = np.where(diff <= 0, diff, np.nan) + if np.all(np.isnan(diff)): + return None + available_idx = np.nanargmax(diff) available_init_time = self.nominal_init_times[available_idx] return available_init_time @@ -259,13 +264,20 @@ def _get_data_loader(self, init_time): lead_time_offsets_and_latencies = [] for data_loader in self._data_loaders: available_init_time = data_loader.get_available_init_time(init_time) - lead_time_offset = init_time - available_init_time - # Break ties by picking the data loader with largest latency -- note that - # we make latency negative here because we want the smallest - # lead_time_offset, but the largest data loader latency. - lead_time_offsets_and_latencies.append( - (lead_time_offset, -data_loader.latency) - ) + if available_init_time is None: + # If there is no available init time, we will assign an infinite lead + # time offset and latency. Since there is no actual "inf" timedelta, + # we'll just use 1e6 days. + inf_time = np.timedelta64(int(1e6), 'D') + lead_time_offsets_and_latencies.append((inf_time, inf_time)) + else: + lead_time_offset = init_time - available_init_time + # Break ties by picking the data loader with largest latency -- note + # that we make latency negative here because we want the smallest + # lead_time_offset, but the largest data loader latency. + lead_time_offsets_and_latencies.append( + (lead_time_offset, -data_loader.latency) + ) lead_time_offsets_and_latencies = np.array( lead_time_offsets_and_latencies, dtype=[ @@ -281,7 +293,7 @@ def _get_data_loader(self, init_time): logging.info( 'Init time: %s, data loader latency: %s', init_time, - most_recent_data_loader.latency, + most_recent_data_loader.latency.astype('timedelta64[m]'), ) return most_recent_data_loader diff --git a/weatherbenchX/data_loaders/latency_wrappers_test.py b/weatherbenchX/data_loaders/latency_wrappers_test.py index 82904ab..8c25c9f 100644 --- a/weatherbenchX/data_loaders/latency_wrappers_test.py +++ b/weatherbenchX/data_loaders/latency_wrappers_test.py @@ -159,6 +159,128 @@ def test_multiple_latency_wrappers(self): correct_output['2m_temperature'].values, ) + def test_multiple_latency_wrappers_tie_breaking(self): + # Setup two loaders with same nominal init times but different latencies. + # We want a case where for a specific query time, they both return the SAME + # available init time. + prediction_1 = test_utils.mock_prediction_data( + time_start='2020-01-01T00', + time_stop='2020-01-02T00', + time_resolution=np.timedelta64(12, 'h'), # 00, 12 + lead_start='0 hours', + lead_stop='24 hours', + lead_resolution='1 hours', + ) + 1.0 # Add 1 to distinguish + prediction_2 = test_utils.mock_prediction_data( + time_start='2020-01-01T00', + time_stop='2020-01-02T00', + time_resolution=np.timedelta64(12, 'h'), # 00, 12 + lead_start='0 hours', + lead_stop='24 hours', + lead_resolution='1 hours', + ) + 2.0 # Add 2 to distinguish + + prediction_path_1 = self.create_tempdir('prediction_1.zarr').full_path + prediction_path_2 = self.create_tempdir('prediction_2.zarr').full_path + prediction_1.to_zarr(prediction_path_1) + prediction_2.to_zarr(prediction_path_2) + + data_loader_1 = xarray_loaders.PredictionsFromXarray( + path=prediction_path_1, variables=['2m_temperature'] + ) + data_loader_2 = xarray_loaders.PredictionsFromXarray( + path=prediction_path_2, variables=['2m_temperature'] + ) + + # Loader 1: Latency 6h. + # Loader 2: Latency 12h. + wrapper1 = latency_wrappers.XarrayConstantLatencyWrapper( + data_loader_1, latency=np.timedelta64(6, 'h') + ) + wrapper2 = latency_wrappers.XarrayConstantLatencyWrapper( + data_loader_2, latency=np.timedelta64(12, 'h') + ) + + multi_wrapper = latency_wrappers.MultipleConstantLatencyWrapper( + [wrapper1, wrapper2] + ) + + # Query time: 2020-01-01T13. + # Nominal times: 00, 12. + + # Wrapper 1 (6h): Issue times T06, T18. + # Query T13 -> Issue time T06 (nominal init T00). + + # Wrapper 2 (12h): Issue times T12, T24. + # Query T13 -> Issue time T12 (nominal init T00). + + # Both return init T00. Wrapper 2 should be chosen (larger latency). + init_times = np.array(['2020-01-01T13'], dtype='datetime64[ns]') + lead_times = np.array([6], dtype='timedelta64[h]') + + wrapped_output = multi_wrapper.load_chunk(init_times, lead_times) + + # Should match prediction_2 + correct_output = data_loader_2.load_chunk( + np.array(['2020-01-01T00'], dtype='datetime64[ns]'), + # offset = 13 - 0 = 13h. lead = 6 + 13 = 19h. + np.array([19], dtype='timedelta64[h]'), + ) + + np.testing.assert_allclose( + wrapped_output.isel(init_time=[0])['2m_temperature'].values, + correct_output['2m_temperature'].values, + ) + + def test_multiple_latency_wrappers_with_missing_init_time(self): + prediction = test_utils.mock_prediction_data( + time_start='2020-01-01T00', + time_stop='2020-01-02T00', + time_resolution=np.timedelta64(24, 'h'), # Only 00 + lead_start='0 hours', + lead_stop='12 hours', + lead_resolution='1 hours', + ) + prediction_path = self.create_tempdir('prediction.zarr').full_path + prediction.to_zarr(prediction_path) + + data_loader = xarray_loaders.PredictionsFromXarray( + path=prediction_path, variables=['2m_temperature'] + ) + + # Loader 1: Latency 6h. Issue time T06. + # Loader 2: Latency 1h. Issue time T01. + wrapper1 = latency_wrappers.XarrayConstantLatencyWrapper( + data_loader, latency=np.timedelta64(6, 'h') + ) + wrapper2 = latency_wrappers.XarrayConstantLatencyWrapper( + data_loader, latency=np.timedelta64(1, 'h') + ) + + multi_wrapper = latency_wrappers.MultipleConstantLatencyWrapper( + [wrapper1, wrapper2] + ) + + # Query time: 05. + # Wrapper 1 (6h): returns None (there is no init time before T05). + # Wrapper 2 (1h): returns issue time T06, nominal init T00. + # Wrapper 2 should be chosen. + init_times = np.array(['2020-01-01T05'], dtype='datetime64[ns]') + lead_times = np.array([1], dtype='timedelta64[h]') + + wrapped_output = multi_wrapper.load_chunk(init_times, lead_times) + + # Check that we got data (not all NaNs or error) + # Lead loaded = 1 + 5 = 6h. + correct_output = data_loader.load_chunk( + np.array(['2020-01-01T00'], dtype='datetime64[ns]'), + np.array([6], dtype='timedelta64[h]'), + ) + np.testing.assert_allclose( + wrapped_output.isel(init_time=[0])['2m_temperature'].values, + correct_output['2m_temperature'].values, + ) + if __name__ == '__main__': absltest.main()