|
| 1 | +# coding=utf-8 |
| 2 | +import base64 |
| 3 | +import os |
| 4 | +from typing import Dict |
| 5 | + |
| 6 | +from langchain_core.messages import HumanMessage |
| 7 | + |
| 8 | +from common import forms |
| 9 | +from common.exception.app_exception import AppApiException |
| 10 | +from common.forms import BaseForm, TooltipLabel |
| 11 | +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode |
| 12 | +from django.utils.translation import gettext_lazy as _ |
| 13 | + |
| 14 | +class VllmImageModelParams(BaseForm): |
| 15 | + temperature = forms.SliderField(TooltipLabel(_('Temperature'), |
| 16 | + _('Higher values make the output more random, while lower values make it more focused and deterministic')), |
| 17 | + required=True, default_value=0.7, |
| 18 | + _min=0.1, |
| 19 | + _max=1.0, |
| 20 | + _step=0.01, |
| 21 | + precision=2) |
| 22 | + |
| 23 | + max_tokens = forms.SliderField( |
| 24 | + TooltipLabel(_('Output the maximum Tokens'), |
| 25 | + _('Specify the maximum number of tokens that the model can generate')), |
| 26 | + required=True, default_value=800, |
| 27 | + _min=1, |
| 28 | + _max=100000, |
| 29 | + _step=1, |
| 30 | + precision=0) |
| 31 | + |
| 32 | + |
| 33 | + |
| 34 | +class VllmImageModelCredential(BaseForm, BaseModelCredential): |
| 35 | + api_base = forms.TextInputField('API Url', required=True) |
| 36 | + api_key = forms.PasswordInputField('API Key', required=True) |
| 37 | + |
| 38 | + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, |
| 39 | + raise_exception=False): |
| 40 | + model_type_list = provider.get_model_type_list() |
| 41 | + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): |
| 42 | + raise AppApiException(ValidCode.valid_error.value, _('{model_type} Model type is not supported').format(model_type=model_type)) |
| 43 | + |
| 44 | + for key in ['api_base', 'api_key']: |
| 45 | + if key not in model_credential: |
| 46 | + if raise_exception: |
| 47 | + raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key)) |
| 48 | + else: |
| 49 | + return False |
| 50 | + try: |
| 51 | + model = provider.get_model(model_type, model_name, model_credential, **model_params) |
| 52 | + res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) |
| 53 | + for chunk in res: |
| 54 | + print(chunk) |
| 55 | + except Exception as e: |
| 56 | + if isinstance(e, AppApiException): |
| 57 | + raise e |
| 58 | + if raise_exception: |
| 59 | + raise AppApiException(ValidCode.valid_error.value, _('Verification failed, please check whether the parameters are correct: {error}').format(error=str(e))) |
| 60 | + else: |
| 61 | + return False |
| 62 | + return True |
| 63 | + |
| 64 | + def encryption_dict(self, model: Dict[str, object]): |
| 65 | + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} |
| 66 | + |
| 67 | + def get_model_params_setting_form(self, model_name): |
| 68 | + return VllmImageModelParams() |
0 commit comments