Skip to content

Commit a81506f

Browse files
authored
Merge pull request #947 from CitrineInformatics/feature/pne-416-v4-listing-endpoints
[PNE-416] Use v4 listing endpoints.
2 parents 104bc0c + 9f5cb7f commit a81506f

File tree

6 files changed

+147
-35
lines changed

6 files changed

+147
-35
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.2.15"
1+
__version__ = "3.3.0"

src/citrine/_rest/pageable.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def _fetch_page(self,
2525
per_page: Optional[int] = None,
2626
json_body: Optional[dict] = None,
2727
additional_params: Optional[dict] = None,
28+
*,
29+
version: Optional[str] = None
2830
) -> Tuple[Iterable[dict], str]:
2931
"""
3032
Fetch visible elements. This does not handle pagination.
@@ -58,6 +60,9 @@ def _fetch_page(self,
5860
}
5961
additional_params: dict, optional
6062
A dict that allows extra parameters to be added to the request parameters
63+
version: str, optional
64+
A string denoting which version of the underlying API endpoint will be called. Defaults
65+
to the collection's API version.
6166
6267
Returns
6368
-------
@@ -68,15 +73,16 @@ def _fetch_page(self,
6873
6974
"""
7075
# To avoid setting defaults -> reduce mutation risk, and to make more extensible
71-
path = self._get_path() if path is None else path
72-
fetch_func = self.session.get_resource if fetch_func is None else fetch_func
73-
json_body = {} if json_body is None else json_body
76+
path = path or self._get_path()
77+
fetch_func = fetch_func or self.session.get_resource
78+
json_body = json_body or {}
7479

75-
module_type = getattr(self, '_module_type', None)
76-
params = self._page_params(page, per_page, module_type)
80+
params = self._page_params(page, per_page)
7781
params.update(additional_params or {})
7882

79-
data = fetch_func(path, params=params, version=self._api_version, **json_body)
83+
version = version or self._api_version
84+
85+
data = fetch_func(path, params=params, version=version, **json_body)
8086

8187
try:
8288
next_uri = data.get('next', "")

src/citrine/resources/design_space.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Resources that represent collections of design spaces."""
2-
from typing import Optional, TypeVar, Union
2+
from functools import partial
3+
from typing import Iterable, Optional, TypeVar, Union
34
from uuid import UUID
45

56
from gemd.enumeration.base_enumeration import BaseEnumeration
@@ -127,6 +128,30 @@ def restore(self, uid: Union[UUID, str]) -> DesignSpace:
127128
entity = self.session.put_resource(url, {}, version=self._api_version)
128129
return self.build(entity)
129130

131+
def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None):
132+
filters = {}
133+
if archived is not None:
134+
filters["archived"] = archived
135+
136+
fetcher = partial(self._fetch_page,
137+
fetch_func=partial(self.session.get_resource, version="v4"),
138+
additional_params=filters)
139+
return self._paginator.paginate(page_fetcher=fetcher,
140+
collection_builder=self._build_collection_elements,
141+
per_page=per_page)
142+
143+
def list_all(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
144+
"""List the most recent version of all design spaces."""
145+
return self._list_base(per_page=per_page)
146+
147+
def list(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
148+
"""List the most recent version of all non-archived design spaces."""
149+
return self._list_base(per_page=per_page, archived=False)
150+
151+
def list_archived(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
152+
"""List the most recent version of all archived predictors."""
153+
return self._list_base(per_page=per_page, archived=True)
154+
130155
def create_default(self,
131156
*,
132157
predictor_id: UUID,

src/citrine/resources/predictor.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -367,20 +367,37 @@ def restore(self, uid: Union[UUID, str]):
367367
raise NotImplementedError("The restore() method is no longer supported. You most likely "
368368
"want restore_root(), or possibly restore_version().")
369369

370+
def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None):
371+
filters = {}
372+
if archived is not None:
373+
filters["archived"] = archived
374+
375+
fetcher = partial(self._fetch_page,
376+
additional_params=filters,
377+
version="v4")
378+
return self._paginator.paginate(page_fetcher=fetcher,
379+
collection_builder=self._build_collection_elements,
380+
per_page=per_page)
381+
382+
def list_all(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
383+
"""List the most recent version of all predictors."""
384+
return self._list_base(per_page=per_page)
385+
386+
def list(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
387+
"""List the most recent version of all non-archived predictors."""
388+
return self._list_base(per_page=per_page, archived=False)
389+
390+
def list_archived(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
391+
"""List the most recent version of all archived predictors."""
392+
return self._list_base(per_page=per_page, archived=True)
393+
370394
def list_versions(self,
371395
uid: Union[UUID, str] = None,
372396
*,
373397
per_page: int = 100) -> Iterable[GraphPredictor]:
374398
"""List all non-archived versions of the given Predictor."""
375399
return self._versions_collection.list(uid, per_page=per_page)
376400

377-
def list_archived(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
378-
"""List archived Predictors."""
379-
fetcher = partial(self._fetch_page, additional_params={"filter": "archived eq 'true'"})
380-
return self._paginator.paginate(page_fetcher=fetcher,
381-
collection_builder=self._build_collection_elements,
382-
per_page=per_page)
383-
384401
def list_archived_versions(self,
385402
uid: Union[UUID, str] = None,
386403
*,

tests/resources/test_design_space.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,45 @@ def test_list_design_spaces(valid_formulation_design_space_data, valid_enumerate
263263

264264
# Then
265265
expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id),
266-
params={'per_page': 20, 'page': 1})
266+
params={'per_page': 20, 'page': 1, 'archived': False})
267+
assert 1 == session.num_calls, session.calls
268+
assert expected_call == session.calls[0]
269+
assert len(design_spaces) == 2
270+
271+
272+
def test_list_all_design_spaces(valid_formulation_design_space_data, valid_enumerated_design_space_data):
273+
# Given
274+
session = FakeSession()
275+
collection = DesignSpaceCollection(uuid.uuid4(), session)
276+
session.set_response({
277+
'response': [valid_formulation_design_space_data, valid_enumerated_design_space_data]
278+
})
279+
280+
# When
281+
design_spaces = list(collection.list_all(per_page=25))
282+
283+
# Then
284+
expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id),
285+
params={'per_page': 25, 'page': 1})
286+
assert 1 == session.num_calls, session.calls
287+
assert expected_call == session.calls[0]
288+
assert len(design_spaces) == 2
289+
290+
291+
def test_list_archived_design_spaces(valid_formulation_design_space_data, valid_enumerated_design_space_data):
292+
# Given
293+
session = FakeSession()
294+
collection = DesignSpaceCollection(uuid.uuid4(), session)
295+
session.set_response({
296+
'response': [valid_formulation_design_space_data, valid_enumerated_design_space_data]
297+
})
298+
299+
# When
300+
design_spaces = list(collection.list_archived(per_page=25))
301+
302+
# Then
303+
expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id),
304+
params={'per_page': 25, 'page': 1, 'archived': True})
267305
assert 1 == session.num_calls, session.calls
268306
assert expected_call == session.calls[0]
269307
assert len(design_spaces) == 2

tests/resources/test_predictor.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -278,30 +278,70 @@ def test_train(valid_graph_predictor_data):
278278
assert session.calls == expected_calls
279279

280280

281-
def test_list_predictors(valid_graph_predictor_data, valid_graph_predictor_data_empty):
281+
def test_list(valid_graph_predictor_data, valid_graph_predictor_data_empty):
282282
# Given
283283
session = FakeSession()
284284
collection = PredictorCollection(uuid.uuid4(), session)
285285
session.set_responses(
286286
{
287287
'response': [valid_graph_predictor_data, valid_graph_predictor_data_empty],
288-
'next': ''
288+
'page': 1,
289+
'per_page': 25
289290
},
290291
basic_predictor_report_data,
291292
basic_predictor_report_data
292293
)
293294

294295
# When
295-
predictors = list(collection.list(per_page=20))
296+
predictors = list(collection.list(per_page=25))
296297

297298
# Then
298-
expected_call = FakeCall(method='GET', path='/projects/{}/predictors'.format(collection.project_id),
299-
params={'per_page': 20, 'page': 1})
299+
expected_call = FakeCall(method='GET',
300+
path='/projects/{}/predictors'.format(collection.project_id),
301+
params={'per_page': 25, 'page': 1, 'archived': False})
300302
assert 1 == session.num_calls, session.calls
301303
assert expected_call == session.calls[0]
302304
assert len(predictors) == 2
303305

304306

307+
def test_list_all(valid_graph_predictor_data, valid_graph_predictor_data_empty):
308+
# Given
309+
session = FakeSession()
310+
collection = PredictorCollection(uuid.uuid4(), session)
311+
session.set_responses(
312+
{'response': [valid_graph_predictor_data, valid_graph_predictor_data_empty]},
313+
basic_predictor_report_data,
314+
basic_predictor_report_data
315+
)
316+
317+
# When
318+
predictors = list(collection.list_all(per_page=25))
319+
320+
# Then
321+
expected_call = FakeCall(method='GET',
322+
path='/projects/{}/predictors'.format(collection.project_id),
323+
params={'per_page': 25, 'page': 1})
324+
assert 1 == session.num_calls, session.calls
325+
assert expected_call == session.calls[0]
326+
assert len(predictors) == 2
327+
328+
329+
def test_list_archived(valid_graph_predictor_data):
330+
# Given
331+
session = FakeSession()
332+
session.set_response({'response': [valid_graph_predictor_data]})
333+
pc = PredictorCollection(uuid.uuid4(), session)
334+
335+
# When
336+
list(pc.list_archived())
337+
338+
# Then
339+
assert session.num_calls == 1
340+
assert session.last_call == FakeCall(method='GET',
341+
path=f"/projects/{pc.project_id}/predictors",
342+
params={'per_page': 20, 'page': 1, 'archived': True})
343+
344+
305345
def test_get(valid_graph_predictor_data):
306346
# Given
307347
session = FakeSession()
@@ -445,20 +485,6 @@ def test_returned_predictor(valid_graph_predictor_data):
445485
assert isinstance(result.predictors[-1], AutoMLPredictor)
446486

447487

448-
def test_predictor_list_archived(valid_graph_predictor_data):
449-
# Given
450-
session = FakeSession()
451-
session.set_response({'response': [valid_graph_predictor_data]})
452-
pc = PredictorCollection(uuid.uuid4(), session)
453-
454-
# When
455-
list(pc.list_archived())
456-
457-
# Then
458-
assert session.num_calls == 1
459-
assert session.last_call == FakeCall(method='GET', path=f"/projects/{pc.project_id}/predictors", params={"filter": "archived eq 'true'", 'per_page': 20, 'page': 1})
460-
461-
462488
def test_list_versions(valid_graph_predictor_data):
463489
# Given
464490
session = FakeSession()

0 commit comments

Comments
 (0)