Skip to content

Commit 1666baa

Browse files
authored
#284: Add support for connectors (#345)
* add connectors Signed-off-by: kalyan <[email protected]> * update Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * rename Signed-off-by: kalyanr <[email protected]> * add tests Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * fix Signed-off-by: kalyan <[email protected]> * lint fix Signed-off-by: kalyan <[email protected]> * update changelog Signed-off-by: kalyan <[email protected]> * increase test coverage Signed-off-by: kalyan <[email protected]> --------- Signed-off-by: kalyan <[email protected]> Signed-off-by: kalyanr <[email protected]>
1 parent 2031705 commit 1666baa

File tree

4 files changed

+212
-0
lines changed

4 files changed

+212
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
1010
- Add example notebook for tracing and registering a CLIPTextModel to OpenSearch with the Neural Search plugin by @patrickbarnhart in ([#283](https://github.com/opensearch-project/opensearch-py-ml/pull/283))
1111
- Add support for train api functionality by @rawwar in ([#310](https://github.com/opensearch-project/opensearch-py-ml/pull/310))
1212
- Add support for Model Access Control - Register, Update, Search and Delete by @rawwar in ([#332](https://github.com/opensearch-project/opensearch-py-ml/pull/332))
13+
- Add support for model connectors by @rawwar in ([#345](https://github.com/opensearch-project/opensearch-py-ml/pull/345))
1314

1415
### Changed
1516
- Modify ml-models.JenkinsFile so that it takes model format into account and can be triggered with generic webhook by @thanawan-atc in ([#211](https://github.com/opensearch-project/opensearch-py-ml/pull/211))

opensearch_py_ml/ml_commons/ml_commons_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TIMEOUT,
2323
)
2424
from opensearch_py_ml.ml_commons.model_access_control import ModelAccessControl
25+
from opensearch_py_ml.ml_commons.model_connector import Connector
2526
from opensearch_py_ml.ml_commons.model_execute import ModelExecute
2627
from opensearch_py_ml.ml_commons.model_uploader import ModelUploader
2728

@@ -37,6 +38,7 @@ def __init__(self, os_client: OpenSearch):
3738
self._model_uploader = ModelUploader(os_client)
3839
self._model_execute = ModelExecute(os_client)
3940
self.model_access_control = ModelAccessControl(os_client)
41+
self.connector = Connector(os_client)
4042

4143
def execute(self, algorithm_name: str, input_json: dict) -> dict:
4244
"""
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# The OpenSearch Contributors require contributions made to
3+
# this file be licensed under the Apache-2.0 license or a
4+
# compatible open source license.
5+
# Any modifications Copyright OpenSearch Contributors. See
6+
# GitHub history for details.
7+
8+
from opensearchpy import OpenSearch
9+
10+
from opensearch_py_ml.ml_commons.ml_common_utils import ML_BASE_URI
11+
12+
13+
class Connector:
14+
def __init__(self, os_client: OpenSearch):
15+
self.client = os_client
16+
17+
def create_standalone_connector(self, payload: dict):
18+
if not isinstance(payload, dict):
19+
raise ValueError("payload needs to be a dictionary")
20+
21+
return self.client.transport.perform_request(
22+
method="POST", url=f"{ML_BASE_URI}/connectors/_create", body=payload
23+
)
24+
25+
def list_connectors(self):
26+
search_query = {"query": {"match_all": {}}}
27+
return self.search_connectors(search_query)
28+
29+
def search_connectors(self, search_query: dict):
30+
if not isinstance(search_query, dict):
31+
raise ValueError("search_query needs to be a dictionary")
32+
33+
return self.client.transport.perform_request(
34+
method="POST", url=f"{ML_BASE_URI}/connectors/_search", body=search_query
35+
)
36+
37+
def get_connector(self, connector_id: str):
38+
if not isinstance(connector_id, str):
39+
raise ValueError("connector_id needs to be a string")
40+
41+
return self.client.transport.perform_request(
42+
method="GET", url=f"{ML_BASE_URI}/connectors/{connector_id}"
43+
)
44+
45+
def delete_connector(self, connector_id: str):
46+
if not isinstance(connector_id, str):
47+
raise ValueError("connector_id needs to be a string")
48+
49+
return self.client.transport.perform_request(
50+
method="DELETE", url=f"{ML_BASE_URI}/connectors/{connector_id}"
51+
)
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# The OpenSearch Contributors require contributions made to
3+
# this file be licensed under the Apache-2.0 license or a
4+
# compatible open source license.
5+
# Any modifications Copyright OpenSearch Contributors. See
6+
# GitHub history for details.
7+
8+
import os
9+
10+
import pytest
11+
from opensearchpy.exceptions import NotFoundError, RequestError
12+
from packaging.version import parse as parse_version
13+
14+
from opensearch_py_ml.ml_commons.model_connector import Connector
15+
from tests import OPENSEARCH_TEST_CLIENT
16+
17+
OPENSEARCH_VERSION = parse_version(os.environ.get("OPENSEARCH_VERSION", "2.11.0"))
18+
CONNECTOR_MIN_VERSION = parse_version("2.9.0")
19+
20+
21+
@pytest.fixture
22+
def client():
23+
return Connector(OPENSEARCH_TEST_CLIENT)
24+
25+
26+
def _safe_delete_connector(client, connector_id):
27+
try:
28+
client.delete_connector(connector_id=connector_id)
29+
except NotFoundError:
30+
pass
31+
32+
33+
@pytest.fixture
34+
def connector_payload():
35+
return {
36+
"name": "Test Connector",
37+
"description": "Connector for testing",
38+
"version": 1,
39+
"protocol": "http",
40+
"parameters": {"endpoint": "api.openai.com", "model": "gpt-3.5-turbo"},
41+
"credential": {"openAI_key": "..."},
42+
"actions": [
43+
{
44+
"action_type": "predict",
45+
"method": "POST",
46+
"url": "https://${parameters.endpoint}/v1/chat/completions",
47+
"headers": {"Authorization": "Bearer ${credential.openAI_key}"},
48+
"request_body": '{ "model": "${parameters.model}", "messages": ${parameters.messages} }',
49+
}
50+
],
51+
}
52+
53+
54+
@pytest.fixture
55+
def test_connector(client: Connector, connector_payload: dict):
56+
res = client.create_standalone_connector(connector_payload)
57+
connector_id = res["connector_id"]
58+
yield connector_id
59+
60+
_safe_delete_connector(client, connector_id)
61+
62+
63+
@pytest.mark.skipif(
64+
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
65+
reason="Connectors are supported in OpenSearch 2.9.0 and above",
66+
)
67+
def test_create_standalone_connector(client: Connector, connector_payload: dict):
68+
res = client.create_standalone_connector(connector_payload)
69+
assert "connector_id" in res
70+
71+
_safe_delete_connector(client, res["connector_id"])
72+
73+
with pytest.raises(ValueError):
74+
client.create_standalone_connector("")
75+
76+
77+
@pytest.mark.skipif(
78+
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
79+
reason="Connectors are supported in OpenSearch 2.9.0 and above",
80+
)
81+
def test_list_connectors(client, test_connector):
82+
try:
83+
res = client.list_connectors()
84+
assert len(res["hits"]["hits"]) > 0
85+
86+
# check if test_connector id is in the response
87+
found = False
88+
for each in res["hits"]["hits"]:
89+
if each["_id"] == test_connector:
90+
found = True
91+
break
92+
assert found, "Test connector not found in list connectors response"
93+
except Exception as ex:
94+
assert False, f"Failed to list connectors due to {ex}"
95+
96+
97+
@pytest.mark.skipif(
98+
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
99+
reason="Connectors are supported in OpenSearch 2.9.0 and above",
100+
)
101+
def test_search_connectors(client, test_connector):
102+
try:
103+
query = {"query": {"match": {"name": "Test Connector"}}}
104+
res = client.search_connectors(query)
105+
assert len(res["hits"]["hits"]) > 0
106+
107+
# check if test_connector id is in the response
108+
found = False
109+
for each in res["hits"]["hits"]:
110+
if each["_id"] == test_connector:
111+
found = True
112+
break
113+
assert found, "Test connector not found in search connectors response"
114+
except Exception as ex:
115+
assert False, f"Failed to search connectors due to {ex}"
116+
117+
with pytest.raises(ValueError):
118+
client.search_connectors("test")
119+
120+
121+
@pytest.mark.skipif(
122+
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
123+
reason="Connectors are supported in OpenSearch 2.9.0 and above",
124+
)
125+
def test_get_connector(client, test_connector):
126+
try:
127+
res = client.get_connector(connector_id=test_connector)
128+
assert res["name"] == "Test Connector"
129+
except Exception as ex:
130+
assert False, f"Failed to get connector due to {ex}"
131+
132+
with pytest.raises(ValueError):
133+
client.get_connector(connector_id=None)
134+
135+
with pytest.raises(RequestError) as exec_info:
136+
client.get_connector(connector_id="test-unknown")
137+
assert exec_info.value.status_code == 400
138+
139+
140+
@pytest.mark.skipif(
141+
OPENSEARCH_VERSION < CONNECTOR_MIN_VERSION,
142+
reason="Connectors are supported in OpenSearch 2.9.0 and above",
143+
)
144+
def test_delete_connector(client, test_connector):
145+
try:
146+
res = client.delete_connector(connector_id=test_connector)
147+
assert res["result"] == "deleted"
148+
except Exception as ex:
149+
assert False, f"Failed to delete connector due to {ex}"
150+
151+
try:
152+
res = client.delete_connector(connector_id="unknown")
153+
assert res["result"] == "not_found"
154+
except Exception as ex:
155+
assert False, f"Failed to delete connector due to {ex}"
156+
157+
with pytest.raises(ValueError):
158+
client.delete_connector(connector_id={"test": "fail"})

0 commit comments

Comments
 (0)