Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))
- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345))


### Changed
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))
- Update demo_tracing_model_torchscript_onnx.ipynb to use make_model_config_json by @thanawan-atc in ([#220](https://github.com/opensearch-project/opensearch-py-ml/pull/220))
Expand Down Expand Up @@ -44,6 +45,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Fix pandas dependency issue in nox session by installing pandas package to python directly by @thanawan-atc in ([#266](https://github.com/opensearch-project/opensearch-py-ml/pull/266))
- Fix conditional job execution issue in model upload workflow by @thanawan-atc in ([#294](https://github.com/opensearch-project/opensearch-py-ml/pull/294))
- fix bug in `MLCommonClient_client.upload_model` by @rawwar in ([#336](https://github.com/opensearch-project/opensearch-py-ml/pull/336))
- Fix to ensure `model_id` is only required once when saving a model in `sentencetransformermodel.py` in ([#323](https://github.com/opensearch-project/opensearch-py-ml/pull/361))

## [1.1.0]

Expand Down
4 changes: 2 additions & 2 deletions opensearch_py_ml/ml_models/sentencetransformermodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
str("The default folder path already exists at : " + self.folder_path)
)

self.model_id = model_id
self.model_id = model_id if model_id is not None else self.DEFAULT_MODEL_ID
self.torch_script_zip_file_path = None
self.onnx_zip_file_path = None

Expand Down Expand Up @@ -806,7 +806,7 @@ def save_as_pt(
:rtype: string
"""

model = SentenceTransformer(model_id)
model = SentenceTransformer(self.model_id)

if model_name is None:
model_name = str(model_id.split("/")[-1] + ".pt")
Expand Down
22 changes: 21 additions & 1 deletion tests/ml_models/test_sentencetransformermodel_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import json
import os
import shutil
from unittest.mock import MagicMock, patch
from zipfile import ZipFile

import pytest

from opensearch_py_ml.ml_models import SentenceTransformerModel
from opensearch_py_ml.ml_commons import MLCommonClient
from opensearch_py_ml.ml_models.sentencetransformermodel import SentenceTransformerModel

TEST_FOLDER = os.path.join(
os.path.dirname(os.path.abspath("__file__")), "tests", "test_model_files"
Expand Down Expand Up @@ -658,5 +660,23 @@ def test_zip_model_with_license():
clean_test_folder(TEST_FOLDER)


@pytest.fixture
def mock_opensearch_client():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not want to mock Opensearch client. Tests we are running are integration tests. Hence, mocking might not be useful.

Copy link
Contributor

@rawwar rawwar Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like I am wrong. In this case, we can just mock the client.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vinay-Vinod Thanks for raising the PR. Could you please make sure if this PR addresses all these:

What solution would you like?
I would like to EITHER 1/ only provide the model_id when I instantiate the SentenceTransformerModel and it gets used in all class methods where applicable, rather than randomly defaulting back to "distilbert-tas-b" OR 2/ use save_as_pt as a static method and never instantiate the class at all. Also it would be nice if I could specify the output location of make_model_config_json.
  1. only provide the model_id when I instantiate the SentenceTransformerModel and it gets used in all class methods where applicable, rather than randomly defaulting back to "distilbert-tas-b"
  2. In save_as_pt method, if no other model id is given as function parameter, it will use the model id which was initialized with SentenceTransformerModel class. If still a model is given as a parameter then in save_as_pt method we'll use that model id. Please let me know if you have any confusion.
  3. it would be nice if I could specify the output location of make_model_config_json --> Can we take care of this one too? Or are you planning to raise another PR for this?

with patch("opensearchpy.OpenSearch") as mock_client:
mock_client.return_value = MagicMock()
yield mock_client


def test_opensearch_connection(mock_opensearch_client):
client = MLCommonClient(mock_opensearch_client)
assert client._client == mock_opensearch_client


def test_init():
model = SentenceTransformerModel(model_id="test-model", folder_path="/test/folder")
assert model.folder_path == "/test/folder"
assert model.model_id == "test-model"


clean_test_folder(TEST_FOLDER)
clean_test_folder(TESTDATA_UNZIP_FOLDER)