Skip to content

Commit c68e634

Browse files
committed
update typehints and code for _get methods to use Path consistently
1 parent 95ce27f commit c68e634

File tree

1 file changed

+11
-12
lines changed

1 file changed

+11
-12
lines changed

earthaccess/store.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def get(
480480
def _get(
481481
self,
482482
granules: Union[List[DataGranule], List[str]],
483-
local_path: str,
483+
local_path: Path,
484484
provider: Optional[str] = None,
485485
threads: int = 8,
486486
) -> List[str]:
@@ -508,7 +508,7 @@ def _get(
508508
def _get_urls(
509509
self,
510510
granules: List[str],
511-
local_path: str,
511+
local_path: Path,
512512
provider: Optional[str] = None,
513513
threads: int = 8,
514514
) -> List[str]:
@@ -524,8 +524,8 @@ def _get_urls(
524524
s3_fs = self.get_s3fs_session(provider=provider)
525525
# TODO: make this parallel or concurrent
526526
for file in data_links:
527-
s3_fs.get(file, local_path)
528-
file_name = Path(local_path) / Path(file).name
527+
s3_fs.get(file, str(local_path))
528+
file_name = local_path / Path(file).name
529529
print(f"Downloaded: {file_name}")
530530
downloaded_files.append(file_name)
531531
return downloaded_files
@@ -538,7 +538,7 @@ def _get_urls(
538538
def _get_granules(
539539
self,
540540
granules: List[DataGranule],
541-
local_path: str,
541+
local_path: Path,
542542
provider: Optional[str] = None,
543543
threads: int = 8,
544544
) -> List[str]:
@@ -570,15 +570,15 @@ def _get_granules(
570570
s3_fs = self.get_s3fs_session(provider=provider)
571571
# TODO: make this async
572572
for file in data_links:
573-
s3_fs.get(file, local_path)
574-
file_name = Path(local_path) / Path(file).name
573+
s3_fs.get(file, str(local_path))
574+
file_name = local_path / Path(file).name
575575
print(f"Downloaded: {file_name}")
576576
downloaded_files.append(file_name)
577577
return downloaded_files
578578
else:
579579
# if the data are cloud-based, but we are not in AWS,
580580
# it will be downloaded as if it was on prem
581-
return self._download_onprem_granules(data_links, local_path, threads)
581+
return self._download_onprem_granules(data_links, str(local_path), threads)
582582

583583
def _download_file(self, url: str, directory: str) -> str:
584584
"""Download a single file from an on-prem location, a DAAC data center.
@@ -595,8 +595,7 @@ def _download_file(self, url: str, directory: str) -> str:
595595
url = url.replace(".html", "")
596596
local_filename = url.split("/")[-1]
597597
path = Path(directory) / Path(local_filename)
598-
local_path = str(path)
599-
if not Path(local_path).exists():
598+
if not path.exists():
600599
try:
601600
session = self.auth.get_session()
602601
with session.get(
@@ -605,7 +604,7 @@ def _download_file(self, url: str, directory: str) -> str:
605604
allow_redirects=True,
606605
) as r:
607606
r.raise_for_status()
608-
with open(local_path, "wb") as f:
607+
with open(path, "wb") as f:
609608
# This is to cap memory usage for large files at 1MB per write to disk per thread
610609
# https://docs.python-requests.org/en/latest/user/quickstart/#raw-response-content
611610
shutil.copyfileobj(r.raw, f, length=1024 * 1024)
@@ -615,7 +614,7 @@ def _download_file(self, url: str, directory: str) -> str:
615614
raise Exception
616615
else:
617616
print(f"File {local_filename} already downloaded")
618-
return local_path
617+
return str(path)
619618

620619
def _download_onprem_granules(
621620
self, urls: List[str], directory: str, threads: int = 8

0 commit comments

Comments
 (0)