Skip to content

Commit a0f2d4a

Browse files
authored
Disable shared library by default. Set default max_length in api server. (#317)
1 parent c9a4a70 commit a0f2d4a

File tree

8 files changed

+44
-6
lines changed

8 files changed

+44
-6
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ option(CHATGLM_ENABLE_EXAMPLES "chatglm: enable c++ examples" ON)
1717
option(CHATGLM_ENABLE_PYBIND "chatglm: enable python binding" OFF)
1818
option(CHATGLM_ENABLE_TESTING "chatglm: enable testing" OFF)
1919

20+
set(BUILD_SHARED_LIBS OFF CACHE BOOL "")
2021
if (CHATGLM_ENABLE_PYBIND)
2122
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE)
2223
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

Dockerfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ RUN \
4747
rm -rf /var/lib/apt/lists/*
4848

4949
COPY --from=build /chatglm.cpp/build/bin/main /chatglm.cpp/build/bin/main
50-
COPY --from=build /chatglm.cpp/build/lib/*.so /chatglm.cpp/build/lib/
5150
COPY --from=build /chatglm.cpp/dist/ /chatglm.cpp/dist/
5251

5352
ADD examples examples

chatglm_cpp/langchain_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
class Settings(BaseSettings):
1414
model: str = "models/chatglm-ggml.bin"
15+
max_length: int = 4096
1516

1617

1718
class ChatRequest(BaseModel):
@@ -48,7 +49,7 @@ class ChatResponse(BaseModel):
4849
settings = Settings()
4950
logging.info(settings)
5051

51-
pipeline = chatglm_cpp.Pipeline(settings.model)
52+
pipeline = chatglm_cpp.Pipeline(settings.model, max_length=settings.max_length)
5253

5354

5455
@app.post("/")

chatglm_cpp/openai_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
class Settings(BaseSettings):
1919
model: str = "models/chatglm3-ggml.bin"
20+
max_length: int = 4096
2021
num_threads: int = 0
2122

2223

@@ -129,7 +130,7 @@ class ChatCompletionResponse(BaseModel):
129130
app.add_middleware(
130131
CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
131132
)
132-
pipeline = chatglm_cpp.Pipeline(settings.model)
133+
pipeline = chatglm_cpp.Pipeline(settings.model, max_length=settings.max_length)
133134
lock = asyncio.Lock()
134135

135136

examples/cli_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main() -> None:
6363
if args.sp:
6464
system = args.sp.read_text()
6565

66-
pipeline = chatglm_cpp.Pipeline(args.model)
66+
pipeline = chatglm_cpp.Pipeline(args.model, max_length=args.max_length)
6767

6868
if args.mode != "chat" and args.interactive:
6969
print("interactive demo is only supported for chat mode, falling back to non-interactive one")

examples/web_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
parser.add_argument("--plain", action="store_true", help="display in plain text without markdown support")
2222
args = parser.parse_args()
2323

24-
pipeline = chatglm_cpp.Pipeline(args.model)
24+
pipeline = chatglm_cpp.Pipeline(args.model, max_length=args.max_length)
2525

2626

2727
def postprocess(text):

tests/perplexity.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ static float cross_entropy(const ggml_tensor *input, const ggml_tensor *target)
9494
// reference: https://huggingface.co/docs/transformers/perplexity
9595
static void perplexity(Args &args) {
9696
std::cout << "Loading model from " << args.model_path << " ...\n";
97-
chatglm::Pipeline pipeline(args.model_path);
97+
chatglm::Pipeline pipeline(args.model_path, args.max_length);
9898

9999
std::cout << "Loading corpus from " << args.corpus_path << " ...\n";
100100
std::string corpus = read_text(args.corpus_path);

tests/test_chatglm_cpp.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,39 @@ def test_internlm7b_pipeline():
147147
@pytest.mark.skipif(not INTERNLM20B_MODEL_PATH.exists(), reason="model file not found")
148148
def test_internlm20b_pipeline():
149149
check_pipeline(model_path=INTERNLM20B_MODEL_PATH, prompt="你好", target="你好!有什么我可以帮助你的吗?")
150+
151+
152+
@pytest.mark.skipif(not CHATGLM4_MODEL_PATH.exists(), reason="model file not found")
153+
def test_langchain_api():
154+
import os
155+
from unittest.mock import patch
156+
157+
from fastapi.testclient import TestClient
158+
159+
with patch.dict(os.environ, {"MODEL": str(CHATGLM4_MODEL_PATH)}):
160+
from chatglm_cpp.langchain_api import app
161+
162+
client = TestClient(app)
163+
response = client.post("/", json={"prompt": "你好", "temperature": 0})
164+
assert response.status_code == 200
165+
assert response.json()["response"] == "你好👋!有什么可以帮助你的吗?"
166+
167+
168+
@pytest.mark.skipif(not CHATGLM4_MODEL_PATH.exists(), reason="model file not found")
169+
def test_openai_api():
170+
import os
171+
from unittest.mock import patch
172+
173+
from fastapi.testclient import TestClient
174+
175+
with patch.dict(os.environ, {"MODEL": str(CHATGLM4_MODEL_PATH)}):
176+
from chatglm_cpp.openai_api import app
177+
178+
client = TestClient(app)
179+
response = client.post(
180+
"/v1/chat/completions", json={"messages": [{"role": "user", "content": "你好"}], "temperature": 0}
181+
)
182+
assert response.status_code == 200
183+
response_message = response.json()["choices"][0]["message"]
184+
assert response_message["role"] == "assistant"
185+
assert response_message["content"] == "你好👋!有什么可以帮助你的吗?"

0 commit comments

Comments
 (0)