Skip to content

Commit 9b2f686

Browse files
authored
#285: Add support for Model profile (#358)
* init Signed-off-by: kalyanr <[email protected]> * update changelog Signed-off-by: kalyanr <[email protected]> * update Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyanr <[email protected]> * lint fix Signed-off-by: kalyanr <[email protected]> * reuse validate input Signed-off-by: kalyanr <[email protected]> * update comment Signed-off-by: kalyanr <[email protected]> * change Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyanr <[email protected]> * update changelog Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyanr <[email protected]> * remove separate model profile module Signed-off-by: kalyanr <[email protected]> * fix tests Signed-off-by: kalyanr <[email protected]> * fix lint Signed-off-by: kalyanr <[email protected]> * fix lint Signed-off-by: kalyanr <[email protected]> * fix Signed-off-by: kalyanr <[email protected]> * Update ml_commons_client.py Signed-off-by: Kalyan <[email protected]> --------- Signed-off-by: kalyanr <[email protected]> Signed-off-by: Kalyan <[email protected]>
1 parent 31ee5dc commit 9b2f686

File tree

4 files changed

+188
-0
lines changed

4 files changed

+188
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
1111
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))
1212
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))
1313
- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345))
14+
- Add support for model profiles by @rawwar in ([#358](https://github.com/opensearch-project/opensearch-py-ml/pull/358))
1415

1516
### Changed
1617
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))

opensearch_py_ml/ml_commons/ml_commons_client.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from opensearch_py_ml.ml_commons.model_connector import Connector
2626
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
2727
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
28+
from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input
2829

2930

3031
class MLCommonClient:
@@ -606,3 +607,111 @@ def delete_task(self, task_id: str) -> object:
606607
method="DELETE",
607608
url=API_URL,
608609
)
610+
611+
def _get_profile(self, payload: Optional[dict] = None):
612+
"""
613+
Get the profile using the given payload.
614+
615+
:param payload: The payload to be used for getting the profile. Defaults to None.
616+
:type payload: Optional[dict]
617+
:return: The response from the server after performing the request.
618+
:rtype: Any
619+
"""
620+
validate_profile_input(None, payload)
621+
return self._client.transport.perform_request(
622+
method="GET", url=f"{ML_BASE_URI}/profile", body=payload
623+
)
624+
625+
def _get_models_profile(
626+
self, model_id: Optional[str] = "", payload: Optional[dict] = None
627+
):
628+
"""
629+
Get the profile of a model.
630+
631+
Args:
632+
model_id (str, optional): The ID of the model. Defaults to "".
633+
payload (dict, optional): Additional payload for the request. Defaults to None.
634+
635+
Returns:
636+
dict: The response from the API.
637+
"""
638+
validate_profile_input(model_id, payload)
639+
640+
url = f"{ML_BASE_URI}/profile/models/{model_id if model_id else ''}"
641+
return self._client.transport.perform_request(
642+
method="GET", url=url, body=payload
643+
)
644+
645+
def _get_tasks_profile(
646+
self, task_id: Optional[str] = "", payload: Optional[dict] = None
647+
):
648+
"""
649+
Retrieves the profile of a task from the API.
650+
651+
Parameters:
652+
task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string.
653+
payload (dict, optional): Additional payload for the request. Defaults to None.
654+
655+
Returns:
656+
dict: The profile of the task.
657+
658+
Raises:
659+
ValueError: If the input validation fails.
660+
661+
"""
662+
validate_profile_input(task_id, payload)
663+
664+
url = f"{ML_BASE_URI}/profile/tasks/{task_id if task_id else ''}"
665+
return self._client.transport.perform_request(
666+
method="GET", url=url, body=payload
667+
)
668+
669+
def get_profile(
670+
self,
671+
profile_type: str = "all",
672+
ids: Optional[Union[str, List[str]]] = None,
673+
request_body: Optional[dict] = None,
674+
) -> dict:
675+
"""
676+
Get profile information based on the profile type.
677+
678+
Args:
679+
profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'.
680+
'all': Retrieves all profiles available.
681+
'model': Retrieves the profile(s) of the specified model(s). The model(s) to retrieve are specified by the 'ids' parameter.
682+
'task': Retrieves the profile(s) of the specified task(s). The task(s) to retrieve are specified by the 'ids' parameter.
683+
ids: Either a single profile ID as a string, or a list of profile IDs to retrieve. Default is None.
684+
request_body: The request body containing additional information. Default is None.
685+
686+
Returns:
687+
The profile information.
688+
689+
Raises:
690+
ValueError: If the profile_type is not 'all', 'model', or 'task'.
691+
692+
Example:
693+
get_profile()
694+
695+
get_profile(profile_type='model', ids='model1')
696+
697+
get_profile(profile_type='model', ids=['model1', 'model2'])
698+
699+
get_profile(profile_type='task', ids='task1', request_body={"node_ids": ["KzONM8c8T4Od-NoUANQNGg"],"return_all_tasks": true,"return_all_models": true})
700+
701+
get_profile(profile_type='task', ids=['task1', 'task2'], request_body={'additional': 'info'})
702+
"""
703+
704+
if profile_type == "all":
705+
return self._get_profile(request_body)
706+
elif profile_type == "model":
707+
if ids and isinstance(ids, list):
708+
ids = ",".join(ids)
709+
return self._get_models_profile(ids, request_body)
710+
elif profile_type == "task":
711+
if ids and isinstance(ids, list):
712+
ids = ",".join(ids)
713+
return self._get_tasks_profile(ids, request_body)
714+
else:
715+
raise ValueError(
716+
"Invalid profile type. Profile type must be 'all', 'model' or 'task'."
717+
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# The OpenSearch Contributors require contributions made to
3+
# this file be licensed under the Apache-2.0 license or a
4+
# compatible open source license.
5+
# Any modifications Copyright OpenSearch Contributors. See
6+
# GitHub history for details.
7+
8+
"""Module for validating Profile API parameters """
9+
10+
11+
def validate_profile_input(path_parameter, payload):
12+
if path_parameter is not None and not isinstance(path_parameter, str):
13+
raise ValueError("path_parameter needs to be a string or None")
14+
15+
if payload is not None and not isinstance(payload, dict):
16+
raise ValueError("payload needs to be a dictionary or None")

tests/ml_commons/test_ml_commons_client.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,3 +573,65 @@ def test_search():
573573
except: # noqa: E722
574574
raised = True
575575
assert raised == False, "Raised Exception in searching model"
576+
577+
578+
# Model Profile Tests. These tests will need some model train/predict run data. Hence, need
579+
# to be run at the end after the training/prediction tests are done.
580+
581+
582+
def test_get_profile():
583+
res = ml_client.get_profile()
584+
assert isinstance(res, dict)
585+
assert "nodes" in res
586+
test_model_id = None
587+
test_task_id = None
588+
for node_id, val in res["nodes"].items():
589+
if test_model_id is None and "models" in val:
590+
for model_id, model_val in val["models"].items():
591+
test_model_id = {"node_id": node_id, "model_id": model_id}
592+
break
593+
if test_task_id is None and "tasks" in val:
594+
for task_id, task_val in val["tasks"].items():
595+
test_task_id = {"node_id": node_id, "task_id": task_id}
596+
break
597+
598+
res = ml_client.get_profile(profile_type="model")
599+
assert isinstance(res, dict)
600+
assert "nodes" in res
601+
for node_id, node_val in res["nodes"].items():
602+
assert "models" in node_val
603+
604+
res = ml_client.get_profile(profile_type="model", ids=[test_model_id["model_id"]])
605+
assert isinstance(res, dict)
606+
assert "nodes" in res
607+
assert test_model_id["model_id"] in res["nodes"][test_model_id["node_id"]]["models"]
608+
609+
res = ml_client.get_profile(profile_type="model", ids=["randomid1", "random_id2"])
610+
assert isinstance(res, dict)
611+
assert len(res) == 0
612+
613+
res = ml_client.get_profile(profile_type="task")
614+
assert isinstance(res, dict)
615+
if len(res) > 0:
616+
assert "nodes" in res
617+
for node_id, node_val in res["nodes"].items():
618+
assert "tasks" in node_val
619+
620+
res = ml_client.get_profile(profile_type="task", ids=["random1", "random2"])
621+
assert isinstance(res, dict)
622+
assert len(res) == 0
623+
624+
with pytest.raises(ValueError):
625+
ml_client.get_profile(profile_type="test")
626+
627+
with pytest.raises(ValueError):
628+
ml_client.get_profile(profile_type="model", ids=1)
629+
630+
with pytest.raises(ValueError):
631+
ml_client.get_profile(profile_type="model", request_body=10)
632+
633+
with pytest.raises(ValueError):
634+
ml_client.get_profile(profile_type="task", ids=1)
635+
636+
with pytest.raises(ValueError):
637+
ml_client.get_profile(profile_type="task", request_body=10)

0 commit comments

Comments
 (0)