diff --git a/CHANGELOG.md b/CHANGELOG.md index 4eb78bf78..535b4c5c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332)) - Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345)) + ### Changed - Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211)) - Update demo_tracing_model_torchscript_onnx.ipynb to use make_model_config_json by @thanawan-atc in ([#220](https://github.com/opensearch-project/opensearch-py-ml/pull/220)) @@ -44,6 +45,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fix pandas dependency issue in nox session by installing pandas package to python directly by @thanawan-atc in ([#266](https://github.com/opensearch-project/opensearch-py-ml/pull/266)) - Fix conditional job execution issue in model upload workflow by @thanawan-atc in ([#294](https://github.com/opensearch-project/opensearch-py-ml/pull/294)) - fix bug in `MLCommonClient_client.upload_model` by @rawwar in ([#336](https://github.com/opensearch-project/opensearch-py-ml/pull/336)) +- Fix to ensure `model_id` is only required once when saving a model in `sentencetransformermodel.py` in ([#323](https://github.com/opensearch-project/opensearch-py-ml/pull/361)) ## [1.1.0] diff --git a/opensearch_py_ml/ml_models/sentencetransformermodel.py b/opensearch_py_ml/ml_models/sentencetransformermodel.py index db5d5a226..ebee6899f 100644 --- a/opensearch_py_ml/ml_models/sentencetransformermodel.py +++ b/opensearch_py_ml/ml_models/sentencetransformermodel.py @@ -92,7 +92,7 @@ def __init__( str("The default folder path already exists at : " + self.folder_path) ) - self.model_id = model_id + self.model_id = model_id if model_id is not None else self.DEFAULT_MODEL_ID self.torch_script_zip_file_path = None self.onnx_zip_file_path = None @@ -806,7 +806,7 @@ def save_as_pt( :rtype: string """ - model = SentenceTransformer(model_id) + model = SentenceTransformer(self.model_id) if model_name is None: model_name = str(model_id.split("/")[-1] + ".pt") diff --git a/tests/ml_models/test_sentencetransformermodel_pytest.py b/tests/ml_models/test_sentencetransformermodel_pytest.py index c9c9046ba..493b79958 100644 --- a/tests/ml_models/test_sentencetransformermodel_pytest.py +++ b/tests/ml_models/test_sentencetransformermodel_pytest.py @@ -8,11 +8,13 @@ import json import os import shutil +from unittest.mock import MagicMock, patch from zipfile import ZipFile import pytest -from opensearch_py_ml.ml_models import SentenceTransformerModel +from opensearch_py_ml.ml_commons import MLCommonClient +from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel TEST_FOLDER = os.path.join( os.path.dirname(os.path.abspath("__file__")), "tests", "test_model_files" @@ -658,5 +660,23 @@ def test_zip_model_with_license(): clean_test_folder(TEST_FOLDER) +@pytest.fixture +def mock_opensearch_client(): + with patch("opensearchpy.OpenSearch") as mock_client: + mock_client.return_value = MagicMock() + yield mock_client + + +def test_opensearch_connection(mock_opensearch_client): + client = MLCommonClient(mock_opensearch_client) + assert client._client == mock_opensearch_client + + +def test_init(): + model = SentenceTransformerModel(model_id="test-model", folder_path="/test/folder") + assert model.folder_path == "/test/folder" + assert model.model_id == "test-model" + + clean_test_folder(TEST_FOLDER) clean_test_folder(TESTDATA_UNZIP_FOLDER)