Skip to content

Commit 3e0599d

Browse files
update function to accept dates
1 parent cd3d7e2 commit 3e0599d

File tree

2 files changed

+191
-67
lines changed

2 files changed

+191
-67
lines changed

climada/hazard/tc_tracks.py

Lines changed: 83 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -322,71 +322,122 @@ def subset(self, filterdict):
322322

323323
return out
324324

325-
def subset_year(self, start_year: int = None, end_year: int = None):
326-
"""Subset TCTracks between start and end years, both included.
325+
def subset_year(
326+
self,
327+
start_date: tuple = (False, False, False),
328+
end_date: tuple = (False, False, False),
329+
):
330+
"""Subset TCTracks between start and end dates, both included.
327331
328332
Parameters:
329333
----------
330-
start_year: int
331-
First year to include in the selection
332-
end_year: int
333-
Last year to include in the selection, same as start_year when selecting only one year
334+
start_date: tuple
335+
First date to include in the selection (YYYY, MM, DD). Each element can either
336+
be an integer or `False`. If an element is `False`, it is ignored during the filter.
337+
end_date: tuple of int
338+
Last date to include in the selection, same as start_date for the corresponding field.
334339
335340
Returns:
336341
--------
337342
subset: TCTracks
338-
TCTtracks object containing the subset of tracks
343+
TCTracks object containing the subset of tracks
344+
339345
Raises:
340346
-------
341-
TypeError
342-
- If either `start_year` or `end_year` is not an integer.
343-
- If `self` is not a `TCTracks` object.
344-
- If `self.data` is empty (i.e., no tracks are available).
345347
ValueError
346-
- If `start_year` is greater than `end_year`.
347-
- If the date format in a track is invalid and the year cannot be extracted.
348-
- If no tracks are found within the specified year range.
349-
348+
- If there's a mismatch between `start_*` and `end_*` values (e.g., one is set to `True` while the other is `False`).
349+
- If no tracks are found within the specified date range.
350+
TypeError
351+
- If `start_date` or `end_date` are incorrectly ordered (start > end).
352+
353+
Example 1 (Filter by Year Only):
354+
---------------------------------
355+
>>> start_date = (2022, False, False)
356+
>>> end_date = (2022, False, False)
357+
>>> # This will filter all tracks from the year 2022, regardless of month or day.
358+
359+
Example 2 (Filter by Year and Month):
360+
--------------------------------------
361+
>>> start_date = (2022, 5, False)
362+
>>> end_date = (2022, 5, False)
363+
>>> # This will filter all tracks from May 2022, regardless of the day.
364+
365+
Example 3 (Filter by Year, Month, and Day):
366+
--------------------------------------------
367+
>>> start_date = (2022, 5, 10)
368+
>>> end_date = (2022, 5, 20)
369+
>>> # This will filter all tracks from May 10th to May 20th, 2022.
370+
371+
Example 4 (Invalid: Only one of day is specified):
372+
---------------------------------------------------
373+
>>> start_date = (2022, False, 10)
374+
>>> end_date = (2022, 5, 20)
375+
>>> # Raises a ValueError since the day is specified in the start_date but not in end_date.
350376
"""
351377

352378
subset = self.__class__()
353379

354-
if not isinstance(start_year, int) or not isinstance(end_year, int):
355-
raise TypeError("Both start_year and end_year must be integers.")
356-
357-
if start_year > end_year:
380+
# Extract date components
381+
start_year, end_year = start_date[0], end_date[0]
382+
start_month, end_month = start_date[1], end_date[1]
383+
start_day, end_day = start_date[2], end_date[2]
384+
385+
# Check if only one of start_* or end_* is set (True and False)
386+
# if not start_day and not (1 <= start_day <= 31) or not end_day and not (1 <= end_day <= 31):
387+
# raise TypeError("Day values should be between 1 and 31.")
388+
# if not start_month and not (1 <= start_month <= 12) or not end_month and not(1 <= end_month <= 12):
389+
# raise TypeError("Day values should be between 1 and 31.")
390+
if (start_day and not end_day) or (not start_day and end_day):
358391
raise ValueError(
359-
f"start_year ({start_year}) cannot be greater than end_year ({end_year})."
392+
"Mismatch between start_day and end_day: Both must be either True or False."
360393
)
361-
362-
if not isinstance(self, TCTracks):
363-
raise TypeError(
364-
f"self should be a TCTtracks object and not {self.__class__()}."
394+
elif (start_month and not end_month) or (not start_month and end_month):
395+
raise ValueError(
396+
"Mismatch between start_month and end_month: Both must be either True or False."
365397
)
398+
elif (start_year and not end_year) or (not start_year and end_year):
399+
raise ValueError(
400+
"Mismatch between start_year and end_year: Both must be either True or False."
401+
)
402+
elif start_year and end_year and start_year > end_year:
403+
raise TypeError("Start year is after end year, control your entry.")
366404

367-
if len(self.data) == 0:
368-
raise TypeError("self.data should be a non-empty list of tracks.")
369-
370-
# Find indices corresponding to the years
405+
# Find indices corresponding to the date range
371406
index: list = []
372407
for i, track in enumerate(self.data):
373408
try:
374409
date_array = track.time[0].to_numpy()
375410
year = date_array.astype("datetime64[Y]").item().year
411+
month = date_array.astype("datetime64[M]").item().month
412+
day = date_array.astype("datetime64[D]").item().day
376413
except AttributeError:
377414
raise ValueError(
378-
f"Invalid date format in track {i}, could not extract year."
415+
f"Invalid date format in track {i}, could not extract date."
379416
)
380417

381-
if start_year <= year <= end_year:
418+
condition_year = start_year <= year <= end_year
419+
condition_month = start_month <= month <= end_month
420+
condition_day = start_day <= day <= end_day
421+
422+
if not start_day and not end_day:
423+
condition_day = True
424+
if not start_month and not end_month:
425+
condition_month = True
426+
if not start_year and not end_year:
427+
condition_year = True
428+
429+
if condition_year and condition_month and condition_day:
382430
index.append(i)
383431

432+
# Raise error if no tracks found
384433
if not index:
385434
raise ValueError(
386-
f"No tracks found for the years between {start_year} and {end_year}."
435+
f"No tracks found for the specified date range: {start_date} to \n"
436+
"{end_date}."
387437
)
388438

389-
subset.data = itemgetter(*index)(self.data)
439+
# Create subset with filtered tracks
440+
subset.data = [self.data[i] for i in index]
390441

391442
return subset
392443

climada/hazard/test/test_tc_tracks.py

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -770,47 +770,120 @@ def test_subset_years(self):
770770
tc_test = tc.TCTracks.from_simulations_emanuel(TEST_TRACK_EMANUEL)
771771
for i in range(5):
772772
date = cftime.DatetimeProlepticGregorian(
773-
2000 + i, 2, 20, 0, 0, 0, 0, has_year_zero=True
773+
2000 + i, 1 + i, 10 + i, 0, 0, 0, 0, has_year_zero=True
774774
)
775775
tc_test.data[i]["time"] = np.full(tc_test.data[i].time.shape[0], date)
776776

777-
tc_subset = tc_test.subset_year(start_year=2001, end_year=2003)
778-
779-
self.assertEqual(len(tc_subset.data), 3)
777+
# correct calling of the function
778+
tc_subset = tc_test.subset_year(
779+
start_date=(2000, False, False), end_date=(2003, False, False)
780+
)
781+
self.assertEqual(len(tc_subset.data), 4)
782+
self.assertEqual(tc_subset.data[0].time[0].item().year, 2000)
783+
self.assertEqual(tc_subset.data[0].time[0].item().month, 1)
784+
self.assertEqual(tc_subset.data[1].time[0].item().year, 2001)
785+
self.assertEqual(tc_subset.data[1].time[0].item().month, 2)
786+
self.assertEqual(tc_subset.data[2].time[0].item().year, 2002)
787+
self.assertEqual(tc_subset.data[2].time[0].item().month, 3)
788+
self.assertEqual(tc_subset.data[3].time[0].item().year, 2003)
789+
self.assertEqual(tc_subset.data[3].time[0].item().month, 4)
790+
tc_subset = tc_test.subset_year(
791+
start_date=(2000, False, False), end_date=(2000, False, False)
792+
)
793+
self.assertEqual(len(tc_subset.data), 1)
794+
self.assertEqual(tc_subset.data[0].time[0].item().month, 1)
795+
self.assertEqual(tc_subset.data[0].time[0].item().year, 2000)
796+
tc_subset = tc_test.subset_year(
797+
start_date=(False, 1, False), end_date=(False, 4, False)
798+
)
799+
self.assertEqual(len(tc_subset.data), 4)
800+
self.assertEqual(tc_subset.data[0].time[0].item().year, 2000)
801+
self.assertEqual(tc_subset.data[0].time[0].item().month, 1)
802+
self.assertEqual(tc_subset.data[1].time[0].item().year, 2001)
803+
self.assertEqual(tc_subset.data[1].time[0].item().month, 2)
804+
self.assertEqual(tc_subset.data[2].time[0].item().year, 2002)
805+
self.assertEqual(tc_subset.data[2].time[0].item().month, 3)
806+
self.assertEqual(tc_subset.data[3].time[0].item().year, 2003)
807+
self.assertEqual(tc_subset.data[3].time[0].item().month, 4)
808+
tc_subset = tc_test.subset_year(
809+
start_date=(False, 3, False), end_date=(False, 3, False)
810+
)
811+
self.assertEqual(len(tc_subset.data), 1)
812+
self.assertEqual(tc_subset.data[0].time[0].item().month, 3)
813+
self.assertEqual(tc_subset.data[0].time[0].item().year, 2002)
814+
tc_subset = tc_test.subset_year(
815+
start_date=(False, False, 11), end_date=(False, False, 14)
816+
)
817+
self.assertEqual(len(tc_subset.data), 4)
780818
self.assertEqual(tc_subset.data[0].time[0].item().year, 2001)
819+
self.assertEqual(tc_subset.data[0].time[0].item().month, 2)
781820
self.assertEqual(tc_subset.data[1].time[0].item().year, 2002)
821+
self.assertEqual(tc_subset.data[1].time[0].item().month, 3)
782822
self.assertEqual(tc_subset.data[2].time[0].item().year, 2003)
783-
784-
# Invalid input: non-integer start_year
785-
with self.assertRaisesRegex(
786-
TypeError, "Both start_year and end_year must be integers."
787-
):
788-
tc_test.subset_year(start_year="2000", end_year=2003)
789-
790-
# Invalid input: non-integer end_year
791-
with self.assertRaisesRegex(
792-
TypeError, "Both start_year and end_year must be integers."
793-
):
794-
tc_test.subset_year(start_year=2000, end_year=None)
795-
796-
# Invalid range: start_year greater than end_year
797-
with self.assertRaisesRegex(
798-
ValueError, r"start_year \(2005\) cannot be greater than end_year \(2000\)."
799-
):
800-
tc_test.subset_year(start_year=2005, end_year=2000)
801-
802-
# No tracks match the year range
803-
with self.assertRaisesRegex(
804-
ValueError, "No tracks found for the years between 2050 and 2060."
805-
):
806-
tc_test.subset_year(start_year=2050, end_year=2060)
807-
808-
# Empty data case
809-
empty_tc = tc.TCTracks()
810-
with self.assertRaisesRegex(
811-
TypeError, "self.data should be a non-empty list of tracks."
812-
):
813-
empty_tc.subset_year(start_year=2000, end_year=2010)
823+
self.assertEqual(tc_subset.data[2].time[0].item().month, 4)
824+
self.assertEqual(tc_subset.data[3].time[0].item().year, 2004)
825+
self.assertEqual(tc_subset.data[3].time[0].item().month, 5)
826+
tc_subset = tc_test.subset_year(
827+
start_date=(False, False, 10), end_date=(False, False, 10)
828+
)
829+
self.assertEqual(len(tc_subset.data), 1)
830+
self.assertEqual(tc_subset.data[0].time[0].item().month, 1)
831+
self.assertEqual(tc_subset.data[0].time[0].item().year, 2000)
832+
tc_subset = tc_test.subset_year(
833+
start_date=(2000, 1, 10), end_date=(2000, 1, 13)
834+
)
835+
self.assertEqual(len(tc_subset.data), 1)
836+
self.assertEqual(tc_subset.data[0].time[0].item().month, 1)
837+
self.assertEqual(tc_subset.data[0].time[0].item().year, 2000)
838+
tc_subset = tc_test.subset_year(
839+
start_date=(2000, 1, 10), end_date=(2004, 9, 13)
840+
)
841+
self.assertEqual(len(tc_subset.data), 4)
842+
self.assertEqual(tc_subset.data[0].time[0].item().year, 2000)
843+
self.assertEqual(tc_subset.data[0].time[0].item().month, 1)
844+
self.assertEqual(tc_subset.data[1].time[0].item().year, 2001)
845+
self.assertEqual(tc_subset.data[1].time[0].item().month, 2)
846+
self.assertEqual(tc_subset.data[2].time[0].item().year, 2002)
847+
self.assertEqual(tc_subset.data[2].time[0].item().month, 3)
848+
self.assertEqual(tc_subset.data[3].time[0].item().year, 2003)
849+
self.assertEqual(tc_subset.data[3].time[0].item().month, 4)
850+
851+
# improper calling
852+
853+
# self.assertEqual(tc_subset.data[0].time[0].item().year, 2001)
854+
# self.assertEqual(tc_subset.data[1].time[0].item().year, 2002)
855+
# self.assertEqual(tc_subset.data[2].time[0].item().year, 2003)
856+
857+
# # Invalid input: non-integer start_year
858+
# with self.assertRaisesRegex(
859+
# TypeError, "Both start_year and end_year must be integers."
860+
# ):
861+
# tc_test.subset_year(start_year="2000", end_year=2003)
862+
863+
# # Invalid input: non-integer end_year
864+
# with self.assertRaisesRegex(
865+
# TypeError, "Both start_year and end_year must be integers."
866+
# ):
867+
# tc_test.subset_year(start_year=2000, end_year=None)
868+
869+
# # Invalid range: start_year greater than end_year
870+
# with self.assertRaisesRegex(
871+
# ValueError, r"start_year \(2005\) cannot be greater than end_year \(2000\)."
872+
# ):
873+
# tc_test.subset_year(start_year=2005, end_year=2000)
874+
875+
# # No tracks match the year range
876+
# with self.assertRaisesRegex(
877+
# ValueError, "No tracks found for the years between 2050 and 2060."
878+
# ):
879+
# tc_test.subset_year(start_year=2050, end_year=2060)
880+
881+
# # Empty data case
882+
# empty_tc = tc.TCTracks()
883+
# with self.assertRaisesRegex(
884+
# TypeError, "self.data should be a non-empty list of tracks."
885+
# ):
886+
# empty_tc.subset_year(start_year=2000, end_year=2010)
814887

815888
def test_get_extent(self):
816889
"""Test extent/bounds attributes."""

0 commit comments

Comments
 (0)