diff --git a/dataflow/serving/api_llm_serving_request.py b/dataflow/serving/api_llm_serving_request.py index c999a692..f50ac61c 100644 --- a/dataflow/serving/api_llm_serving_request.py +++ b/dataflow/serving/api_llm_serving_request.py @@ -26,14 +26,17 @@ def __init__(self, max_workers: int = 10, max_retries: int = 5, timeout: tuple[float, float] = (10.0, 120.0), # connect timeout, read timeout + **configs ): # Get API key from environment variable or config self.api_url = api_url self.model_name = model_name - self.temperature = temperature + # self.temperature = temperature self.max_workers = max_workers self.max_retries = max_retries self.timeout = timeout + self.configs = configs + self.configs.update({"temperature": temperature}) self.logger = get_logger() @@ -125,17 +128,17 @@ def _api_chat_with_id( start = time.time() try: if is_embedding: - payload = json.dumps({ + payload = { "model": model, "input": payload - }) + } elif json_schema is None: - payload = json.dumps({ + payload = { "model": model, "messages": payload - }) + } else: - payload = json.dumps({ + payload = { "model": model, "messages": payload, "response_format": { @@ -146,7 +149,10 @@ def _api_chat_with_id( "schema": json_schema } } - }) + } + + payload.update(self.configs) + payload = json.dumps(payload) # Make a POST request to the API response = self.session.post(self.api_url, headers=self.headers, data=payload, timeout=self.timeout) cost = time.time() - start