diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fb792360..bf487e5a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,12 +29,12 @@ jobs: with: python-versions: ${{ matrix.python-version }} - - name: Typecheck - run: nox -s typecheck - - name: Test run: nox -s tests -- --verbose --cov=earthaccess --cov-report=term-missing --capture=no --tb=native --log-cli-level=INFO + - name: Typecheck + run: nox -s typecheck + - name: Upload coverage # Don't upload coverage when using the `act` tool to run the workflow locally if: ${{ !env.ACT }} diff --git a/earthaccess/api.py b/earthaccess/api.py index 6b758aa2..992cb357 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -85,7 +85,7 @@ def search_datasets(count: int = -1, **kwargs: Any) -> List[DataCollection]: return query.get_all() -def search_data(count: int = -1, **kwargs: Any) -> List[DataGranule]: +def search_data(count: int = -1, **kwargs: Any) -> DataGranules: """Search dataset granules using NASA's CMR. [https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html](https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html) @@ -122,14 +122,13 @@ def search_data(count: int = -1, **kwargs: Any) -> List[DataGranule]: ``` """ if earthaccess.__auth__.authenticated: - query = DataGranules(earthaccess.__auth__).parameters(**kwargs) + results = DataGranules(earthaccess.__auth__).parameters(**kwargs) else: - query = DataGranules().parameters(**kwargs) - granules_found = query.hits() - logger.info(f"Granules found: {granules_found}") - if count > 0: - return query.get(count) - return query.get_all() + results = DataGranules().parameters(**kwargs) + + logger.info(f"Granules found: {results.hits()}") + results.load(count) + return results def search_services(count: int = -1, **kwargs: Any) -> List[Any]: diff --git a/earthaccess/search.py b/earthaccess/search.py index 3a2b458d..2d394010 100644 --- a/earthaccess/search.py +++ b/earthaccess/search.py @@ -418,6 +418,7 @@ def __init__(self, auth: Optional[Auth] = None, *args: Any, **kwargs: Any) -> No if auth: self.mode(auth.system.cmr_base_url) + self._granules = [] self._debug = False @override @@ -935,3 +936,50 @@ def doi(self, doi: str) -> Self: ) return self + + def load(self, count: int = -1): + if count > 0: + self.granules = self.get(count) + self.granules = self.get_all() + + @property + def granules(self) -> list: + """TODO""" + return self._granules + + @granules.setter + def granules(self, value: list): + self._granules = value + + @granules.deleter + def granules(self): + del self._granules + + def __iter__(self): + return iter(self.granules) + + def __len__(self): + return len(self.granules) + + def __getitem__(self, index: int) -> DataGranule: + # FIXME: allow slicing + # if isinstance(index, slice): + # return DataGranules(self.jobs[index]) + return self.granules[index] + + def __setitem__(self, index: int, granule: DataGranule) -> 'DataGranules': + self.granules[index] = granule + return self + + # FIXME: Is a granule in this results object? what do we use to tell? + # def __contains__(self, job: Job): + # return job in self.jobs + + def __eq__(self, other: 'DataGranules') -> bool: + # FIXME: compare query parameters too? what does it mean to be equal? + return self.granules == other.granules + + # TODO: display methods + def __repr__(self) -> str: + reprs = ", ".join([granule.__repr__() for granule in self.granules]) + return f'DataGranules([{reprs}])'