Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions mindsdb_sdk/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from mindsdb_sdk.connectors.rest_api import RestAPI


class Config():
"""
**Configuration for MindsDB**

This class provides methods to set and get the various configuration aspects of MindsDB.

Working with configuration:

Set default LLM configuration:

>>> server.config.set_default_llm(
... provider='openai',
... model_name='gpt-4',
... api_key='sk-...'
... )

Get default LLM configuration:

>>> llm_config = server.config.get_default_llm()
>>> print(llm_config)

Set default embedding model:

>>> server.config.set_default_embedding_model(
... provider='openai',
... model_name='text-embedding-ada-002',
... api_key='sk-...'
... )

Get default embedding model:

>>> embedding_config = server.config.get_default_embedding_model()

Set default reranking model:

>>> server.config.set_default_reranking_model(
... provider='openai',
... model_name='gpt-4',
... api_key='sk-...'
... )

Get default reranking model:

>>> reranking_config = server.config.get_default_reranking_model()
"""
def __init__(self, api: RestAPI):
self.api = api

def set_default_llm(
self,
provider: str,
model_name: str,
api_key: str = None,
**kwargs
):
"""
Set the default LLM configuration for MindsDB.

:param provider: The name of the LLM provider (e.g., 'openai', 'google').
:param model_name: The name of the model to use.
:param api_key: Optional API key for the provider.
:param kwargs: Additional parameters for the LLM configuration.
"""
config = {
"default_llm": {
"provider": provider,
"model_name": model_name,
"api_key": api_key,
**kwargs
}
}
self.api.update_config(config)

def get_default_llm(self):
"""
Get the default LLM configuration for MindsDB.

:return: Dictionary containing the default LLM configuration.
"""
return self.api.get_config().get("default_llm", {})

def set_default_embedding_model(
self,
provider: str,
model_name: str,
api_key: str = None,
**kwargs
):
"""
Set the default embedding model configuration for MindsDB.

:param provider: The name of the embedding model provider (e.g., 'openai', 'google').
:param model_name: The name of the embedding model to use.
:param api_key: Optional API key for the provider.
:param kwargs: Additional parameters for the embedding model configuration.
"""
config = {
"default_embedding_model": {
"provider": provider,
"model_name": model_name,
"api_key": api_key,
**kwargs
}
}
self.api.update_config(config)

def get_default_embedding_model(self):
"""
Get the default embedding model configuration for MindsDB.

:return: Dictionary containing the default embedding model configuration.
"""
return self.api.get_config().get("default_embedding_model", {})

def set_default_reranking_model(
self,
provider: str,
model_name: str,
api_key: str = None,
**kwargs
):
"""
Set the default reranking model configuration for MindsDB.

:param provider: The name of the reranking model provider (e.g., 'openai', 'google').
:param model_name: The name of the reranking model to use.
:param api_key: Optional API key for the provider.
:param kwargs: Additional parameters for the reranking model configuration.
"""
config = {
"default_reranking_model": {
"provider": provider,
"model_name": model_name,
"api_key": api_key,
**kwargs
}
}
self.api.update_config(config)

def get_default_reranking_model(self):
"""
Get the default reranking model configuration for MindsDB.

:return: Dictionary containing the default reranking model configuration.
"""
return self.api.get_config().get("default_reranking_model", {})

21 changes: 21 additions & 0 deletions mindsdb_sdk/connectors/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,24 @@ def knowledge_base_completion(self, project: str, knowledge_base_name, payload):
)
_raise_for_status(r)
return r.json()

def get_config(self):
"""
Get MindsDB configuration.

:return: Dictionary containing MindsDB configuration.
"""
url = self.url + '/api/config'
r = self.session.get(url)
_raise_for_status(r)
return r.json()

def update_config(self, config: dict):
"""
Update MindsDB configuration with the provided settings.

:param config: Dictionary containing configuration settings.
"""
url = self.url + '/api/config'
r = self.session.put(url, json=config)
_raise_for_status(r)
3 changes: 3 additions & 0 deletions mindsdb_sdk/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .ml_engines import MLEngines
from .handlers import Handlers
from .skills import Skills
from .config import Config


class Server(Project):
Expand Down Expand Up @@ -48,6 +49,8 @@ def __init__(self, api, skills: Skills = None, agents: Agents = None):
self.ml_handlers = Handlers(self.api, 'ml')
self.data_handlers = Handlers(self.api, 'data')

self.config = Config(api)

def status(self) -> dict:
"""
Get server information. It could content version
Expand Down
103 changes: 103 additions & 0 deletions tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,7 @@ def test_add_database(self, mock_post, mock_put, mock_get):
}
assert agent_update_json == expected_agent_json


class TestSkills():
@patch('requests.Session.get')
def test_list(self, mock_get):
Expand Down Expand Up @@ -1896,3 +1897,105 @@ def test_delete(self, mock_delete):
server.skills.drop('test_skill')
# Check API call.
assert mock_delete.call_args[0][0] == f'{DEFAULT_LOCAL_API_URL}/api/projects/mindsdb/skills/test_skill'


class TestConfig():
@patch('requests.Session.put')
@patch('requests.Session.get')
def test_set_and_get_default_llm(self, mock_get, mock_put):
server = mindsdb_sdk.connect()
response_mock(mock_put, {})
response_mock(mock_get, {
'default_llm': {
'provider': 'openai',
'model_name': 'gpt-4',
'api_key': 'sk-test123'
}
})

server.config.set_default_llm(
provider='openai',
model_name='gpt-4',
api_key='sk-test123'
)
assert mock_put.call_args[1]['json'] == {
'default_llm': {
'provider': 'openai',
'model_name': 'gpt-4',
'api_key': 'sk-test123'
}
}

llm_config = server.config.get_default_llm()
assert llm_config == {
'provider': 'openai',
'model_name': 'gpt-4',
'api_key': 'sk-test123'
}

@patch('requests.Session.put')
@patch('requests.Session.get')
def test_set_and_get_default_embedding_model(self, mock_get, mock_put):
server = mindsdb_sdk.connect()
response_mock(mock_put, {})
response_mock(mock_get, {
'default_embedding_model': {
'provider': 'openai',
'model_name': 'text-embedding-ada-002',
'api_key': 'sk-test456'
}
})

server.config.set_default_embedding_model(
provider='openai',
model_name='text-embedding-ada-002',
api_key='sk-test456'
)
assert mock_put.call_args[1]['json'] == {
'default_embedding_model': {
'provider': 'openai',
'model_name': 'text-embedding-ada-002',
'api_key': 'sk-test456'
}
}

embedding_config = server.config.get_default_embedding_model()
assert embedding_config == {
'provider': 'openai',
'model_name': 'text-embedding-ada-002',
'api_key': 'sk-test456'
}

@patch('requests.Session.put')
@patch('requests.Session.get')
def test_set_and_get_default_reranking_model(self, mock_get, mock_put):
server = mindsdb_sdk.connect()
response_mock(mock_put, {})
response_mock(mock_get, {
'default_reranking_model': {
'provider': 'cohere',
'model_name': 'rerank-english-v2.0',
'api_key': 'cohere-test789'
}
})

server.config.set_default_reranking_model(
provider='cohere',
model_name='rerank-english-v2.0',
api_key='cohere-test789'
)
assert mock_put.call_args[1]['json'] == {
'default_reranking_model': {
'provider': 'cohere',
'model_name': 'rerank-english-v2.0',
'api_key': 'cohere-test789'
}
}

reranking_config = server.config.get_default_reranking_model()
assert reranking_config == {
'provider': 'cohere',
'model_name': 'rerank-english-v2.0',
'api_key': 'cohere-test789'
}