|
25 | 25 | from opensearch_py_ml.ml_commons.model_connector import Connector |
26 | 26 | from opensearch_py_ml.ml_commons.model_execute import ModelExecute |
27 | 27 | from opensearch_py_ml.ml_commons.model_uploader import ModelUploader |
| 28 | +from opensearch_py_ml.ml_commons.validators.profile import validate_profile_input |
28 | 29 |
|
29 | 30 |
|
30 | 31 | class MLCommonClient: |
@@ -606,3 +607,111 @@ def delete_task(self, task_id: str) -> object: |
606 | 607 | method="DELETE", |
607 | 608 | url=API_URL, |
608 | 609 | ) |
| 610 | + |
| 611 | + def _get_profile(self, payload: Optional[dict] = None): |
| 612 | + """ |
| 613 | + Get the profile using the given payload. |
| 614 | +
|
| 615 | + :param payload: The payload to be used for getting the profile. Defaults to None. |
| 616 | + :type payload: Optional[dict] |
| 617 | + :return: The response from the server after performing the request. |
| 618 | + :rtype: Any |
| 619 | + """ |
| 620 | + validate_profile_input(None, payload) |
| 621 | + return self._client.transport.perform_request( |
| 622 | + method="GET", url=f"{ML_BASE_URI}/profile", body=payload |
| 623 | + ) |
| 624 | + |
| 625 | + def _get_models_profile( |
| 626 | + self, model_id: Optional[str] = "", payload: Optional[dict] = None |
| 627 | + ): |
| 628 | + """ |
| 629 | + Get the profile of a model. |
| 630 | +
|
| 631 | + Args: |
| 632 | + model_id (str, optional): The ID of the model. Defaults to "". |
| 633 | + payload (dict, optional): Additional payload for the request. Defaults to None. |
| 634 | +
|
| 635 | + Returns: |
| 636 | + dict: The response from the API. |
| 637 | + """ |
| 638 | + validate_profile_input(model_id, payload) |
| 639 | + |
| 640 | + url = f"{ML_BASE_URI}/profile/models/{model_id if model_id else ''}" |
| 641 | + return self._client.transport.perform_request( |
| 642 | + method="GET", url=url, body=payload |
| 643 | + ) |
| 644 | + |
| 645 | + def _get_tasks_profile( |
| 646 | + self, task_id: Optional[str] = "", payload: Optional[dict] = None |
| 647 | + ): |
| 648 | + """ |
| 649 | + Retrieves the profile of a task from the API. |
| 650 | +
|
| 651 | + Parameters: |
| 652 | + task_id (str, optional): The ID of the task to retrieve the profile for. Defaults to an empty string. |
| 653 | + payload (dict, optional): Additional payload for the request. Defaults to None. |
| 654 | +
|
| 655 | + Returns: |
| 656 | + dict: The profile of the task. |
| 657 | +
|
| 658 | + Raises: |
| 659 | + ValueError: If the input validation fails. |
| 660 | +
|
| 661 | + """ |
| 662 | + validate_profile_input(task_id, payload) |
| 663 | + |
| 664 | + url = f"{ML_BASE_URI}/profile/tasks/{task_id if task_id else ''}" |
| 665 | + return self._client.transport.perform_request( |
| 666 | + method="GET", url=url, body=payload |
| 667 | + ) |
| 668 | + |
| 669 | + def get_profile( |
| 670 | + self, |
| 671 | + profile_type: str = "all", |
| 672 | + ids: Optional[Union[str, List[str]]] = None, |
| 673 | + request_body: Optional[dict] = None, |
| 674 | + ) -> dict: |
| 675 | + """ |
| 676 | + Get profile information based on the profile type. |
| 677 | +
|
| 678 | + Args: |
| 679 | + profile_type: The type of profile to retrieve. Valid values are 'all', 'model', or 'task'. Default is 'all'. |
| 680 | + 'all': Retrieves all profiles available. |
| 681 | + 'model': Retrieves the profile(s) of the specified model(s). The model(s) to retrieve are specified by the 'ids' parameter. |
| 682 | + 'task': Retrieves the profile(s) of the specified task(s). The task(s) to retrieve are specified by the 'ids' parameter. |
| 683 | + ids: Either a single profile ID as a string, or a list of profile IDs to retrieve. Default is None. |
| 684 | + request_body: The request body containing additional information. Default is None. |
| 685 | +
|
| 686 | + Returns: |
| 687 | + The profile information. |
| 688 | +
|
| 689 | + Raises: |
| 690 | + ValueError: If the profile_type is not 'all', 'model', or 'task'. |
| 691 | +
|
| 692 | + Example: |
| 693 | + get_profile() |
| 694 | +
|
| 695 | + get_profile(profile_type='model', ids='model1') |
| 696 | +
|
| 697 | + get_profile(profile_type='model', ids=['model1', 'model2']) |
| 698 | +
|
| 699 | + get_profile(profile_type='task', ids='task1', request_body={"node_ids": ["KzONM8c8T4Od-NoUANQNGg"],"return_all_tasks": true,"return_all_models": true}) |
| 700 | +
|
| 701 | + get_profile(profile_type='task', ids=['task1', 'task2'], request_body={'additional': 'info'}) |
| 702 | + """ |
| 703 | + |
| 704 | + if profile_type == "all": |
| 705 | + return self._get_profile(request_body) |
| 706 | + elif profile_type == "model": |
| 707 | + if ids and isinstance(ids, list): |
| 708 | + ids = ",".join(ids) |
| 709 | + return self._get_models_profile(ids, request_body) |
| 710 | + elif profile_type == "task": |
| 711 | + if ids and isinstance(ids, list): |
| 712 | + ids = ",".join(ids) |
| 713 | + return self._get_tasks_profile(ids, request_body) |
| 714 | + else: |
| 715 | + raise ValueError( |
| 716 | + "Invalid profile type. Profile type must be 'all', 'model' or 'task'." |
| 717 | + ) |
0 commit comments