Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/infinicore_infer/models/jiuge.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ typedef struct
const void *const *attn_qkv;
// nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh]
const void *const *attn_qkv_b;
// nlayer * [dh]
const void *const *attn_q_norm;
// nlayer * [dh]
const void *const *attn_k_norm;
// nlayer * [ndev, d, nkvh / ndev * dh]
const void *const *attn_o;
// nlayer * [d]
Expand Down
65 changes: 57 additions & 8 deletions scripts/jiuge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
forward_batch,
)
from infer_task import InferTask, KVCache

from tokenizers import decoders as _dec
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
Expand Down Expand Up @@ -64,6 +64,12 @@ def attn_k_b(self, i):
def attn_v_b(self, i):
return f"model.layers.{i}.self_attn.v_proj.bias"

def attn_q_norm(self, i):
return f"model.layers.{i}.self_attn.q_norm.weight"

def attn_k_norm(self, i):
return f"model.layers.{i}.self_attn.k_norm.weight"

def ffn_norm(self, i):
return f"model.layers.{i}.post_attention_layernorm.weight"

Expand Down Expand Up @@ -123,7 +129,11 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None):
if "num_key_value_heads" in config
else config["num_attention_heads"]
),
dh=config["hidden_size"] // config["num_attention_heads"],
dh=(
config["head_dim"]
if "head_dim" in config
else config["hidden_size"] // config["num_attention_heads"]
),
di=config["intermediate_size"],
dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
Expand Down Expand Up @@ -281,6 +291,35 @@ def qkv_b_slices(_i):
else:
self.attn_qkv_b = None

if naming.attn_q_norm(0) in state_dict:
self.attn_q_norm_tensors = [
state_dict[naming.attn_q_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_q_norm_ptrs = [
self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs)
self.attn_k_norm_tensors = [
state_dict[naming.attn_k_norm(i)]
.reshape([2, dh // 2])
.transpose(0, 1)
.contiguous()
.to(torch_dt_norm)
for i in range(nlayer)
]
self.attn_k_norm_ptrs = [
self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer)
]
self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs)
else:
self.attn_q_norm = None
self.attn_k_norm = None

self.attn_o_tensor = [
(
state_dict[naming.attn_o(i)]
Expand Down Expand Up @@ -427,6 +466,20 @@ def load_all_safetensors_from_dir(dir_path_: str):
)
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
backend = getattr(self.tokenizer, "backend_tokenizer", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不止llama有这个问题,9g7b也有。建议无关模型类型,只要是在tokinizer中发现sequence normalizer有prepend和strip就修改

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不止llama有这个问题,9g7b也有。建议无关模型类型,只要是在tokinizer中发现sequence normalizer有prepend和strip就修改

好的

target = getattr(backend, "_tokenizer", backend)
norm = getattr(target, "normalizer", None)
dec = getattr(target, "decoder", None)
sn = repr(norm)[:800] if norm is not None else ""
sd = repr(dec)[:800] if dec is not None else ""
has_prepend = "Prepend" in sn
has_strip = "Strip" in sd
if has_prepend and has_strip:
target.decoder = _dec.Sequence([
_dec.Replace("▁", " "),
_dec.ByteFallback(),
_dec.Fuse(),
])
self.weights = JiugeWeightsImpl(
self.meta,
LlamaWeightsNaming(),
Expand Down Expand Up @@ -484,7 +537,7 @@ def load_all_safetensors_from_dir(dir_path_: str):
)
else:
raise ValueError("Unsupported weight naming")
elif "qwen2" == config["model_type"]:
elif "qwen2" == config["model_type"] or "qwen3" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
Expand Down Expand Up @@ -564,11 +617,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
Expand Down
12 changes: 2 additions & 10 deletions scripts/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,7 @@ async def chat_stream(id_, request_data, request: Request):
break

token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
content = request.app.state.model.tokenizer.decode(token)
chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False)
yield f"data: {chunk}\n\n"

Expand All @@ -236,11 +232,7 @@ async def chat(id_, request_data, request: Request):
break

token = await infer_task.output_queue.async_q.get()
content = (
request.app.state.model.tokenizer._tokenizer.id_to_token(token)
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
content = request.app.state.model.tokenizer.decode(token)
output.append(content)

output_text = "".join(output).strip()
Expand Down
2 changes: 2 additions & 0 deletions scripts/libinfinicore_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class JiugeWeightsCStruct(ctypes.Structure):
("attn_norm", POINTER(c_void_p)),
("attn_qkv", POINTER(c_void_p)),
("attn_qkv_b", POINTER(c_void_p)),
("attn_q_norm", POINTER(c_void_p)),
("attn_k_norm", POINTER(c_void_p)),
("attn_o", POINTER(c_void_p)),
("ffn_norm", POINTER(c_void_p)),
("ffn_gate_up", POINTER(c_void_p)),
Expand Down
29 changes: 22 additions & 7 deletions src/models/jiuge/jiuge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
infinirtStream_t stream;
infinirtStreamCreate(&stream);

std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out,
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
for (size_t layer = 0; layer < meta->nlayer; layer++) {
w_attn_norm.push_back(
Expand All @@ -32,6 +32,12 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
b_attn_qkv.push_back(
getAttnQKVBias(meta, weights, layer, idev, ndev));
}
if (weights->attn_q_norm != nullptr) {
w_attn_q_norm.push_back(
getAttnQNorm(meta, weights, layer));
w_attn_k_norm.push_back(
getAttnKNorm(meta, weights, layer));
}
w_attn_out.push_back(
getAttnO(meta, weights, layer, idev, ndev));
w_ffn_norm.push_back(
Expand All @@ -56,6 +62,8 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
w_attn_norm,
w_attn_qkv,
b_attn_qkv,
w_attn_q_norm,
w_attn_k_norm,
w_attn_out,
w_ffn_norm,
w_ffn_gate_up,
Expand Down Expand Up @@ -130,6 +138,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto dvoc = meta.dvoc;
auto stream = rsrc.stream;
bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0;

// Allocate buffers
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
Expand All @@ -141,7 +150,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq);

auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh});
auto qkv_buf_view = qkv_buf->view({ntok, nh + nkvh * 2, dh});
auto q_buf = qkv_buf_view->slice(1, 0, nh);
auto k_buf = qkv_buf_view->slice(1, nh, nkvh);

// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
Expand Down Expand Up @@ -198,19 +209,23 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon);
// qkv_proj
linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr);
if (has_qk_norm) {
rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer], meta.epsilon);
rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon);
}
// rope
rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);
rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table);

size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
auto q = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3});
auto k = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
auto v = qkv_buf_view->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});

// self attention
// concat
Expand Down
2 changes: 1 addition & 1 deletion src/models/jiuge/jiuge_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct DeviceResource {
// Weights
std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
cos_table;
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out,
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
// Streams
infinirtStream_t stream;
Expand Down
16 changes: 16 additions & 0 deletions src/models/jiuge/jiuge_weight.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape);
}

inline std::shared_ptr<Tensor> getAttnQNorm(
JiugeMeta const *meta,
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->dh});
return Tensor::weight((char *)(w->attn_q_norm[layer]), w->dt_norm, shape);
}

inline std::shared_ptr<Tensor> getAttnKNorm(
JiugeMeta const *meta,
JiugeWeights const *w,
size_t layer) {
auto shape = std::vector<size_t>({meta->dh});
return Tensor::weight((char *)(w->attn_k_norm[layer]), w->dt_norm, shape);
}

inline std::shared_ptr<Tensor> getAttnO(JiugeMeta const *meta,
JiugeWeights const *w, size_t layer,
size_t idev, size_t ndev) {
Expand Down