Skip to content

Commit f6799dc

Browse files
authored
Merge pull request #2532 from weaverba137/restore-query-payload
Restore get_query_payload to all methods
2 parents 4362450 + ab8d5b8 commit f6799dc

File tree

4 files changed

+119
-25
lines changed

4 files changed

+119
-25
lines changed

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ sdss
141141

142142
- The default data release has been changed to DR17. [#2478]
143143

144+
- Optional keyword arguments are now keyword only. [#2477, #2532]
144145

145146

146147
Infrastructure, Utility and Other Changes and Additions

astroquery/sdss/core.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"""
33
Access Sloan Digital Sky Survey database online.
44
"""
5-
import io
65
import warnings
76
import numpy as np
87

@@ -518,9 +517,9 @@ class = 'galaxy' \
518517
timeout=timeout, cache=cache)
519518
return response
520519

521-
def get_spectra_async(self, coordinates=None, radius=2. * u.arcsec,
520+
def get_spectra_async(self, *, coordinates=None, radius=2. * u.arcsec,
522521
matches=None, plate=None, fiberID=None, mjd=None,
523-
timeout=TIMEOUT,
522+
timeout=TIMEOUT, get_query_payload=False,
524523
data_release=conf.default_release, cache=True,
525524
show_progress=True):
526525
"""
@@ -559,6 +558,9 @@ def get_spectra_async(self, coordinates=None, radius=2. * u.arcsec,
559558
timeout : float, optional
560559
Time limit (in seconds) for establishing successful connection with
561560
remote server. Defaults to `SDSSClass.TIMEOUT`.
561+
get_query_payload : bool, optional
562+
If True, this will return the data the query would have sent out,
563+
but does not actually do the query.
562564
data_release : int, optional
563565
The data release of the SDSS to use. With the default server, this
564566
only supports DR8 or later.
@@ -599,12 +601,19 @@ def get_spectra_async(self, coordinates=None, radius=2. * u.arcsec,
599601
if coordinates is None:
600602
matches = self.query_specobj(plate=plate, mjd=mjd, fiberID=fiberID,
601603
fields=['run2d', 'plate', 'mjd', 'fiberID'],
602-
timeout=timeout, data_release=data_release, cache=cache)
604+
timeout=timeout, get_query_payload=get_query_payload,
605+
data_release=data_release, cache=cache)
603606
else:
604-
matches = self.query_crossid(coordinates, radius=radius,
607+
matches = self.query_crossid(coordinates, radius=radius, timeout=timeout,
605608
specobj_fields=['run2d', 'plate', 'mjd', 'fiberID'],
606-
spectro=True,
607-
timeout=timeout, data_release=data_release, cache=cache)
609+
spectro=True, get_query_payload=get_query_payload,
610+
data_release=data_release, cache=cache)
611+
if get_query_payload:
612+
if coordinates is None:
613+
return matches
614+
else:
615+
return matches[0]
616+
608617
if matches is None:
609618
warnings.warn("Query returned no results.", NoResultsWarning)
610619
return
@@ -638,10 +647,10 @@ def get_spectra_async(self, coordinates=None, radius=2. * u.arcsec,
638647
return results
639648

640649
@prepend_docstr_nosections(get_spectra_async.__doc__)
641-
def get_spectra(self, coordinates=None, radius=2. * u.arcsec,
650+
def get_spectra(self, *, coordinates=None, radius=2. * u.arcsec,
642651
matches=None, plate=None, fiberID=None, mjd=None,
643-
timeout=TIMEOUT, cache=True,
644-
data_release=conf.default_release,
652+
timeout=TIMEOUT, get_query_payload=False,
653+
data_release=conf.default_release, cache=True,
645654
show_progress=True):
646655
"""
647656
Returns
@@ -654,9 +663,14 @@ def get_spectra(self, coordinates=None, radius=2. * u.arcsec,
654663
radius=radius, matches=matches,
655664
plate=plate, fiberID=fiberID,
656665
mjd=mjd, timeout=timeout,
666+
get_query_payload=get_query_payload,
657667
data_release=data_release,
668+
cache=cache,
658669
show_progress=show_progress)
659670

671+
if get_query_payload:
672+
return readable_objs
673+
660674
if readable_objs is not None:
661675
if isinstance(readable_objs, dict):
662676
return readable_objs
@@ -666,7 +680,7 @@ def get_spectra(self, coordinates=None, radius=2. * u.arcsec,
666680
def get_images_async(self, coordinates=None, radius=2. * u.arcsec,
667681
matches=None, run=None, rerun=301, camcol=None,
668682
field=None, band='g', timeout=TIMEOUT,
669-
cache=True,
683+
cache=True, get_query_payload=False,
670684
data_release=conf.default_release,
671685
show_progress=True):
672686
"""
@@ -714,6 +728,9 @@ def get_images_async(self, coordinates=None, radius=2. * u.arcsec,
714728
timeout : float, optional
715729
Time limit (in seconds) for establishing successful connection with
716730
remote server. Defaults to `SDSSClass.TIMEOUT`.
731+
get_query_payload : bool, optional
732+
If True, this will return the data the query would have sent out,
733+
but does not actually do the query.
717734
cache : bool, optional
718735
Cache the images using astropy's caching system
719736
data_release : int, optional
@@ -753,12 +770,19 @@ def get_images_async(self, coordinates=None, radius=2. * u.arcsec,
753770
matches = self.query_photoobj(run=run, rerun=rerun,
754771
camcol=camcol, field=field,
755772
fields=['run', 'rerun', 'camcol', 'field'],
756-
timeout=timeout,
773+
timeout=timeout, get_query_payload=get_query_payload,
757774
data_release=data_release, cache=cache)
758775
else:
759-
matches = self.query_crossid(coordinates, radius=radius,
776+
matches = self.query_crossid(coordinates, radius=radius, timeout=timeout,
760777
fields=['run', 'rerun', 'camcol', 'field'],
761-
timeout=timeout, data_release=data_release, cache=cache)
778+
get_query_payload=get_query_payload,
779+
data_release=data_release, cache=cache)
780+
if get_query_payload:
781+
if coordinates is None:
782+
return matches
783+
else:
784+
return matches[0]
785+
762786
if matches is None:
763787
warnings.warn("Query returned no results.", NoResultsWarning)
764788
return
@@ -786,7 +810,7 @@ def get_images_async(self, coordinates=None, radius=2. * u.arcsec,
786810
return results
787811

788812
@prepend_docstr_nosections(get_images_async.__doc__)
789-
def get_images(self, coordinates=None, radius=2. * u.arcsec,
813+
def get_images(self, *, coordinates=None, radius=2. * u.arcsec,
790814
matches=None, run=None, rerun=301, camcol=None, field=None,
791815
band='g', timeout=TIMEOUT, cache=True,
792816
get_query_payload=False, data_release=conf.default_release,
@@ -798,10 +822,22 @@ def get_images(self, coordinates=None, radius=2. * u.arcsec,
798822
799823
"""
800824

801-
readable_objs = self.get_images_async(
802-
coordinates=coordinates, radius=radius, matches=matches, run=run,
803-
rerun=rerun, data_release=data_release, camcol=camcol, field=field,
804-
band=band, timeout=timeout, show_progress=show_progress)
825+
readable_objs = self.get_images_async(coordinates=coordinates,
826+
radius=radius,
827+
matches=matches,
828+
run=run,
829+
rerun=rerun,
830+
camcol=camcol,
831+
field=field,
832+
band=band,
833+
timeout=timeout,
834+
cache=cache,
835+
get_query_payload=get_query_payload,
836+
data_release=data_release,
837+
show_progress=show_progress)
838+
839+
if get_query_payload:
840+
return readable_objs
805841

806842
if readable_objs is not None:
807843
if isinstance(readable_objs, dict):
@@ -906,7 +942,7 @@ def _parse_result(self, response, verbose=False):
906942
else:
907943
return arr
908944

909-
def _args_to_payload(self, coordinates=None,
945+
def _args_to_payload(self, *, coordinates=None,
910946
fields=None, spectro=False, region=False,
911947
plate=None, mjd=None, fiberID=None, run=None,
912948
rerun=301, camcol=None, field=None,

astroquery/sdss/tests/test_sdss.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_sdss_spectrum_mjd(patch_request, patch_get_readable_fileobj, dr):
177177
@pytest.mark.parametrize("dr", dr_list)
178178
def test_sdss_spectrum_coords(patch_request, patch_get_readable_fileobj, dr,
179179
coords=coords):
180-
sp = sdss.SDSS.get_spectra(coords, data_release=dr)
180+
sp = sdss.SDSS.get_spectra(coordinates=coords, data_release=dr)
181181
image_tester(sp, 'spectra')
182182

183183

@@ -220,7 +220,7 @@ def test_sdss_image_run(patch_request, patch_get_readable_fileobj, dr):
220220
@pytest.mark.parametrize("dr", dr_list)
221221
def test_sdss_image_coord(patch_request, patch_get_readable_fileobj, dr,
222222
coord=coords):
223-
img = sdss.SDSS.get_images(coords, data_release=dr)
223+
img = sdss.SDSS.get_images(coordinates=coords, data_release=dr)
224224
image_tester(img, 'images')
225225

226226

@@ -454,6 +454,63 @@ def test_photoobj_run_camcol_field_payload(patch_request, dr):
454454
assert query_payload['format'] == 'csv'
455455

456456

457+
@pytest.mark.parametrize("dr", dr_list)
458+
def test_get_spectra_specobj_payload(patch_request, dr):
459+
expect = ("SELECT DISTINCT "
460+
"s.run2d, s.plate, s.mjd, s.fiberID "
461+
"FROM PhotoObjAll AS p "
462+
"JOIN SpecObjAll AS s ON p.objID = s.bestObjID "
463+
"WHERE "
464+
"(s.plate=751 AND s.mjd=52251)")
465+
query_payload = sdss.SDSS.get_spectra_async(plate=751, mjd=52251,
466+
get_query_payload=True,
467+
data_release=dr)
468+
assert query_payload['cmd'] == expect
469+
assert query_payload['format'] == 'csv'
470+
471+
472+
@pytest.mark.parametrize("dr", dr_list)
473+
def test_get_spectra_coordinates_payload(patch_request, dr):
474+
expect = ("SELECT\r\n"
475+
"s.run2d, s.plate, s.mjd, s.fiberID, s.SpecObjID AS obj_id, dbo.fPhotoTypeN(p.type) AS type "
476+
"FROM #upload u JOIN #x x ON x.up_id = u.up_id JOIN PhotoObjAll AS p ON p.objID = x.objID "
477+
"JOIN SpecObjAll AS s ON p.objID = s.bestObjID "
478+
"ORDER BY x.up_id")
479+
query_payload = sdss.SDSS.get_spectra_async(coordinates=coords_column,
480+
get_query_payload=True,
481+
data_release=dr)
482+
assert query_payload['uquery'] == expect
483+
assert query_payload['format'] == 'csv'
484+
assert query_payload['photoScope'] == 'nearPrim'
485+
486+
487+
@pytest.mark.parametrize("dr", dr_list)
488+
def test_get_images_photoobj_payload(patch_request, dr):
489+
expect = ("SELECT DISTINCT "
490+
"p.run, p.rerun, p.camcol, p.field "
491+
"FROM PhotoObjAll AS p WHERE "
492+
"(p.run=5714 AND p.camcol=6 AND p.rerun=301)")
493+
query_payload = sdss.SDSS.get_images_async(run=5714, camcol=6,
494+
get_query_payload=True,
495+
data_release=dr)
496+
assert query_payload['cmd'] == expect
497+
assert query_payload['format'] == 'csv'
498+
499+
500+
@pytest.mark.parametrize("dr", dr_list)
501+
def test_get_images_coordinates_payload(patch_request, dr):
502+
expect = ("SELECT\r\n"
503+
"p.run, p.rerun, p.camcol, p.field, dbo.fPhotoTypeN(p.type) AS type "
504+
"FROM #upload u JOIN #x x ON x.up_id = u.up_id JOIN PhotoObjAll AS p ON p.objID = x.objID "
505+
"ORDER BY x.up_id")
506+
query_payload = sdss.SDSS.get_images_async(coordinates=coords_column,
507+
get_query_payload=True,
508+
data_release=dr)
509+
assert query_payload['uquery'] == expect
510+
assert query_payload['format'] == 'csv'
511+
assert query_payload['photoScope'] == 'nearPrim'
512+
513+
457514
@pytest.mark.parametrize("dr", dr_list)
458515
def test_spectra_plate_mjd_payload(patch_request, dr):
459516
expect = ("SELECT DISTINCT "

astroquery/sdss/tests/test_sdss_remote.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_sdss_spectrum_mjd(self):
6363
sp = sdss.SDSS.get_spectra(plate=2345, fiberID=572)
6464

6565
def test_sdss_spectrum_coords(self):
66-
sp = sdss.SDSS.get_spectra(self.coords)
66+
sp = sdss.SDSS.get_spectra(coordinates=self.coords)
6767

6868
def test_sdss_sql(self):
6969
query = """
@@ -91,7 +91,7 @@ def test_sdss_image_run(self):
9191
img = sdss.SDSS.get_images(run=1904, camcol=3, field=164)
9292

9393
def test_sdss_image_coord(self):
94-
img = sdss.SDSS.get_images(self.coords)
94+
img = sdss.SDSS.get_images(coordinates=self.coords)
9595

9696
def test_sdss_specobj(self):
9797
colnames = ['ra', 'dec', 'objid', 'run', 'rerun', 'camcol', 'field',
@@ -161,7 +161,7 @@ def test_query_timeout(self):
161161
"self._request, fix it before merging #586"))
162162
def test_spectra_timeout(self):
163163
with pytest.raises(TimeoutError):
164-
sdss.SDSS.get_spectra(self.coords, timeout=self.mintimeout)
164+
sdss.SDSS.get_spectra(coordinates=self.coords, timeout=self.mintimeout)
165165

166166
def test_query_non_default_field(self):
167167
# A regression test for #469

0 commit comments

Comments
 (0)