-
Notifications
You must be signed in to change notification settings - Fork 2.6k
pref: Generate prompt #4127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pref: Generate prompt #4127
Conversation
|
Adding the "do-not-merge/release-note-label-needed" label because no release-note block was detected, please follow our release note process to remove it. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes-sigs/prow repository. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
| content=m.get('content')) for m in messages]]): | ||
| yield 'data: ' + json.dumps({'content': r.content}) + '\n\n' | ||
|
|
||
| return to_stream_response_simple(process()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your code seems generally clean and well-organized. Here are a few suggestions:
-
Imports: Ensure you import all functions and classes used before they are needed.
-
Code Reorganization:
- The
PromptGenerateSerializermethods could be separated for better modularization.
- The
-
Error Handling:
- Consider adding more descriptive error messages or handling potential exceptions more gracefully, especially around file reading (use try-except blocks).
-
Performance Optimization:
- Check if there might be unnecessary calculations or computations that can be optimized (e.g., caching results where applicable).
-
Security Considerations:
- Be cautious about paths involving environment variables like
PROJECT_DIR. Avoid hardcoding sensitive information directly in the path strings.
- Be cautious about paths involving environment variables like
Here is an improved version of the key areas mentioned above:
# Importing required libraries at the top
import os
import json
from gettext import gettext as _
from typing import List, Dict
import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.utils.translation importgettext_lazy as _
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from rest_framework import serializers
from application.chat_pipeline.pipeline_manage import PipelineManage
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.openai_to_response import OpenaiToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from common.utils.common import flat_map, get_file_content
from knowledge.models import Document, Paragraph
from maxkb.conf import PROJECT_DIR
from models_provider.models import Model, Status
from models_provider.tools import get_model_instance_by_model_workspace_id
class AppApiException(Exception):
"""Custom exception class."""
pass
def is_valid(self, *, raise_exception=False):
# Your existing validation logic here
class ChatMessageSerializers(serializers.Serializer):
message = serializers.CharField(required=True, label=_("User Questions"))
stream = serializers.BooleanField(required=True,
help_text=_("Indicates whether responses should be streamed."))
class PromptGenerateSerializer(serializers.Serializer):
workspace_id = serializers.CharField(required=False, label=_('Workspace ID'))
model_id = serializers.CharField(required=True, label=_('Model ID'))
prompt = serializers.JSONField(label=_('prompt'))
def is_valid_application_exists(self, query_set):
"""Check if an application exists based on its ID."""
if not query_set.exists():
raise AppApiException(500, _('Application id does not exist'))
def generate_prompt(self, instance: dict) -> str:
self.is_valid(raise_exception=True)
self.validate()
app = self.is_valid_application_exists(QuerySet(Application).filter(id=self.validated_data['application_id']))
workspace_id = self.validated_data.get('workspace_id')
model_id = self.validated_data.get('model_id')
prompt = instance.get('prompt', "")
system_content = get_file_content(os.path.join(PROJECT_DIR, "apps", "chat", 'template', 'generate_prompt_system')).format(
application_name=app.name, desc=app.desc
)
model_existence_check = QuerySet(Model).filter(id=model_id, model_type="LLM").exists()
if not model_existence_check:
raise Exception(_("Model does not exists or is not an LLM model"))
model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id, **app.model_params_setting)
def process() -> str:
for response in model.stream([
SystemMessage(content=system_content),
*[HumanMessage(content=m.get('content')) if m.get('role') == 'user' else AIMessage(content=m.get('content')) for m in instance]']):
yield f"data: {json.dumps({'completion': response.content})}\n\n"
return to_stream_response_simple(process())
def validate(self):
"""Additional validation checks such as model type check."""
model_id = self.validated_data.get('model_id')
model_type = get_model_instance_by_model_workspace_id(model_id=model_id)._get_meta().verbose_name
if model_type != 'LLM':
self.add_error('Invalid model specified:', 'This endpoint only supports LLM models.')Key Changes:
- Moved Imports: Moved imports to the beginning of each module.
- Separated Method Calls: Slightly restructured method calls within the
generate_promptmethod. - Added Validation Application Exists: A helper method that encapsulates the validation step ensuring the existence of an application.
- Detailed Error Message: Added a description/help text to the serializer fields for clarity.
- Encapsulated Additional Checks: Used another method (
validate) to handle additional validations beyond basic data validation by calling.is_valid()initially, which then runs other checks.
These changes ensure robustness and maintainability while adhering to best practices in coding conventions and security considerations.
| model = provider.get_model(model_type, model_name, model_credential, **model_params) | ||
| model.check_auth() | ||
| except Exception as e: | ||
| traceback.print_exc() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The provided code has a few potential issues and improvements:
-
Missing
model_params:- The function
is_validcallsprovider.get_model()without passing arguments formodel_params. This might cause an error ifget_modelexpects additional parameters.
- The function
-
Potential Exception Handling:
- While catching exceptions (
Exception) can be useful for handling unexpected errors, it's important to handle specific exception types rather than justExceptionto get more detailed information. For example, you could catchValueError,TypeError, etc.
- While catching exceptions (
-
Logging Instead of Printing Stack Trace:
- Using
traceback.print_exc()prints the stack trace to the console, which is generally not recommended when integrating into larger systems or APIs where logs should be managed centrally.
- Using
-
Docstring Consistency:
- Ensure that all methods have docstrings to describe their purpose and inputs/outputs.
Here are some improvement suggestions:
def is_valid(self,
model_type: str,
model_name: str,
model_credential: dict,
model_params=None) -> bool:
"""
Validates whether the given model configuration is acceptable.
Args:
model_type (str): Type of the model (e.g., 'transformer', 'gpt').
model_name (str): Name of the model configuration file.
model_credential (dict): Credentials required to access the model.
model_params (dict, optional): Additional parameters for model instantiation.
Returns:
bool: True if valid, otherwise False.
Raises:
ValueError: If model type is invalid.
KeyError: If any expected parameter is missing from model_credentials.
TypeError: If wrong data types are used for parameters.
"""
if "name" not in model_credentials or "credential" not in model_credentials:
raise KeyError("Model credentials must contain at least 'name' and 'credential'.")
try:
model = provider.get_model(model_type, model_name, model_credential, **(model_params or {}))
model.check_auth()
return True
except ValueError as ve:
print(f"Invalid input value: {ve}")
return False
except TypeError as te:
print(f"Type mismatch: {te}")
return False
except Exception as e:
print(f"An error occurred: {e}")
print(traceback.format_exc()) # Consider logging instead
return FalseKey Changes:
- Added
selfas the first argument to indicate this is an instance method. - Fixed the typo in
modelParamstomodel_params. - Ensured
model_paramsis passed with a checkmodel_params or {}to avoid aKeyError. - Improved exception handling by catching specific types and providing user-friendly messages.
- Updated the docstring to include descriptions and parameters clearly.
pref: Generate prompt