Skip to content

Commit dd25609

Browse files
committed
allow datetime object to be provided for temporal queries
1 parent dd61f23 commit dd25609

File tree

3 files changed

+60
-20
lines changed

3 files changed

+60
-20
lines changed

earthaccess/search.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime as dt
22
from inspect import getmembers, ismethod
3-
from typing import Any, Dict, List, Optional, Tuple, Type
3+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
44

55
import dateutil.parser as parser # type: ignore
66
from cmr import CollectionQuery, GranuleQuery # type: ignore
@@ -262,31 +262,36 @@ def get(self, limit: int = 2000) -> list:
262262
return results
263263

264264
def temporal(
265-
self, date_from: str, date_to: str, exclude_boundary: bool = False
265+
self,
266+
date_from: Optional[Union[str, dt.datetime]] = None,
267+
date_to: Optional[Union[str, dt.datetime]] = None,
268+
exclude_boundary: bool = False,
266269
) -> Type[CollectionQuery]:
267270
"""Filter by an open or closed date range. Dates can be provided as datetime objects
268271
or ISO 8601 formatted strings. Multiple ranges can be provided by successive calls.
269272
to this method before calling execute().
270273
271274
Parameters:
272-
date_from (String): earliest date of temporal range
273-
date_to (string): latest date of temporal range
275+
date_from (String or Datetime object): earliest date of temporal range
276+
date_to (String or Datetime object): latest date of temporal range
274277
exclude_boundary (Boolean): whether or not to exclude the date_from/to in the matched range
275278
"""
276279
DEFAULT = dt.datetime(1979, 1, 1)
277-
if date_from is not None:
280+
if date_from is not None and not isinstance(date_from, dt.datetime):
278281
try:
279-
parsed_from = parser.parse(date_from, default=DEFAULT).isoformat() + "Z"
282+
date_from = parser.parse(date_from, default=DEFAULT).isoformat() + "Z"
280283
except Exception:
281284
print("The provided start date was not recognized")
282-
parsed_from = ""
283-
if date_to is not None:
285+
date_from = ""
286+
287+
if date_to is not None and not isinstance(date_to, dt.datetime):
284288
try:
285-
parsed_to = parser.parse(date_to, default=DEFAULT).isoformat() + "Z"
289+
date_to = parser.parse(date_to, default=DEFAULT).isoformat() + "Z"
286290
except Exception:
287291
print("The provided end date was not recognized")
288-
parsed_to = ""
289-
super().temporal(parsed_from, parsed_to, exclude_boundary)
292+
date_to = ""
293+
294+
super().temporal(date_from, date_to, exclude_boundary)
290295
return self
291296

292297

@@ -614,8 +619,8 @@ def debug(self, debug: bool = True) -> Type[GranuleQuery]:
614619

615620
def temporal(
616621
self,
617-
date_from: Optional[str] = None,
618-
date_to: Optional[str] = None,
622+
date_from: Optional[Union[str, dt.datetime]] = None,
623+
date_to: Optional[Union[str, dt.datetime]] = None,
619624
exclude_boundary: bool = False,
620625
) -> Type[GranuleQuery]:
621626
"""Filter by an open or closed date range.
@@ -628,19 +633,21 @@ def temporal(
628633
exclude_boundary (Boolean): whether or not to exclude the date_from/to in the matched range
629634
"""
630635
DEFAULT = dt.datetime(1979, 1, 1)
631-
if date_from is not None:
636+
if date_from is not None and not isinstance(date_from, dt.datetime):
632637
try:
633-
parsed_from = parser.parse(date_from, default=DEFAULT).isoformat() + "Z"
638+
date_from = parser.parse(date_from, default=DEFAULT).isoformat() + "Z"
634639
except Exception:
635640
print("The provided start date was not recognized")
636-
parsed_from = ""
637-
if date_to is not None:
641+
date_from = ""
642+
643+
if date_to is not None and not isinstance(date_to, dt.datetime):
638644
try:
639-
parsed_to = parser.parse(date_to, default=DEFAULT).isoformat() + "Z"
645+
date_to = parser.parse(date_to, default=DEFAULT).isoformat() + "Z"
640646
except Exception:
641647
print("The provided end date was not recognized")
642-
parsed_to = ""
643-
super().temporal(parsed_from, parsed_to, exclude_boundary)
648+
date_to = ""
649+
650+
super().temporal(date_from, date_to, exclude_boundary)
644651
return self
645652

646653
def version(self, version: str = "") -> Type[GranuleQuery]:

tests/unit/test_collection_queries.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
11
# package imports
2+
import datetime as dt
3+
import pytest
24
from earthaccess.search import DataCollections
35

46

7+
valid_single_dates = [
8+
("2001-12-12", "2001-12-21", "2001-12-12T00:00:00Z,2001-12-21T00:00:00Z"),
9+
("2021-02-01", "", "2021-02-01T00:00:00Z,"),
10+
("1999-02-01 06:00", "2009-01-01", "1999-02-01T06:00:00Z,2009-01-01T00:00:00Z"),
11+
(dt.datetime(2021, 2, 1), dt.datetime(2021, 2, 2), "2021-02-01T00:00:00Z,2021-02-02T00:00:00Z")
12+
]
13+
14+
invalid_single_dates = [
15+
("2001-12-45", "2001-12-21", None),
16+
("2021w1", "", None),
17+
("2999-02-01", "2009-01-01", None),
18+
]
19+
520
def test_query_can_find_cloud_provider():
621
query = DataCollections().daac("PODAAC").cloud_hosted(True)
722
assert query.params["provider"] == "POCLOUD"
@@ -18,3 +33,19 @@ def test_querybuilder_can_handle_doi():
1833
assert query.params["doi"] == doi
1934
query = DataCollections().cloud_hosted(True).daac("PODAAC").doi(doi)
2035
assert query.params["doi"] == doi
36+
37+
38+
@pytest.mark.parametrize("start,end,expected", valid_single_dates)
39+
def test_query_can_parse_single_dates(start, end, expected):
40+
query = DataCollections().temporal(start, end)
41+
assert query.params["temporal"][0] == expected
42+
43+
44+
@pytest.mark.parametrize("start,end,expected", invalid_single_dates)
45+
def test_query_can_handle_invalid_dates(start, end, expected):
46+
query = DataCollections()
47+
try:
48+
query = query.temporal(start, end)
49+
except Exception as e:
50+
assert isinstance(e, ValueError)
51+
assert "temporal" not in query.params

tests/unit/test_granule_queries.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# package imports
2+
import datetime as dt
23
import pytest
34
from earthaccess.search import DataGranules
45

56
valid_single_dates = [
67
("2001-12-12", "2001-12-21", "2001-12-12T00:00:00Z,2001-12-21T00:00:00Z"),
78
("2021-02-01", "", "2021-02-01T00:00:00Z,"),
89
("1999-02-01 06:00", "2009-01-01", "1999-02-01T06:00:00Z,2009-01-01T00:00:00Z"),
10+
(dt.datetime(2021, 2, 1), dt.datetime(2021, 2, 2), "2021-02-01T00:00:00Z,2021-02-02T00:00:00Z")
911
]
1012

1113
invalid_single_dates = [

0 commit comments

Comments
 (0)