Skip to content

Commit 8b00ea3

Browse files
authored
Merge pull request #459 from kvenkman/improvement/use-pathlib
Improvement/use pathlib
2 parents 6b86f49 + 453e7f9 commit 8b00ea3

File tree

6 files changed

+28
-32
lines changed

6 files changed

+28
-32
lines changed

earthaccess/auth.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,7 @@ def _netrc(self) -> bool:
257257
try:
258258
my_netrc = Netrc()
259259
except FileNotFoundError as err:
260-
raise FileNotFoundError(
261-
f"No .netrc found in {os.path.expanduser('~')}"
262-
) from err
260+
raise FileNotFoundError(f"No .netrc found in {Path.home()}") from err
263261
except NetrcParseError as err:
264262
raise NetrcParseError("Unable to parse .netrc") from err
265263
if my_netrc["urs.earthdata.nasa.gov"] is not None:
@@ -365,7 +363,7 @@ def _persist_user_credentials(self, username: str, password: str) -> bool:
365363
try:
366364
netrc_path = Path().home().joinpath(".netrc")
367365
netrc_path.touch(exist_ok=True)
368-
os.chmod(netrc_path.absolute(), 0o600)
366+
netrc_path.chmod(0o600)
369367
except Exception as e:
370368
print(e)
371369
return False

earthaccess/store.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import datetime
2-
import os
32
import shutil
43
import traceback
54
from functools import lru_cache
@@ -443,7 +442,7 @@ def _open_urls(
443442
def get(
444443
self,
445444
granules: Union[List[DataGranule], List[str]],
446-
local_path: Optional[str] = None,
445+
local_path: Optional[Path] = None,
447446
provider: Optional[str] = None,
448447
threads: int = 8,
449448
) -> List[str]:
@@ -466,11 +465,10 @@ def get(
466465
List of downloaded files
467466
"""
468467
if local_path is None:
469-
local_path = os.path.join(
470-
".",
471-
"data",
472-
f"{datetime.datetime.today().strftime('%Y-%m-%d')}-{uuid4().hex[:6]}",
473-
)
468+
today = datetime.datetime.today().strftime("%Y-%m-%d")
469+
uuid = uuid4().hex[:6]
470+
local_path = Path.cwd() / "data" / f"{today}-{uuid}"
471+
474472
if len(granules):
475473
files = self._get(granules, local_path, provider, threads)
476474
return files
@@ -481,7 +479,7 @@ def get(
481479
def _get(
482480
self,
483481
granules: Union[List[DataGranule], List[str]],
484-
local_path: str,
482+
local_path: Path,
485483
provider: Optional[str] = None,
486484
threads: int = 8,
487485
) -> List[str]:
@@ -509,7 +507,7 @@ def _get(
509507
def _get_urls(
510508
self,
511509
granules: List[str],
512-
local_path: str,
510+
local_path: Path,
513511
provider: Optional[str] = None,
514512
threads: int = 8,
515513
) -> List[str]:
@@ -525,8 +523,8 @@ def _get_urls(
525523
s3_fs = self.get_s3fs_session(provider=provider)
526524
# TODO: make this parallel or concurrent
527525
for file in data_links:
528-
s3_fs.get(file, local_path)
529-
file_name = os.path.join(local_path, os.path.basename(file))
526+
s3_fs.get(file, str(local_path))
527+
file_name = local_path / Path(file).name
530528
print(f"Downloaded: {file_name}")
531529
downloaded_files.append(file_name)
532530
return downloaded_files
@@ -539,7 +537,7 @@ def _get_urls(
539537
def _get_granules(
540538
self,
541539
granules: List[DataGranule],
542-
local_path: str,
540+
local_path: Path,
543541
provider: Optional[str] = None,
544542
threads: int = 8,
545543
) -> List[str]:
@@ -571,8 +569,8 @@ def _get_granules(
571569
s3_fs = self.get_s3fs_session(provider=provider)
572570
# TODO: make this async
573571
for file in data_links:
574-
s3_fs.get(file, local_path)
575-
file_name = os.path.join(local_path, os.path.basename(file))
572+
s3_fs.get(file, str(local_path))
573+
file_name = local_path / Path(file).name
576574
print(f"Downloaded: {file_name}")
577575
downloaded_files.append(file_name)
578576
return downloaded_files
@@ -581,7 +579,7 @@ def _get_granules(
581579
# it will be downloaded as if it was on prem
582580
return self._download_onprem_granules(data_links, local_path, threads)
583581

584-
def _download_file(self, url: str, directory: str) -> str:
582+
def _download_file(self, url: str, directory: Path) -> str:
585583
"""Download a single file from an on-prem location, a DAAC data center.
586584
587585
Parameters:
@@ -595,9 +593,8 @@ def _download_file(self, url: str, directory: str) -> str:
595593
if "opendap" in url and url.endswith(".html"):
596594
url = url.replace(".html", "")
597595
local_filename = url.split("/")[-1]
598-
path = Path(directory) / Path(local_filename)
599-
local_path = str(path)
600-
if not os.path.exists(local_path):
596+
path = directory / Path(local_filename)
597+
if not path.exists():
601598
try:
602599
session = self.auth.get_session()
603600
with session.get(
@@ -606,7 +603,7 @@ def _download_file(self, url: str, directory: str) -> str:
606603
allow_redirects=True,
607604
) as r:
608605
r.raise_for_status()
609-
with open(local_path, "wb") as f:
606+
with open(path, "wb") as f:
610607
# This is to cap memory usage for large files at 1MB per write to disk per thread
611608
# https://docs.python-requests.org/en/latest/user/quickstart/#raw-response-content
612609
shutil.copyfileobj(r.raw, f, length=1024 * 1024)
@@ -616,10 +613,10 @@ def _download_file(self, url: str, directory: str) -> str:
616613
raise Exception
617614
else:
618615
print(f"File {local_filename} already downloaded")
619-
return local_path
616+
return str(path)
620617

621618
def _download_onprem_granules(
622-
self, urls: List[str], directory: str, threads: int = 8
619+
self, urls: List[str], directory: Path, threads: int = 8
623620
) -> List[Any]:
624621
"""Downloads a list of URLS into the data directory.
625622
@@ -638,8 +635,7 @@ def _download_onprem_granules(
638635
raise ValueError(
639636
"We need to be logged into NASA EDL in order to download data granules"
640637
)
641-
if not os.path.exists(directory):
642-
os.makedirs(directory)
638+
directory.mkdir(parents=True, exist_ok=True)
643639

644640
arguments = [(url, directory) for url in urls]
645641
results = pqdm(

tests/integration/test_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import unittest
5+
from pathlib import Path
56

67
import earthaccess
78
import pytest
@@ -84,7 +85,7 @@ def test_download(tmp_path, selection, use_url):
8485
result = results[selection]
8586
files = earthaccess.download(result, str(tmp_path))
8687
assertions.assertIsInstance(files, list)
87-
assert all(os.path.exists(f) for f in files)
88+
assert all(Path(f).exists() for f in files)
8889

8990

9091
def test_auth_environ():

tests/integration/test_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def activate_netrc():
3030
f.write(
3131
f"machine urs.earthdata.nasa.gov login {username} password {password}\n"
3232
)
33-
os.chmod(NETRC_PATH, 0o600)
33+
NETRC_PATH.chmod(0o600)
3434

3535

3636
def delete_netrc():

tests/integration/test_cloud_download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,4 +166,4 @@ def test_multi_file_granule(tmp_path):
166166
urls = granules[0].data_links()
167167
assert len(urls) > 1
168168
files = earthaccess.download(granules, str(tmp_path))
169-
assert set(map(os.path.basename, urls)) == set(map(os.path.basename, files))
169+
assert set([Path(f).name for f in urls]) == set([Path(f).name for f in files])

tests/integration/test_kerchunk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import os
33
import unittest
4+
from pathlib import Path
45

56
import earthaccess
67
import pytest
@@ -32,14 +33,14 @@ def granules():
3233
@pytest.mark.parametrize("protocol", ["", "file://"])
3334
def test_consolidate_metadata_outfile(tmp_path, granules, protocol):
3435
outfile = f"{protocol}{tmp_path / 'metadata.json'}"
35-
assert not os.path.exists(outfile)
36+
assert not Path(outfile).exists()
3637
result = earthaccess.consolidate_metadata(
3738
granules,
3839
outfile=outfile,
3940
access="indirect",
4041
kerchunk_options={"concat_dims": "Time"},
4142
)
42-
assert os.path.exists(strip_protocol(outfile))
43+
assert Path(strip_protocol(outfile)).exists()
4344
assert result == outfile
4445

4546

0 commit comments

Comments
 (0)