Skip to content
This repository was archived by the owner on Dec 6, 2024. It is now read-only.

Commit 9d7bee7

Browse files
committed
Bump version to 0.1.2
1 parent c54bf25 commit 9d7bee7

File tree

7 files changed

+187
-5
lines changed

7 files changed

+187
-5
lines changed

MANIFEST.in

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
global-include CMakeLists.txt *.cmake README.md LICENSE
2+
include *.cpp *.h
3+
4+
# absl
5+
graft third_party/abseil-cpp/absl
6+
graft third_party/abseil-cpp/CMake
7+
include third_party/abseil-cpp/*
8+
9+
# re2
10+
graft third_party/re2/re2
11+
graft third_party/re2/util
12+
include third_party/re2/*
13+
14+
# ggml
15+
graft third_party/ggml/include
16+
graft third_party/ggml/src
17+
include third_party/ggml/*
18+
19+
# pybind11
20+
graft third_party/pybind11/include
21+
graft third_party/pybind11/tools

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ The Python binding provides high-level `chat` and `stream_chat` interface simila
7979

8080
**Installation**
8181

82-
Install from PyPI (recommended): WIP.
82+
Install from PyPI (recommended): will trigger compilation on your platform.
83+
```sh
84+
pip install -U qwen-cpp
85+
```
8386

8487
You may also install from source.
8588
```sh

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,4 @@ dynamic = ["version"]
3333
[project.urls]
3434
Homepage = "https://github.com/QwenLM/qwen.cpp"
3535
Repository = "https://github.com/QwenLM/qwen.cpp.git"
36+
BugTracker = "https://github.com/QwenLM/qwen.cpp/issues"

qwen_cpp/__init__.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import tempfile
2+
from pathlib import Path
3+
from typing import Iterator, List, Optional, Union
4+
5+
import qwen_cpp._C as _C
6+
7+
8+
class Pipeline(_C.Pipeline):
9+
def __init__(
10+
self, model_path: str, tiktoken_path: str, *, dtype: Optional[str] = None
11+
) -> None:
12+
if Path(model_path).is_file() and Path(tiktoken_path).is_file():
13+
super().__init__(str(model_path), str(tiktoken_path))
14+
else:
15+
from qwen_cpp.convert import convert
16+
17+
if dtype is None:
18+
dtype = "q4_0" # default dtype
19+
20+
with tempfile.NamedTemporaryFile("wb") as f:
21+
convert(f, model_path, dtype=dtype)
22+
super().__init__(f.name, str(tiktoken_path))
23+
24+
def chat(
25+
self,
26+
history: List[str],
27+
*,
28+
max_length: int = 2048,
29+
max_context_length: int = 512,
30+
do_sample: bool = True,
31+
top_k: int = 0,
32+
top_p: float = 0.7,
33+
temperature: float = 0.95,
34+
repetition_penalty: float = 1.0,
35+
num_threads: int = 0,
36+
stream: bool = False,
37+
) -> Union[Iterator[str], str]:
38+
input_ids = self.tokenizer.encode_history(history, max_context_length)
39+
return self._generate(
40+
input_ids=input_ids,
41+
max_length=max_length,
42+
max_context_length=max_context_length,
43+
do_sample=do_sample,
44+
top_k=top_k,
45+
top_p=top_p,
46+
temperature=temperature,
47+
repetition_penalty=repetition_penalty,
48+
num_threads=num_threads,
49+
stream=stream,
50+
)
51+
52+
def _generate(
53+
self,
54+
input_ids: List[int],
55+
*,
56+
max_length: int = 2048,
57+
max_context_length: int = 512,
58+
do_sample: bool = True,
59+
top_k: int = 0,
60+
top_p: float = 0.7,
61+
temperature: float = 0.95,
62+
repetition_penalty: float = 1.0,
63+
num_threads: int = 0,
64+
stream: bool = False,
65+
) -> Union[Iterator[str], str]:
66+
gen_config = _C.GenerationConfig(
67+
max_length=max_length,
68+
max_context_length=max_context_length,
69+
do_sample=do_sample,
70+
top_k=top_k,
71+
top_p=top_p,
72+
temperature=temperature,
73+
repetition_penalty=repetition_penalty,
74+
num_threads=num_threads,
75+
)
76+
77+
generate_fn = self._stream_generate if stream else self._sync_generate
78+
return generate_fn(input_ids=input_ids, gen_config=gen_config)
79+
80+
def _stream_generate(
81+
self, input_ids: List[int], gen_config: _C.GenerationConfig
82+
) -> Iterator[str]:
83+
input_ids = [x for x in input_ids] # make a copy
84+
n_past = 0
85+
n_ctx = len(input_ids)
86+
87+
token_cache = []
88+
print_len = 0
89+
while len(input_ids) < gen_config.max_length:
90+
next_token_id = self.model.generate_next_token(
91+
input_ids, gen_config, n_past, n_ctx
92+
)
93+
n_past = len(input_ids)
94+
input_ids.append(next_token_id)
95+
96+
token_cache.append(next_token_id)
97+
output = self.tokenizer.decode(token_cache)
98+
99+
if output.endswith("\n"):
100+
yield output[print_len:]
101+
token_cache = []
102+
print_len = 0
103+
elif output.endswith((",", "!", ":", ";", "?", "�")):
104+
pass
105+
else:
106+
yield output[print_len:]
107+
print_len = len(output)
108+
109+
if next_token_id in (
110+
self.model.config.eos_token_id,
111+
self.model.config.im_start_id,
112+
self.model.config.im_end_id,
113+
):
114+
break
115+
116+
output = self.tokenizer.decode(token_cache)
117+
yield output[print_len:]
118+
119+
def _sync_generate(
120+
self, input_ids: List[int], gen_config: _C.GenerationConfig
121+
) -> str:
122+
input_ids = [x for x in input_ids] # make a copy
123+
n_past = 0
124+
n_ctx = len(input_ids)
125+
126+
while len(input_ids) < gen_config.max_length:
127+
next_token_id = self.model.generate_next_token(
128+
input_ids, gen_config, n_past, n_ctx
129+
)
130+
n_past = len(input_ids)
131+
input_ids.append(next_token_id)
132+
if next_token_id in (
133+
self.model.config.eos_token_id,
134+
self.model.config.im_start_id,
135+
self.model.config.im_end_id,
136+
):
137+
break
138+
139+
output = self.tokenizer.decode(input_ids[n_ctx:])
140+
return output

qwen_pybind.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ PYBIND11_MODULE(_C, m) {
4040
.def("encode", &QwenTokenizer::encode)
4141
.def("decode", &QwenTokenizer::decode)
4242
.def("encode_history", &QwenTokenizer::encode_history);
43+
44+
py::class_<GenerationConfig>(m, "GenerationConfig")
45+
.def(py::init<int, int, bool, int, float, float, float, int>(), "max_length"_a = 2048,
46+
"max_context_length"_a = 512, "do_sample"_a = true, "top_k"_a = 0, "top_p"_a = 0.7, "temperature"_a = 0.95,
47+
"repetition_penalty"_a = 1.0, "num_threads"_a = 0)
48+
.def_readwrite("max_length", &GenerationConfig::max_length)
49+
.def_readwrite("max_context_length", &GenerationConfig::max_context_length)
50+
.def_readwrite("do_sample", &GenerationConfig::do_sample)
51+
.def_readwrite("top_k", &GenerationConfig::top_k)
52+
.def_readwrite("top_p", &GenerationConfig::top_p)
53+
.def_readwrite("temperature", &GenerationConfig::temperature)
54+
.def_readwrite("repetition_penalty", &GenerationConfig::repetition_penalty)
55+
.def_readwrite("num_threads", &GenerationConfig::num_threads);
56+
57+
py::class_<Pipeline>(m, "Pipeline")
58+
.def(py::init<const std::string &, const std::string &>())
59+
.def_property_readonly("model", [](const Pipeline &self) { return self.model.get(); })
60+
.def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer.get(); });
4361
}
4462

4563
} // namespace qwen

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def build_extension(self, ext: CMakeExtension) -> None:
114114
HERE = Path(__file__).resolve().parent
115115

116116
setup(
117-
version="0.1",
117+
version="0.1.2",
118118
author="Shijie Wang",
119119
packages=find_packages(),
120120
ext_modules=[CMakeExtension("qwen_cpp._C")],

tiktoken.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,9 @@ class tiktoken {
164164
}
165165

166166
private:
167-
template <typename T>
168167
auto split_with_allowed_special_token(
169168
re2::StringPiece &input,
170-
const T &allowed_special
169+
const ankerl::unordered_dense::map<std::string, int> &allowed_special
171170
) const -> std::pair<std::optional<std::string>, re2::StringPiece> {
172171
if (special_regex_ == nullptr) return { std::nullopt, input };
173172

@@ -206,7 +205,7 @@ class tiktoken {
206205
auto _encode_native(
207206
const std::string &text,
208207
const ankerl::unordered_dense::map<std::string, int> &allowed_special
209-
) const -> const std::pair<std::vector<int>, int> {
208+
) const -> std::pair<std::vector<int>, int> {
210209
std::vector<int> ret;
211210
int last_piece_token_len = 0;
212211
re2::StringPiece input(text);

0 commit comments

Comments
 (0)