Skip to content

Commit 1fb5b9a

Browse files
authored
Merge pull request #421 from devchat-ai/add_max_tokens_for_llm_api
Add max_tokens configuration for LLM API
2 parents ed1a2ac + 8b6dd86 commit 1fb5b9a

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

devchat/llm/openai.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
import json
88
import os
99
import re
10-
from typing import Dict, List
10+
from pathlib import Path
11+
from typing import Any, Dict, List
1112

1213
import httpx
1314
import openai
15+
import oyaml as yaml
1416

1517
from devchat.ide import IDEService
18+
from devchat.workflow.path import CHAT_CONFIG_FILENAME, CHAT_DIR
1619

1720
from .pipeline import (
1821
RetryException, # Import RetryException class
@@ -42,6 +45,35 @@ def _try_remove_markdown_block_flag(content):
4245
return content
4346

4447

48+
# 模块级变量用于缓存配置
49+
_chat_config: Dict[str, Any] = {}
50+
51+
52+
def _load_chat_config() -> None:
53+
"""加载聊天配置到全局变量"""
54+
global _chat_config
55+
chat_config_path = Path(CHAT_DIR) / CHAT_CONFIG_FILENAME
56+
with open(chat_config_path, "r", encoding="utf-8") as file:
57+
_chat_config = yaml.safe_load(file)
58+
59+
60+
def get_maxtokens_by_model(model: str) -> int:
61+
# 如果配置还没有加载,则加载配置
62+
if not _chat_config:
63+
_load_chat_config()
64+
65+
# 默认值设置为1024
66+
default_max_tokens = 1024
67+
68+
# 检查模型是否在配置中
69+
if model in _chat_config.get("models", {}):
70+
# 如果模型存在,尝试获取max_tokens,如果不存在则返回默认值
71+
return _chat_config["models"][model].get("max_tokens", default_max_tokens)
72+
else:
73+
# 如果模型不在配置中,返回默认值
74+
return default_max_tokens
75+
76+
4577
def chat_completion_stream_commit(
4678
messages: List[Dict], # [{"role": "user", "content": "hello"}]
4779
llm_config: Dict, # {"model": "...", ...}
@@ -62,6 +94,7 @@ def chat_completion_stream_commit(
6294
# Update llm_config dictionary
6395
llm_config["stream"] = True
6496
llm_config["timeout"] = 60
97+
llm_config["max_tokens"] = get_maxtokens_by_model(llm_config["model"])
6598
# Return chat completions
6699
return client.chat.completions.create(messages=messages, **llm_config)
67100

@@ -83,6 +116,7 @@ def chat_completion_stream_raw(**kwargs):
83116
# Update kwargs dictionary
84117
kwargs["stream"] = True
85118
kwargs["timeout"] = 60
119+
kwargs["max_tokens"] = get_maxtokens_by_model(kwargs["model"])
86120
# Return chat completions
87121
return client.chat.completions.create(**kwargs)
88122

0 commit comments

Comments
 (0)