Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## Version 11.4.0 - 2024-11-21

- Added `get_prediction`
- Support run_async argument for `create_predictions`
- Support `statistics_last_n_days` for `get_training`

## Version 11.3.0 - 2024-03-27

- Support profile argument when creating a Client
Expand Down
2 changes: 1 addition & 1 deletion las/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
__maintainer_email__ ='[email protected]'
__title__ = 'lucidtech-las'
__url__ = 'https://github.com/LucidtechAI/las-sdk-python'
__version__ = '11.3.0'
__version__ = '11.4.0'
30 changes: 27 additions & 3 deletions las/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def delete_dataset(self, dataset_id: str, delete_documents: bool = False) -> Dic

:param dataset_id: Id of the dataset
:type dataset_id: str
:param delete_documents: Set to true to delete documents in dataset before deleting dataset
:param delete_documents: Set to True to delete documents in dataset before deleting dataset
:type delete_documents: bool
:return: Dataset response from REST API
:rtype: dict
Expand Down Expand Up @@ -1342,20 +1342,23 @@ def create_training(
body.update(**optional_args)
return self._make_request(requests.post, f'/models/{model_id}/trainings', body=body)

def get_training(self, model_id: str, training_id: str) -> Dict:
def get_training(self, model_id: str, training_id: str, statistics_last_n_days: Optional[int] = None) -> Dict:
"""Get training, calls the GET /models/{modelId}/trainings/{trainingId} endpoint.

:param model_id: ID of the model
:type model_id: str
:param training_id: ID of the training
:type training_id: str
:param statistics_last_n_days: Integer between 1 and 30
:type statistics_last_n_days: int, optional
:return: Training response from REST API
:rtype: dict

:raises: :py:class:`~las.InvalidCredentialsException`, :py:class:`~las.TooManyRequestsException`,\
:py:class:`~las.LimitExceededException`, :py:class:`requests.exception.RequestException`
"""
return self._make_request(requests.get, f'/models/{model_id}/trainings/{training_id}')
params = {'statisticsLastNDays': statistics_last_n_days}
return self._make_request(requests.get, f'/models/{model_id}/trainings/{training_id}', params=params)

def list_trainings(self, model_id, *, max_results: Optional[int] = None, next_token: Optional[str] = None) -> Dict:
"""List trainings available, calls the GET /models/{modelId}/trainings endpoint.
Expand Down Expand Up @@ -1528,6 +1531,7 @@ def create_prediction(
training_id: Optional[str] = None,
preprocess_config: Optional[dict] = None,
postprocess_config: Optional[dict] = None,
run_async: Optional[bool] = None,
) -> Dict:
"""Create a prediction on a document using specified model, calls the POST /predictions endpoint.

Expand Down Expand Up @@ -1568,6 +1572,8 @@ def create_prediction(
{'strategy': 'BEST_N_PAGES', 'parameters': {'n': 3}}
{'strategy': 'BEST_N_PAGES', 'parameters': {'n': 3, 'collapse': False}}
:type postprocess_config: dict, optional
:param run_async: If True run the prediction async, if False run sync. if omitted run synchronously.
:type run_async: bool
:return: Prediction response from REST API
:rtype: dict

Expand All @@ -1580,6 +1586,7 @@ def create_prediction(
'trainingId': training_id,
'preprocessConfig': preprocess_config,
'postprocessConfig': postprocess_config,
'async': run_async,
}
return self._make_request(requests.post, '/predictions', body=dictstrip(body))

Expand Down Expand Up @@ -1623,6 +1630,23 @@ def list_predictions(
}
return self._make_request(requests.get, '/predictions', params=dictstrip(params))

def get_prediction(self, prediction_id: str) -> Dict:
"""Get prediction, calls the GET /predictions/{predictionId} endpoint.

>>> from las.client import Client
>>> client = Client()
>>> client.get_prediction(prediction_id='<prediction id>')

:param prediction_id: Id of the prediction
:type prediction_id: str
:return: Asset response from REST API with content
:rtype: dict

:raises: :py:class:`~las.InvalidCredentialsException`, :py:class:`~las.TooManyRequestsException`,\
:py:class:`~las.LimitExceededException`, :py:class:`requests.exception.RequestException`
"""
return self._make_request(requests.get, f'/predictions/{prediction_id}')

def get_plan(self, plan_id: str) -> Dict:
"""Get information about a specific plan, calls the GET /plans/{plan_id} endpoint.

Expand Down
4 changes: 4 additions & 0 deletions tests/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def create_payment_method_id():
return f'las:payment-method:{uuid4().hex}'


def create_prediction_id():
return f'las:prediction:{uuid4().hex}'


def create_deployment_environment_id():
return f'las:deployment-environment:{uuid4().hex}'

Expand Down
13 changes: 12 additions & 1 deletion tests/test_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
{'strategy': 'BEST_N_PAGES', 'parameters': {'n': 3, 'collapse': False}},
None,
])
def test_create_prediction(client: Client, preprocess_config, postprocess_config):
@pytest.mark.parametrize('run_async', [True, False, None])
def test_create_prediction(client: Client, preprocess_config, postprocess_config, run_async):
document_id = service.create_document_id()
model_id = service.create_model_id()
response = client.create_prediction(
document_id,
model_id,
preprocess_config=dictstrip(preprocess_config) if preprocess_config else None,
postprocess_config=postprocess_config,
run_async=run_async,
)
assert 'predictionId' in response, 'Missing predictionId in response'

Expand All @@ -39,3 +41,12 @@ def test_list_predictions(client: Client, sort_by, order, model_id):
response = client.list_predictions(sort_by=sort_by, order=order, model_id=model_id)
logging.info(response)
assert 'predictions' in response, 'Missing predictions in response'


@pytest.mark.parametrize('prediction_id', [service.create_prediction_id(), None])
def test_get_prediction(client: Client, prediction_id):
response = client.get_prediction(prediction_id)
logging.info(response)
assert 'predictionId' in response, 'Missing prediction in response'
assert 'inferenceTime' in response, 'Missing inferenceTime in response'
assert 'predictions' in response, 'Missing predictions in response'
Loading