Skip to content

Commit 27d0277

Browse files
committed
add automatic host guess in download_data
1 parent fe1995c commit 27d0277

File tree

4 files changed

+57
-2
lines changed

4 files changed

+57
-2
lines changed

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ heasarc
7575
- Add ``query_by_column`` to allow querying of different catalog columns. [#3403]
7676
- Add support for uploading tables when using TAP directly through ``query_tap``. [#3403]
7777
- Improve how maxrec works. If it is bigger than the default server limit, add a TOP statement. [#3403]
78+
- Add automatic guessing for the data host in ``download_data``. [#3403]
7879

7980
alma
8081
^^^^

astroquery/heasarc/core.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,36 @@ def enable_cloud(self, provider='aws', profile=None):
750750

751751
self.s3_client = self.s3_resource.meta.client
752752

753+
def _guess_host(self, host):
754+
"""Guess the host to use for downloading data
755+
756+
Parameters
757+
----------
758+
host : str
759+
The host provided by the user
760+
761+
Returns
762+
-------
763+
host : str
764+
The guessed host
765+
766+
"""
767+
if host in ['heasarc', 'sciserver', 'aws']:
768+
return host
769+
elif host is not None:
770+
raise ValueError(
771+
'host has to be one of heasarc, sciserver, aws or None')
772+
773+
# host is None, so we guess
774+
if os.environ['HOME'] == '/home/idies' and os.path.exists('/FTP/'):
775+
# we are on idies, so we can use sciserver
776+
return 'sciserver'
777+
778+
for var in ['AWS_REGION', 'AWS_DEFAULT_REGION', 'AWS_ROLE_ARN']:
779+
if var in os.environ:
780+
return 'aws'
781+
return 'heasarc'
782+
753783
def download_data(self, links, host='heasarc', location='.'):
754784
"""Download data products in links with a choice of getting the
755785
data from either the heasarc server, sciserver, or the cloud in AWS.
@@ -781,8 +811,8 @@ def download_data(self, links, host='heasarc', location='.'):
781811
if isinstance(links, Row):
782812
links = links.table[[links.index]]
783813

784-
if host not in ['heasarc', 'sciserver', 'aws']:
785-
raise ValueError('host has to be one of heasarc, sciserver, aws')
814+
# guess the host if not provided
815+
host = self._guess_host(host)
786816

787817
host_column = 'access_url' if host == 'heasarc' else host
788818
if host_column not in links.colnames:

astroquery/heasarc/tests/test_heasarc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,29 @@ def test_locate_data_row():
506506
Heasarc.locate_data(table[0:2], catalog_name="xray")
507507

508508

509+
def test__guess_host_default():
510+
# Use a new HeasarcClass object
511+
assert Heasarc._guess_host(host=None) == 'heasarc'
512+
513+
514+
@pytest.mark.parametrize("host", ["heasarc", "sciserver", "aws"])
515+
def test__guess_host_know(host):
516+
# Use a new HeasarcClass object
517+
assert Heasarc._guess_host(host=host) == host
518+
519+
520+
def test__guess_host_sciserver(monkeypatch):
521+
monkeypatch.setenv("HOME", "/home/idies")
522+
monkeypatch.setattr("os.path.exists", lambda path: path.startswith('/FTP'))
523+
assert Heasarc._guess_host(host=None) == 'sciserver'
524+
525+
526+
@pytest.mark.parametrize("var", ["AWS_REGION", "AWS_REGION_DEFAULT", "AWS_ROLE_ARN"])
527+
def test__guess_host_aws(monkeypatch, var):
528+
monkeypatch.setenv("AWS_REGION", var)
529+
assert Heasarc._guess_host(host=None) == 'aws'
530+
531+
509532
def test_download_data__empty():
510533
with pytest.raises(ValueError, match="Input links table is empty"):
511534
Heasarc.download_data(Table())

docs/heasarc/heasarc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ You can specify where the data are to be downloaded using the ``location`` param
247247

248248
To download the data, you can pass ``links`` table (or row) to `~astroquery.heasarc.HeasarcClass.download_data`,
249249
specifying from where you want the data to be fetched by specifying the ``host`` parameter. By default,
250+
the function will try to guess the best host based on your environment. If it cannot guess, then
250251
the data is fetched from the main HEASARC servers.
251252
The recommendation is to use different hosts depending on where your code is running:
252253
* ``host='sciserver'``: Use this option if you running you analysis on Sciserver. Because

0 commit comments

Comments
 (0)