Replies: 1 comment 2 replies
-
Hey @schopra6! 👋 I'm here to help you with bugs, questions, and becoming a contributor. Let's tackle this together! To customize the LLM decoding with ChatOpenAPI in LangChain, you need to create a custom LLM class by implementing the import json
import logging
from typing import Any, Dict, List, Optional, Set
import requests
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from pydantic import Field
logger = logging.getLogger(__name__)
class CustomLLM(LLM):
"""A custom chat model that allows customization of the inference method and token decoding parameters."""
infer_api: str = "http://127.0.0.1:8000/custom"
max_tokens: int = Field(1024, alias="max_token")
temp: Optional[float] = 0.7
top_p: Optional[float] = 0.9
top_k: Optional[int] = 0
do_sample: bool = False
echo: Optional[bool] = False
stop: Optional[List[str]] = []
repeat_last_n: Optional[int] = 64
repeat_penalty: Optional[float] = 1.18
streaming: bool = False
history: List[str] = []
use_history: bool = False
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
if (self.top_p or 0) > 0 and (self.top_k or 0) > 0:
logger.warning(
"top_p and top_k cannot be set simultaneously. "
"set top_k to 0 instead..."
)
self.top_k = 0
@property
def _llm_type(self) -> str:
return "CustomLLM"
@staticmethod
def _model_param_names() -> Set[str]:
return {
"max_tokens",
"temp",
"top_k",
"top_p",
"do_sample",
}
def _default_params(self) -> Dict[str, Any]:
return {
"do_sample": self.do_sample,
"infer_api": self.infer_api,
"max_tokens": self.max_tokens,
"repeat_penalty": self.repeat_penalty,
"temp": self.temp,
"top_k": self.top_k,
"top_p": self.top_p,
"use_history": self.use_history,
}
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
"model": self._llm_type,
**self._default_params(),
**{
k: v for k, v in self.__dict__.items() if k in self._model_param_names()
},
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.use_history:
self.history.append(prompt)
input = "<n>".join(self.history)
else:
input = prompt
headers = {"Content-Type": "application/json"}
data = json.dumps(
{
"ques_list": [{"id": "000", "ques": input}],
"tokens_to_generate": self.max_tokens,
"temperature": self.temp,
"top_p": self.top_p,
"top_k": self.top_k,
"do_sample": self.do_sample,
}
)
logger.debug("CustomLLM prompt:", input)
try:
response = requests.put(self.infer_api, headers=headers, data=data)
except requests.exceptions.RequestException as e:
raise ValueError(f"Error raised by inference api: {e}")
logger.debug(f"CustomLLM response: {response}")
if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")
try:
resp = response.json()
if resp["errCode"] != "0":
raise ValueError(
f"Failed with error code [{resp['errCode']}], "
f"error message: [{resp['exceptionMsg']}]"
)
if "resData" in resp:
if len(resp["resData"]["output"]) >= 0:
generate_text = resp["resData"]["output"][0]["ans"]
else:
raise ValueError("No output found in response.")
else:
raise ValueError("No resData found in response.")
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised during decoding response from inference api: {e}."
f"\nResponse: {response.text}"
)
if stop is not None:
generate_text = enforce_stop_tokens(generate_text, stop)
if self.use_history:
self.history.append(generate_text)
logger.debug(f"history: {self.history}")
return generate_text
# Example usage
llm = CustomLLM(infer_api="http://127.0.0.1:8000/custom", max_tokens=50, temp=0.8, top_p=0.95)
result = llm.invoke("This is a foobar thing")
print(result) In this example, the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Checked other resources
Commit to Help
Example Code
Description
System Info
langchain==0.2.14
langchain-cli==0.0.30
langchain-community==0.2.12
langchain-core==0.2.35
langchain-experimental==0.0.64
langchain-openai==0.1.22
langchain-text-splitters==0.2.2
Beta Was this translation helpful? Give feedback.
All reactions