diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index 8b09e1ac..b7bd3363 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -12,9 +12,14 @@ struct JiugeModel; typedef struct { infiniDtype_t dt_logits; - size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size, dim_model_base; float epsilon, theta; uint32_t end_token; + // Longrope support + uint32_t rope_type; // 0 = standard, 1 = longrope + size_t original_max_position_embeddings; + const float *short_factor; // Array of dh/2 floats, nullptr if not longrope + const float *long_factor; // Array of dh/2 floats, nullptr if not longrope } JiugeMeta; typedef struct @@ -101,8 +106,9 @@ __C __export struct KVCache *createPagedKVCache( /// @param temperature 采样温度(0. 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp -/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill -/// @param enable_paged_attn 是否启用 paged attention +/// @param repetition_penalty 重复惩罚系数(1.0 表示无惩罚) +/// @param previous_tokens_per_req 每个请求的唯一 token ID 数组指针(vLLM-style,用于高效重复惩罚) +/// @param previous_tokens_len_per_req 每个请求的唯一 token 数量 /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void inferBatchJiuge(struct JiugeModel *, @@ -110,6 +116,9 @@ inferBatchJiuge(struct JiugeModel *, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, uint32_t *output); __C __export void @@ -120,6 +129,9 @@ inferBatch(struct JiugeModel *, const int32_t *block_tables, const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output); diff --git a/python/icinfer/engine/libinfinicore_infer.py b/python/icinfer/engine/libinfinicore_infer.py index 75f6e025..2e44e198 100644 --- a/python/icinfer/engine/libinfinicore_infer.py +++ b/python/icinfer/engine/libinfinicore_infer.py @@ -48,9 +48,14 @@ class JiugeMetaCStruct(ctypes.Structure): ("dctx", c_size_t), ("dvoc", c_size_t), ("kvcache_block_size", c_size_t), + ("dim_model_base", c_size_t), ("epsilon", c_float), ("theta", c_float), ("end_token", c_uint), + ("rope_type", c_uint), + ("original_max_position_embeddings", c_size_t), + ("short_factor", POINTER(c_float)), + ("long_factor", POINTER(c_float)), ] diff --git a/python/icinfer/models/jiuge.py b/python/icinfer/models/jiuge.py index 34c7f60f..e132c21c 100644 --- a/python/icinfer/models/jiuge.py +++ b/python/icinfer/models/jiuge.py @@ -3,6 +3,7 @@ from sympy import true from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import ctypes import os from pathlib import Path import safetensors @@ -122,6 +123,45 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["num_hidden_layers"] ) + dim_model_base = ( + config["dim_model_base"] if "dim_model_base" in config else config["hidden_size"] + ) + + # Load longrope configuration + rope_type = 0 # 0 = standard, 1 = longrope + original_max_position_embeddings = 0 + short_factor_ptr = None + long_factor_ptr = None + self._short_factor_array = None # Keep reference to prevent GC + self._long_factor_array = None # Keep reference to prevent GC + + rope_scaling = config.get("rope_scaling", {}) + if isinstance(rope_scaling, dict): + rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "") + if rope_scaling_type == "longrope": + rope_type = 1 + original_max_position_embeddings = rope_scaling.get( + "original_max_position_embeddings", + config.get("original_max_position_embeddings", 0) + ) + + short_factor_list = rope_scaling.get("short_factor", []) + long_factor_list = rope_scaling.get("long_factor", []) + + if short_factor_list and long_factor_list: + # Convert to ctypes arrays + half_dh = (config["hidden_size"] // config["num_attention_heads"]) // 2 + if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh: + self._short_factor_array = (c_float * half_dh)(*short_factor_list) + self._long_factor_array = (c_float * half_dh)(*long_factor_list) + short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float)) + long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float)) + else: + logger.warning( + f"Longrope factor arrays have wrong length: " + f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}" + ) + super().__init__( dt_logits=dt_, nlayer=config["num_hidden_layers"], @@ -138,10 +178,15 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["max_position_embeddings"] if max_tokens is None else max_tokens ), dvoc=config["vocab_size"], - block_size=config["block_size"], + kvcache_block_size=config["block_size"], + dim_model_base=dim_model_base, epsilon=config["rms_norm_eps"], theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), end_token=2, + rope_type=rope_type, + original_max_position_embeddings=original_max_position_embeddings, + short_factor=short_factor_ptr, + long_factor=long_factor_ptr, ) self.torch_dtype_logits = dtype @@ -206,7 +251,7 @@ def __init__( ) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = ( - state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + state_dict[naming.output_norm()].to(torch_dt_norm) ) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) diff --git a/python/icinfer/utils/jiuge_weights_loader.py b/python/icinfer/utils/jiuge_weights_loader.py index 7cff7bf2..0de730fb 100644 --- a/python/icinfer/utils/jiuge_weights_loader.py +++ b/python/icinfer/utils/jiuge_weights_loader.py @@ -7,6 +7,7 @@ from typing import Tuple import math from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import ctypes import os from pathlib import Path import safetensors @@ -119,6 +120,54 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): self.scale_o = config.scale_depth / math.sqrt(config.num_hidden_layers) self.scale_down = config.scale_depth / math.sqrt(config.num_hidden_layers) + dim_model_base = ( + config.dim_model_base if hasattr(config, "dim_model_base") else config.hidden_size + ) + + # Load longrope configuration + rope_type = 0 # 0 = standard, 1 = longrope + original_max_position_embeddings = 0 + short_factor_ptr = None + long_factor_ptr = None + self._short_factor_array = None # Keep reference to prevent GC + self._long_factor_array = None # Keep reference to prevent GC + + # Handle both dict and object config + if hasattr(config, "rope_scaling"): + rope_scaling = config.rope_scaling + elif isinstance(config, dict) and "rope_scaling" in config: + rope_scaling = config["rope_scaling"] + else: + rope_scaling = {} + + if isinstance(rope_scaling, dict): + rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "") + if rope_scaling_type == "longrope": + rope_type = 1 + original_max_position_embeddings = rope_scaling.get( + "original_max_position_embeddings", + getattr(config, "original_max_position_embeddings", 0) if not isinstance(config, dict) else config.get("original_max_position_embeddings", 0) + ) + + short_factor_list = rope_scaling.get("short_factor", []) + long_factor_list = rope_scaling.get("long_factor", []) + + if short_factor_list and long_factor_list: + # Convert to ctypes arrays + half_dh = (config.hidden_size // config.num_attention_heads) // 2 + if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh: + self._short_factor_array = (c_float * half_dh)(*short_factor_list) + self._long_factor_array = (c_float * half_dh)(*long_factor_list) + short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float)) + long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float)) + else: + import logging + logger = logging.getLogger(__name__) + logger.warning( + f"Longrope factor arrays have wrong length: " + f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}" + ) + super().__init__( dt_logits=dt_, nlayer=config.num_hidden_layers, @@ -134,9 +183,14 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): dctx=(config.max_position_embeddings if max_tokens is None else max_tokens), dvoc=config.vocab_size, kvcache_block_size=config.kvcache_block_size, + dim_model_base=dim_model_base, epsilon=config.rms_norm_eps, theta=(config.rope_theta if hasattr(config, "rope_theta") else 100000.0), end_token=2, + rope_type=rope_type, + original_max_position_embeddings=original_max_position_embeddings, + short_factor=short_factor_ptr, + long_factor=long_factor_ptr, ) self.torch_dtype_logits = dtype @@ -201,7 +255,7 @@ def __init__( ) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = ( - state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + state_dict[naming.output_norm()].to(torch_dt_norm) ) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..eb0137d5 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -1,5 +1,5 @@ class InferTask: - def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0): self.id = id self.finish_reason = None self.tokens = tokens @@ -7,14 +7,28 @@ def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): self.temperature = temperature self.topk = topk self.topp = topp + self.repetition_penalty = repetition_penalty self.end_tokens = end_tokens self._kv_cache = None self.pos = 0 + # vLLM-style unique token tracking for efficient repetition penalty + # Track unique token IDs that have been generated (not the full sequence) + # Initialize with prompt tokens so they are also penalized + self._unique_generated_tokens = set(tokens) # Initialize with prompt tokens! + self._unique_tokens_array = sorted(self._unique_generated_tokens) # Pre-sort for efficiency + self._unique_tokens_dirty = False # Already initialized, no need to rebuild + def bind_kvcache(self, kv_cache, pos=0): self._kv_cache = kv_cache self.pos = pos - self.tokens = self.tokens[pos:] + # Update tokens and add any new tokens to unique set + remaining_tokens = self.tokens[pos:] + for token in remaining_tokens: + if token not in self._unique_generated_tokens: + self._unique_generated_tokens.add(token) + self._unique_tokens_dirty = True + self.tokens = remaining_tokens def release_kvcache(self): cache = self._kv_cache @@ -34,6 +48,24 @@ def next(self, out_token): self.finish_reason = "length" else: self.tokens = [out_token] + # Incrementally update unique token set (vLLM-style) + # Only add if it's a new token (O(1) average) + if out_token not in self._unique_generated_tokens: + self._unique_generated_tokens.add(out_token) + self._unique_tokens_dirty = True + + def get_unique_previous_tokens(self): + """ + Returns a sorted list of unique token IDs that have been generated. + This is the vLLM-style "seen tokens" list for efficient repetition penalty. + + Returns: + tuple: (array, length) where array is sorted list of unique token IDs + """ + if self._unique_tokens_dirty: + self._unique_tokens_array = sorted(self._unique_generated_tokens) + self._unique_tokens_dirty = False + return self._unique_tokens_array, len(self._unique_tokens_array) class KVCache: diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 676d7de7..2dcda072 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -20,6 +20,7 @@ from infer_task import InferTask, KVCache from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import ctypes torch.set_default_device("cpu") @@ -113,6 +114,46 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["num_hidden_layers"] ) + dim_model_base = ( + config["dim_model_base"] if "dim_model_base" in config else config["hidden_size"] + ) + + # Load longrope configuration + rope_type = 0 # 0 = standard, 1 = longrope + original_max_position_embeddings = 0 + short_factor_ptr = None + long_factor_ptr = None + self._short_factor_array = None # Keep reference to prevent GC + self._long_factor_array = None # Keep reference to prevent GC + + rope_scaling = config.get("rope_scaling", {}) + if isinstance(rope_scaling, dict): + rope_scaling_type = rope_scaling.get("rope_type") or rope_scaling.get("type", "") + if rope_scaling_type == "longrope": + rope_type = 1 + original_max_position_embeddings = rope_scaling.get( + "original_max_position_embeddings", + config.get("original_max_position_embeddings", 0) + ) + + short_factor_list = rope_scaling.get("short_factor", []) + long_factor_list = rope_scaling.get("long_factor", []) + + if short_factor_list and long_factor_list: + # Convert to ctypes arrays + dh = config["head_dim"] if "head_dim" in config else config["hidden_size"] // config["num_attention_heads"] + half_dh = dh // 2 + if len(short_factor_list) == half_dh and len(long_factor_list) == half_dh: + self._short_factor_array = (c_float * half_dh)(*short_factor_list) + self._long_factor_array = (c_float * half_dh)(*long_factor_list) + short_factor_ptr = ctypes.cast(self._short_factor_array, POINTER(c_float)) + long_factor_ptr = ctypes.cast(self._long_factor_array, POINTER(c_float)) + else: + print( + f"Warning: Longrope factor arrays have wrong length: " + f"short={len(short_factor_list)}, long={len(long_factor_list)}, expected={half_dh}" + ) + super().__init__( dt_logits=dt_, nlayer=config["num_hidden_layers"], @@ -129,10 +170,15 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): config["max_position_embeddings"] if max_tokens is None else max_tokens ), dvoc=config["vocab_size"], - kvcache_block_size=0, + kvcache_block_size=config.get("block_size", 0), + dim_model_base=dim_model_base, epsilon=config["rms_norm_eps"], theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), end_token=2, + rope_type=rope_type, + original_max_position_embeddings=original_max_position_embeddings, + short_factor=short_factor_ptr, + long_factor=long_factor_ptr, ) self.torch_dtype_logits = dtype @@ -197,7 +243,7 @@ def __init__( ) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = ( - state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + state_dict[naming.output_norm()].to(torch_dt_norm) ) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) @@ -395,11 +441,55 @@ def __init__(self, tasks: List[InferTask]): self.temperaturas_list = [t.temperature for t in tasks] self.topks_list = [t.topk for t in tasks] self.topps_list = [t.topp for t in tasks] + self.repetition_penalties_list = [t.repetition_penalty for t in tasks] # Flatten token lists flat_tokens = [tok for toks in token_lists for tok in toks] self.ntok = len(flat_tokens) + # Collect unique tokens per request (vLLM-style for efficient repetition penalty) + # Each request has its own list of unique token IDs + self.unique_tokens_arrays = [] # List of arrays, one per request + self.unique_tokens_lens = [] # List of lengths, one per request + self.unique_tokens_flat = [] # Flattened array for C API + self.unique_tokens_offsets = [0] # Offsets into flat array + + total_unique_tokens = 0 + for task in tasks: + tokens_array, tokens_len = task.get_unique_previous_tokens() + self.unique_tokens_arrays.append(tokens_array) + self.unique_tokens_lens.append(tokens_len) + self.unique_tokens_flat.extend(tokens_array) + total_unique_tokens += tokens_len + self.unique_tokens_offsets.append(total_unique_tokens) + + # Convert to C-compatible arrays + if total_unique_tokens > 0: + self.unique_tokens_c = (c_uint * total_unique_tokens)(*self.unique_tokens_flat) + # Create array of pointers, one per request + self.unique_tokens_ptrs = [] + for req_idx in range(self.nreq): + offset = self.unique_tokens_offsets[req_idx] + length = self.unique_tokens_lens[req_idx] + if length > 0: + # Create pointer to the start of this request's tokens in the flat array + ptr = ctypes.cast( + ctypes.addressof(self.unique_tokens_c) + offset * ctypes.sizeof(c_uint), + POINTER(c_uint) + ) + else: + ptr = None + self.unique_tokens_ptrs.append(ptr) + # Create array of pointers (use None for empty requests) + self.unique_tokens_ptrs_array = (POINTER(c_uint) * self.nreq)(*self.unique_tokens_ptrs) + else: + self.unique_tokens_c = None + # All requests have no previous tokens + self.unique_tokens_ptrs_array = (POINTER(c_uint) * self.nreq)(*[None] * self.nreq) + + # Array of lengths per request + self.unique_tokens_lens_array = (c_uint * self.nreq)(*self.unique_tokens_lens) + # Convert to ctypes arrays in one pass self.tokens = (c_uint * self.ntok)(*flat_tokens) self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) @@ -408,6 +498,7 @@ def __init__(self, tasks: List[InferTask]): self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) self.topks = (c_uint * self.nreq)(*self.topks_list) self.topps = (c_float * self.nreq)(*self.topps_list) + self.repetition_penalties = (c_float * self.nreq)(*self.repetition_penalties_list) def input_args(self): return ( @@ -420,6 +511,9 @@ def input_args(self): self.temperaturas, self.topks, self.topps, + self.repetition_penalties, + self.unique_tokens_ptrs_array, # Array of pointers to unique tokens per request + self.unique_tokens_lens_array, # Array of lengths per request ) @@ -534,7 +628,7 @@ def load_all_safetensors_from_dir(dir_path_: str): else: raise ValueError("Unsupported model architecture") - + if "llama" == config["model_type"]: from tokenizers import decoders as _dec backend = getattr(self.tokenizer, "backend_tokenizer", None) @@ -593,9 +687,21 @@ def drop_kv_cache(self, kv_cache): def batch_infer_one_round(self, tasks: List[InferTask]): output = (c_uint * len(tasks))() batch_inputs = JiugeBatchedTask(tasks) + args = batch_inputs.input_args() self.jiuge_model.infer_batch( self.model_instance, - *(batch_inputs.input_args()), + args[0], # tokens + args[1], # ntok + args[2], # req_lens + args[3], # nreq + args[4], # req_pos + args[5], # kv_caches + args[6], # temperature + args[7], # topk + args[8], # topp + args[9], # repetition_penalty + args[10], # previous_tokens_per_req + args[11], # previous_tokens_len_per_req output, ) return list(output) @@ -616,6 +722,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. topk_, topp_, self.eos_token_id, + 1.0, # repetition_penalty default ) infer_task.bind_kvcache(KVCache(self)) @@ -648,7 +755,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): tasks = [ - InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id, 1.0) for i in range(batch_size) ] kv_caches = [KVCache(self) for _ in range(batch_size)] diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 115fbd0a..43098b60 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -7,7 +7,7 @@ import argparse import queue -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, HTTPException from fastapi.responses import StreamingResponse, JSONResponse import contextlib import uvicorn @@ -16,6 +16,10 @@ import json import threading import janus +import os +import signal +import asyncio +from pathlib import Path DEVICE_TYPE_MAP = { @@ -69,6 +73,36 @@ def parse_args(): action="store_true", help="Whether to use AWQ quantized model (default: False)", ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to run the server on (default: 8000)", + ) + parser.add_argument( + "--host", + type=str, + default="0.0.0.0", + help="Host to bind the server to (default: 0.0.0.0)", + ) + parser.add_argument( + "--request-timeout", + type=int, + default=30, + help="Request timeout in seconds. Process will exit if a request hangs longer than this (default: 30)", + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="Model name to return in /models endpoint. If not specified, will use the directory name from --model-path (like vLLM/llama.cpp)", + ) + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. Requests exceeding this limit will wait in queue. Default: unlimited", + ) return parser.parse_args() @@ -79,12 +113,47 @@ def parse_args(): max_tokens = args.max_tokens USE_AWQ = args.awq MAX_BATCH = args.max_batch +SERVER_PORT = args.port +SERVER_HOST = args.host +REQUEST_TIMEOUT = args.request_timeout +MAX_CONCURRENCY = args.max_concurrency + +# Derive model name from model path directory name (like vLLM and llama.cpp) +# Use --model-name if explicitly provided, otherwise use directory name +if args.model_name: + MODEL_NAME = args.model_name +elif model_path: + # Extract directory name from model path + # This follows the same convention as vLLM and llama.cpp + model_path_obj = Path(model_path).resolve() # Resolve to absolute path + if model_path_obj.is_dir(): + MODEL_NAME = model_path_obj.name + elif model_path_obj.is_file(): + # If it's a file, use the parent directory name + MODEL_NAME = model_path_obj.parent.name + else: + # Path doesn't exist yet, but extract name from path string + # Use the last component of the path + MODEL_NAME = model_path_obj.name or model_path_obj.parent.name +else: + MODEL_NAME = None print( f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." ) +print( + f"Request timeout: {REQUEST_TIMEOUT}s. Process will exit if a request hangs longer than this." +) +if MAX_CONCURRENCY is not None and MAX_CONCURRENCY > 0: + print(f"Max concurrency: {MAX_CONCURRENCY}. Requests exceeding this limit will wait in queue.") +else: + print("Max concurrency: unlimited") -def chunk_json(id_, content=None, role=None, finish_reason=None): +def chunk_json(id_, content=None, role=None, finish_reason=None, model="jiuge"): + """ + Generate SSE chunk format for streaming responses. + Used for Server-Sent Events (SSE) streaming mode. + """ delta = {} if content: delta["content"] = content @@ -94,12 +163,11 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): "id": id_, "object": "chat.completion.chunk", "created": int(time.time()), - "model": "jiuge", + "model": model, "system_fingerprint": None, "choices": [ { "index": 0, - "text": content, "delta": delta, "logprobs": None, "finish_reason": finish_reason, @@ -108,17 +176,65 @@ def chunk_json(id_, content=None, role=None, finish_reason=None): } +def chat_completion_json(id_, content, role="assistant", finish_reason=None, prompt_tokens=0, completion_tokens=0, model="jiuge"): + """ + Generate OpenAI-compatible non-streaming chat completion response. + Used for non-streaming (stream=False) mode. + """ + return { + "id": id_, + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "message": { + "role": role, + "content": content + }, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + } + + # A wrapper for InferTask that supports async output queue class AsyncInferTask(InferTask): - def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): - super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens) + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty=1.0, test_hang_seconds=0): + super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens, repetition_penalty) self.output_queue = janus.Queue() - print(f"[INFO] Create InferTask {self.id}") + self.last_activity_time = time.time() # Track when task was last active + self.test_hang_seconds = test_hang_seconds # Test parameter: sleep for this many seconds to simulate hang (set to 0 after first use) + self.timed_out = False # Flag to mark if task has timed out + self.initial_prompt_tokens = len(tokens) # Track initial prompt token count for usage statistics + self.generated_tokens = [] # Track generated token IDs for counting completion tokens + print(f"[INFO] Create InferTask {self.id}" + (f" (TEST: will hang for {test_hang_seconds}s once)" if test_hang_seconds > 0 else "")) def output(self, out_token): self.next(out_token) + self.last_activity_time = time.time() # Update activity time when output is generated + if out_token is not None: # Track non-None tokens for completion count + self.generated_tokens.append(out_token) self.output_queue.sync_q.put(out_token) + def signal_timeout(self): + """Signal that this task has timed out""" + self.timed_out = True + self.finish_reason = "timeout" + + def signal_internal_error(self): + """Signal that an internal error occurred (process will be killed)""" + self.timed_out = True # Reuse timed_out flag to trigger error response + self.finish_reason = "internal_error" + @contextlib.asynccontextmanager async def lifespan(app: FastAPI): @@ -133,9 +249,24 @@ async def lifespan(app: FastAPI): ) app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) app.state.request_queue = janus.Queue() + app.state.active_tasks = {} # Track active tasks: task_id -> task object + app.state.task_lock = threading.Lock() # Lock for accessing active_tasks + + # Initialize concurrency control semaphore + if MAX_CONCURRENCY is not None and MAX_CONCURRENCY > 0: + app.state.concurrency_semaphore = asyncio.Semaphore(MAX_CONCURRENCY) + print(f"Max concurrency: {MAX_CONCURRENCY}. Requests exceeding this limit will wait in queue.") + else: + app.state.concurrency_semaphore = None + print("Max concurrency: unlimited") + worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) worker_thread.start() + # Start timeout checker thread + timeout_checker_thread = threading.Thread(target=timeout_checker_loop, args=(app,), daemon=True) + timeout_checker_thread.start() + try: yield # The app runs here finally: @@ -151,6 +282,61 @@ async def lifespan(app: FastAPI): App = FastAPI(lifespan=lifespan) +# Timeout checker: monitors active tasks and kills process if any task hangs +def timeout_checker_loop(app): + """Monitor active tasks and kill the process if any task hangs beyond timeout""" + while True: + try: + time.sleep(5) # Check every 5 seconds + + current_time = time.time() + hung_tasks = [] + + with app.state.task_lock: + # Check all active tasks for timeout + for task_id, task in list(app.state.active_tasks.items()): + time_since_activity = current_time - task.last_activity_time + if time_since_activity > REQUEST_TIMEOUT: + hung_tasks.append((task_id, time_since_activity)) + + # If we found hung tasks, signal all active tasks and then kill the process + if hung_tasks: + print(f"[ERROR] Detected {len(hung_tasks)} hung task(s) exceeding timeout of {REQUEST_TIMEOUT}s:") + for task_id, hang_time in hung_tasks: + print(f" - Task {task_id}: hung for {hang_time:.1f}s") + + # Signal all active tasks (not just hung ones) to send error responses to clients + # This ensures all processing requests get error responses before process is killed + with app.state.task_lock: + all_active_tasks = list(app.state.active_tasks.items()) + print(f"[ERROR] Signaling {len(all_active_tasks)} active task(s) to send error responses...") + for task_id, task in all_active_tasks: + if task_id in [tid for tid, _ in hung_tasks]: + # Hung tasks get timeout error + task.signal_timeout() + print(f"[ERROR] Signaled timeout to hung task {task_id}") + else: + # Other active tasks get internal error (process will be killed) + task.signal_internal_error() + print(f"[ERROR] Signaled internal error to active task {task_id}") + + # Give a short time for error responses to be sent to clients + print(f"[ERROR] Waiting 2 seconds for error responses to be sent to clients...") + time.sleep(2) + + print(f"[ERROR] Killing process to trigger recovery mechanism...") + # Kill the process - this will be detected by the babysitter and trigger restart + os.kill(os.getpid(), signal.SIGTERM) + # If SIGTERM doesn't work, use SIGKILL as fallback after a delay + time.sleep(2) + os.kill(os.getpid(), signal.SIGKILL) + break + + except Exception as e: + print(f"[ERROR] Exception in timeout checker: {e}") + time.sleep(5) + + # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. def worker_loop(app): while True: @@ -162,22 +348,65 @@ def worker_loop(app): if task is None: return + # Register task as active + with app.state.task_lock: + app.state.active_tasks[task.id] = task + task.last_activity_time = time.time() + batch = [task] while len(batch) < MAX_BATCH: try: req = app.state.request_queue.sync_q.get_nowait() if req is not None: batch.append(req) + # Register additional tasks as active + with app.state.task_lock: + app.state.active_tasks[req.id] = req + req.last_activity_time = time.time() except queue.Empty: break + + # Update activity time before inference + batch_start_time = time.time() + with app.state.task_lock: + for t in batch: + t.last_activity_time = batch_start_time + + # Test hang simulation: if any task has test_hang_seconds > 0, sleep to simulate hang + # Only apply once per task by setting test_hang_seconds to 0 after use + tasks_needing_hang = [t for t in batch if t.test_hang_seconds > 0] + if tasks_needing_hang: + max_hang_time = max(t.test_hang_seconds for t in tasks_needing_hang) + print(f"[TEST] Simulating hang for {max_hang_time}s (task will exceed timeout if timeout < {max_hang_time}s)") + time.sleep(max_hang_time) + print(f"[TEST] Hang simulation complete, continuing with inference...") + # Reset test_hang_seconds to 0 for all tasks that used it (so it won't hang again) + for t in tasks_needing_hang: + t.test_hang_seconds = 0 + output_tokens = app.state.model.batch_infer_one_round(batch) + + # Update activity time after inference (critical: if batch_infer_one_round hangs, + # this won't execute, and timeout checker will detect it) + batch_end_time = time.time() + with app.state.task_lock: + for task, token in zip(batch, output_tokens): + task.last_activity_time = batch_end_time + task.output(token) + if task.finish_reason is None: + # Task continues, keep it tracked but update activity time + # It will be put back in queue and processed again + pass + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) + # Remove task from active tracking when finished + app.state.active_tasks.pop(task.id, None) + + # Put unfinished tasks back in queue (outside lock to avoid deadlock) for task, token in zip(batch, output_tokens): - task.output(token) if task.finish_reason is None: app.state.request_queue.sync_q.put(task) - else: - print(f"[INFO] Task {task.id} finished infer.") - app.state.kv_cache_pool.release_sync(task) def build_task(id_, request_data, request: Request): @@ -185,11 +414,16 @@ def build_task(id_, request_data, request: Request): if "messages" in request_data: # Chat format messages = request_data.get("messages", []) - input_content = request.app.state.model.tokenizer.apply_chat_template( - conversation=messages, - add_generation_prompt=True, - tokenize=False, - ) + # Get chat_template_kwargs from request, default to empty dict + chat_template_kwargs = request_data.get("chat_template_kwargs", {}) + # Merge with default parameters, allowing chat_template_kwargs to override + template_params = { + "conversation": messages, + "add_generation_prompt": True, + "tokenize": False, + **chat_template_kwargs # Allow override of defaults + } + input_content = request.app.state.model.tokenizer.apply_chat_template(**template_params) tokens = request.app.state.model.tokenizer.encode(input_content) max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len()) else: @@ -197,31 +431,79 @@ def build_task(id_, request_data, request: Request): prompt = request_data.get("prompt", "") tokens = request.app.state.model.tokenizer.encode(prompt) max_tokens = request_data.get("max_tokens", 0) - + + # Test parameter: test_hang_seconds - sleep for this many seconds to simulate hang + # This is useful for testing the timeout checker mechanism. + # Example: Set "test_hang_seconds": 350 in request to test timeout (if timeout is 300s) + # The sleep happens in the worker loop before batch_infer_one_round, simulating a hang + test_hang_seconds = request_data.get("test_hang_seconds", 0) + return AsyncInferTask( id_, tokens, max_tokens, request_data.get("temperature", 1.0), - request_data.get("top_k", 1), + request_data.get("top_k", 0), # Default to 0 (disabled) to consider all tokens, matching vLLM behavior request_data.get("top_p", 1.0), request.app.state.model.eos_token_id, + request_data.get("repetition_penalty", 1.0), + test_hang_seconds=test_hang_seconds, ) async def chat_stream(id_, request_data, request: Request): + # Acquire concurrency semaphore if configured (waits if max concurrency reached) + semaphore = request.app.state.concurrency_semaphore + if semaphore: + await semaphore.acquire() + try: + async for item in _chat_stream_impl(id_, request_data, request): + yield item + finally: + semaphore.release() + else: + async for item in _chat_stream_impl(id_, request_data, request): + yield item + + +async def _chat_stream_impl(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) + # Track task from creation + with request.app.state.task_lock: + request.app.state.active_tasks[infer_task.id] = infer_task + infer_task.last_activity_time = time.time() + await request.app.state.kv_cache_pool.acquire(infer_task) + # Check if task already timed out before starting stream + if infer_task.timed_out: + raise HTTPException( + status_code=504, + detail={ + "message": f"Request timeout: task exceeded {REQUEST_TIMEOUT}s timeout", + "type": "timeout_error", + "code": "timeout" + } + ) + + # Get model name from request or use global MODEL_NAME + model_name = request_data.get("model", MODEL_NAME or "jiuge") + # Initial empty content chunk = json.dumps( - chunk_json(id_, content="", role="assistant"), ensure_ascii=False + chunk_json(id_, content="", role="assistant", model=model_name), ensure_ascii=False ) yield f"data: {chunk}\n\n" request.app.state.request_queue.sync_q.put(infer_task) + # For streaming: accumulate tokens and decode incrementally to handle UTF-8 properly + # We maintain a buffer and decode the full buffer each time, only yielding new characters + # This ensures multi-byte UTF-8 sequences (emojis, etc.) are decoded correctly + token_buffer = [] + last_yielded_length = 0 + while True: if await request.is_disconnected(): print("Client disconnected. Aborting stream.") @@ -230,32 +512,218 @@ async def chat_stream(id_, request_data, request: Request): infer_task.finish_reason is not None and infer_task.output_queue.async_q.empty() ): - chunk = json.dumps( - chunk_json(id_, finish_reason=infer_task.finish_reason), - ensure_ascii=False, - ) + # Decode any remaining tokens in buffer before finishing + if token_buffer: + try: + decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) + + # On final chunk, yield everything including any trailing replacement chars + # (they're no longer incomplete sequences since this is the end) + remaining_content = decoded_text[last_yielded_length:] + if remaining_content: + # Log if replacement chars present (for monitoring) + if '\ufffd' in remaining_content: + print(f"[REPLACEMENT_CHAR_DETECTED] Found replacement char(s) in final streaming chunk (Request: {id_})") + chunk = json.dumps(chunk_json(id_, content=remaining_content, model=model_name), ensure_ascii=False) + yield f"data: {chunk}\n\n" + except Exception: + pass + + # Check if timed out or internal error - yield error chunk instead of raising HTTPException + # (can't raise HTTPException after streaming has started) + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + # Yield error chunk in SSE format + error_chunk = { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": None + }], + "error": { + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + } + chunk = json.dumps(error_chunk, ensure_ascii=False) + yield f"data: {chunk}\n\n" + yield "data: [DONE]\n\n" + else: + # Final chunk: empty delta with finish_reason (OpenAI API spec) + chunk = json.dumps( + chunk_json(id_, finish_reason=infer_task.finish_reason, model=model_name), + ensure_ascii=False, + ) + yield f"data: {chunk}\n\n" + # Send [DONE] marker to indicate stream completion (OpenAI API spec) + yield "data: [DONE]\n\n" + break + + # Check for timeout or internal error before getting next token + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + # Yield error chunk in SSE format (can't raise HTTPException after streaming started) + error_chunk = { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": None + }], + "error": { + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + } + chunk = json.dumps(error_chunk, ensure_ascii=False) yield f"data: {chunk}\n\n" + yield "data: [DONE]\n\n" break token = await infer_task.output_queue.async_q.get() - content = request.app.state.model.tokenizer.decode(token) - chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False) - yield f"data: {chunk}\n\n" + # Skip EOS tokens - don't include them in the stream + # The finish_reason will be set in the final chunk + if token is None: + continue + # Handle end_tokens as list, tuple, or single value + if isinstance(infer_task.end_tokens, (list, tuple)): + if token in infer_task.end_tokens: + continue + elif token == infer_task.end_tokens: + continue + + # Accumulate token in buffer + token_buffer.append(token) + + # Decode the entire buffer each time to ensure proper UTF-8 handling + # The tokenizer handles multi-byte sequences correctly when decoding token lists + try: + decoded_text = request.app.state.model.tokenizer.decode(token_buffer, skip_special_tokens=False) + + # vLLM-style UTF-8 buffering: if text ends with replacement char (), + # it's likely an incomplete UTF-8 byte sequence - hold it back until more tokens arrive + # Only check the end, not the middle (middle replacement chars are real invalid tokens) + holds_back_incomplete_utf8 = False + if decoded_text and decoded_text.endswith('\ufffd'): + # Incomplete UTF-8 sequence - hold back this chunk + holds_back_incomplete_utf8 = True + else: + # Calculate new content by comparing current decode with what we've already yielded + if len(decoded_text) > last_yielded_length: + new_content = decoded_text[last_yielded_length:] + + # Only yield if we have new content + if new_content: + chunk = json.dumps(chunk_json(id_, content=new_content, model=model_name), ensure_ascii=False) + yield f"data: {chunk}\n\n" + last_yielded_length = len(decoded_text) + # Prevent buffer from growing too large by periodically flushing + # Keep last 5 tokens for multi-token character sequences + # CRITICAL: Don't trim if we're holding back incomplete UTF-8 (could break sequence) + # When trimming, we need to calculate how many characters from the kept tokens + # have already been yielded. We do this by: + # 1. Calculating what portion of the full decoded text corresponds to removed tokens + # 2. The remaining portion (kept tokens) may have already been partially yielded + # 3. When we decode only the kept tokens, we need to figure out which portion was already sent + if not holds_back_incomplete_utf8 and len(token_buffer) > 20: + # Calculate what portion of decoded_text corresponds to removed tokens + num_removed = len(token_buffer) - 5 + removed_tokens = token_buffer[:num_removed] + kept_tokens = token_buffer[-5:] + + # Decode removed and kept portions separately + try: + removed_text = request.app.state.model.tokenizer.decode(removed_tokens, skip_special_tokens=False) + # The removed portion starts from the beginning, so its length is what we've already fully processed + # Calculate how many characters from kept tokens were already yielded + # We know: decoded_text = removed_text + kept_portion_from_full_decode + # And we've yielded up to last_yielded_length characters + if last_yielded_length > len(removed_text): + # Some portion of the kept tokens was already yielded + # Calculate the offset into the kept portion that was already sent + chars_yielded_from_kept = last_yielded_length - len(removed_text) + else: + # Nothing from kept tokens was yielded yet + chars_yielded_from_kept = 0 + + # Decode the kept tokens to get their full text + decoded_kept = request.app.state.model.tokenizer.decode(kept_tokens, skip_special_tokens=False) + + # Update last_yielded_length to point to the character position in the new buffer + # that corresponds to what we've already sent + last_yielded_length = min(chars_yielded_from_kept, len(decoded_kept)) + token_buffer = kept_tokens + except Exception: + # If decoding fails during trimming, don't trim (safer to keep buffer) + # Log but continue with current buffer + print(f"[Warning] Failed to decode during buffer trim, keeping full buffer") + + except Exception as e: + # If decoding fails, skip this token and log (don't break the stream) + print(f"[Warning] Failed to decode token {token}: {e}") + if len(token_buffer) > 0: + token_buffer.pop() # Remove the problematic token + + except HTTPException: + # Re-raise HTTPException to propagate error status code + raise except Exception as e: print(f"[Error] ID : {id_} Exception: {e}") + raise HTTPException( + status_code=500, + detail={ + "message": str(e), + "type": "internal_error", + "code": "internal_error" + } + ) finally: - if infer_task.finish_reason is None: - infer_task.finish_reason = "cancel" + if infer_task: + if infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + # Clean up task from active tracking + with request.app.state.task_lock: + request.app.state.active_tasks.pop(infer_task.id, None) async def chat(id_, request_data, request: Request): + # Acquire concurrency semaphore if configured (waits if max concurrency reached) + semaphore = request.app.state.concurrency_semaphore + if semaphore: + await semaphore.acquire() + try: + return await _chat_impl(id_, request_data, request) + finally: + semaphore.release() + else: + return await _chat_impl(id_, request_data, request) + + +async def _chat_impl(id_, request_data, request: Request): try: infer_task = build_task(id_, request_data, request) + # Track task from creation + with request.app.state.task_lock: + request.app.state.active_tasks[infer_task.id] = infer_task + infer_task.last_activity_time = time.time() + await request.app.state.kv_cache_pool.acquire(infer_task) request.app.state.request_queue.sync_q.put(infer_task) - output = [] + # Collect all tokens first, then decode at once to preserve UTF-8 sequences + tokens = [] while True: if ( infer_task.finish_reason is not None @@ -263,16 +731,134 @@ async def chat(id_, request_data, request: Request): ): break + # Check for timeout or internal error before getting next token + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + return JSONResponse( + content={ + "error": { + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + }, + status_code=500 # Internal Server Error + ) + token = await infer_task.output_queue.async_q.get() - content = request.app.state.model.tokenizer.decode(token) - output.append(content) - output_text = "".join(output).strip() - response = chunk_json( + # Skip EOS tokens - don't include them in the output + if token is None: + continue + # Handle end_tokens as list, tuple, or single value + if isinstance(infer_task.end_tokens, (list, tuple)): + if token in infer_task.end_tokens: + continue + elif token == infer_task.end_tokens: + continue + + # Collect tokens - decode all at once to preserve multi-byte UTF-8 characters + tokens.append(token) + + # Check if timed out or internal error before returning response + if infer_task.timed_out: + # Both timeout and internal_error result in internal error response + # because the process will be killed and restarted + return JSONResponse( + content={ + "error": { + "message": "Internal server error: process will be restarted", + "type": "internal_error", + "code": "internal_error" + } + }, + status_code=500 # Internal Server Error + ) + + # Decode all tokens at once to preserve multi-byte UTF-8 sequences (emojis, etc.) + # This is critical for proper handling of characters that span multiple tokens + # CRITICAL: Ensure proper UTF-8 handling to avoid replacement characters in markdown references + if tokens: + try: + # Decode tokens to string - tokenizer should handle UTF-8 correctly + decoded_output = request.app.state.model.tokenizer.decode(tokens, skip_special_tokens=False) + # Handle both bytes and str return types from tokenizer + if isinstance(decoded_output, bytes): + # If tokenizer returns bytes, decode with UTF-8 + output_text = decoded_output.decode('utf-8', errors='strict').strip() + else: + # If tokenizer returns str, verify it's valid UTF-8 by re-encoding/decoding + # This catches any invalid surrogate pairs or other UTF-8 issues + # Use 'surrogatepass' to preserve surrogates, then 'strict' for final decode + try: + # Validate UTF-8 by encoding and decoding - if it fails, there's an issue + validated = decoded_output.encode('utf-8', errors='strict').decode('utf-8', errors='strict') + output_text = validated.strip() + except UnicodeError: + # If validation fails, try with error handling to recover + print(f"[Warning] UTF-8 validation failed, using replacement for invalid sequences") + output_text = decoded_output.encode('utf-8', errors='replace').decode('utf-8', errors='replace').strip() + except Exception as e: + print(f"[Error] Token decoding error: {e}") + # Last resort: try with error replacement + try: + decoded_output = request.app.state.model.tokenizer.decode(tokens, skip_special_tokens=False) + if isinstance(decoded_output, bytes): + output_text = decoded_output.decode('utf-8', errors='replace').strip() + else: + output_text = decoded_output.encode('utf-8', errors='replace').decode('utf-8', errors='replace').strip() + except Exception as e2: + print(f"[Error] Failed to decode tokens even with error handling: {e2}") + output_text = "" + else: + output_text = "" + + # Strip EOS token strings from the end of output (defensive check) + # Decode EOS tokens to get their string representations + eos_token_strings = [] + end_tokens_list = [] + if isinstance(infer_task.end_tokens, (list, tuple)): + end_tokens_list = list(infer_task.end_tokens) + else: + end_tokens_list = [infer_task.end_tokens] + + for eos_token_id in end_tokens_list: + try: + eos_str = request.app.state.model.tokenizer.decode([eos_token_id]) + if eos_str: + eos_token_strings.append(eos_str) + except Exception: + pass + + # Remove EOS token strings from the end of output + for eos_str in eos_token_strings: + if output_text.endswith(eos_str): + output_text = output_text[:-len(eos_str)].rstrip() + + # vLLM-style: Log replacement characters for monitoring + # Replacement chars at the end would have been incomplete UTF-8, but for non-streaming + # we decode the full sequence, so any replacement chars are real invalid tokens + if '\ufffd' in output_text: + # Log for monitoring/debugging + print(f"[REPLACEMENT_CHAR_DETECTED] Found replacement char(s) in non-streaming response (Request: {id_})") + + # Calculate token usage + prompt_tokens = infer_task.initial_prompt_tokens + completion_tokens = len(infer_task.generated_tokens) + + # Get model name from request data or use global MODEL_NAME + model_name = request_data.get("model", MODEL_NAME or "jiuge") + + # Use correct OpenAI-compatible format for non-streaming response + response = chat_completion_json( id_, content=output_text, role="assistant", finish_reason=infer_task.finish_reason or "stop", + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + model=model_name ) return response @@ -282,6 +868,9 @@ async def chat(id_, request_data, request: Request): finally: if infer_task.finish_reason is None: infer_task.finish_reason = "cancel" + # Clean up task from active tracking + with request.app.state.task_lock: + request.app.state.active_tasks.pop(infer_task.id, None) @App.post("/chat/completions") @@ -300,11 +889,15 @@ async def chat_completions(request: Request): stream = data.get("stream", False) id_ = f"cmpl-{uuid.uuid4().hex}" if stream: + # FastAPI's exception handler will catch HTTPException raised from the generator return StreamingResponse( chat_stream(id_, data, request), media_type="text/event-stream" ) else: response = await chat(id_, data, request) + # If response is already a JSONResponse (error case), return it directly + if isinstance(response, JSONResponse): + return response return JSONResponse(content=response) @@ -318,16 +911,16 @@ async def completion(id_, request_data, request: Request): max_tokens = request_data.get("max_tokens", 0) if max_tokens > 0: return JSONResponse( - content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."}, + content={"error": "max_tokens > 0 is not supported yet. Please use max_tokens=0 for logprobs calculation."}, status_code=400 ) - + infer_task = build_task(id_, request_data, request) await request.app.state.kv_cache_pool.acquire(infer_task) - + output = [] logprobs = [] - + # Handle echo and logprobs calculation echo = request_data.get("echo", False) if echo: @@ -340,12 +933,12 @@ async def completion(id_, request_data, request: Request): .replace("<0x0A>", "\n") ) output.append(content) - + # Calculate logprobs for input tokens from jiuge import JiugeBatchedTask batch_inputs = JiugeBatchedTask([infer_task]) log_probs = torch.zeros( - (batch_inputs.ntok, request.app.state.model.meta.dvoc), + (batch_inputs.ntok, request.app.state.model.meta.dvoc), dtype=request.app.state.model.meta.torch_dtype_logits ) request.app.state.model.jiuge_model.forward_batch( @@ -358,39 +951,43 @@ async def completion(id_, request_data, request: Request): batch_inputs.kv_caches, log_probs.data_ptr(), ) - + log_probs = log_probs.float() - + # Calculate correct logprobs for input tokens token_logprobs = [] for i in range(len(infer_task.tokens) - 1): # Only up to second-to-last token next_token = infer_task.tokens[i+1] # Next token to predict logprob = log_probs[i, next_token].item() # Use position i logits to predict position i+1 token token_logprobs.append(logprob) - + # First token has no context, so logprob is None logprobs = [None] + token_logprobs else: # echo=false: don't calculate logprobs since user can't see input text logprobs = [] - + # For max_tokens=0, we need to manually release the KV cache since we don't go through worker await request.app.state.kv_cache_pool.release(infer_task) print(f"[DEBUG] {id_} Released KV cache for max_tokens=0") output_text = "".join(output).strip() - + + # vLLM-style: Log replacement characters for monitoring + if '\ufffd' in output_text: + print(f"[REPLACEMENT_CHAR_DETECTED] Found replacement char(s) in completion response (Request: {id_})") + # Prepare tokens list for logprobs tokens_list = [] text_offset_list = [] current_offset = 0 - + # Build tokens list and text offsets for i, content in enumerate(output): tokens_list.append(content) text_offset_list.append(current_offset) current_offset += len(content) - + # Build response according to DeepSeek API completion format response = { "id": id_, @@ -440,15 +1037,75 @@ async def completions(request: Request): id_ = f"cmpl-{uuid.uuid4().hex}" response = await completion(id_, data, request) - + # Check if response is already a JSONResponse (error case) if isinstance(response, JSONResponse): return response else: return JSONResponse(content=response) + +@App.get("/models") +async def list_models(request: Request): + """ + OpenAI-compatible /models endpoint. + Returns a list of available models (single model specified by --model-name argument). + """ + try: + # Check if model is loaded (server is ready) + if not hasattr(request.app.state, 'model') or request.app.state.model is None: + # Server not ready yet - return 503 Service Unavailable + return JSONResponse( + content={"error": "Service not ready yet, model still loading"}, + status_code=503 + ) + + # Use model name from argument/directory name, otherwise try to detect from config + model_id = MODEL_NAME + + if not model_id: + # Get model information from app state + model = request.app.state.model + model_id = "unknown" # Default model ID + + # Try to get model name from config if available + if hasattr(model, 'config') and model.config: + # Try model_type first + model_id = model.config.get("model_type", "unknown") + # If model_type is not informative, try architectures + if model_id == "unknown" and "architectures" in model.config: + architectures = model.config.get("architectures", []) + if architectures: + model_id = architectures[0].lower() + + return JSONResponse(content={ + "object": "list", + "data": [ + { + "id": model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "infini", + "permission": [], + "root": model_id, + "parent": None + } + ] + }) + except AttributeError as e: + # Model not loaded yet + print(f"[Error] Model not loaded in /models: {e}") + return JSONResponse( + content={"error": "Service not ready yet, model still loading"}, + status_code=503 + ) + except Exception as e: + print(f"[Error] Exception in /models: {e}") + return JSONResponse(content={"error": str(e)}, status_code=500) + + if __name__ == "__main__": - uvicorn.run(App, host="0.0.0.0", port=8000) + uvicorn.run(App, host=SERVER_HOST, port=SERVER_PORT) """ curl -N -H "Content-Type: application/json" \ @@ -456,12 +1113,74 @@ async def completions(request: Request): -d '{ "model": "jiuge", "messages": [ - {"role": "user", "content": "山东最高的山是?"} + {"role": "user", "content": "介绍你自己"} ], - "temperature": 1.0, - "top_k": 50, - "top_p": 0.8, + "temperature": 0.7, + "top_p": 0.7, + "repetition_penalty": 1.02, + "stream": false, + "chat_template_kwargs": {"enable_thinking": false} + }' + + +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "system", "content": "你是一个由启元实验室开发的九格助手,你擅长中英文对话,能够理解并处理各种问题,提供安全、有帮助、准确的回答。当前时间:2025-12-24#注意:回复之前注意结合上下文和工具返回内容进行回复"}, + {"role": "user", "content": "怎么看待台海局势"} + ], + "temperature": 0.7, + "top_p": 0.7, "max_tokens": 512, - "stream": true + "repetition_penalty": 1.1, + "stream": false, + "chat_template_kwargs": {"enable_thinking": false} }' + +# Test timeout checker: simulate a hang that exceeds the timeout +# This will cause the process to be killed by the timeout checker +# (assuming --request-timeout is set to a value less than test_hang_seconds) +# Example: if --request-timeout=300, use test_hang_seconds=350 to trigger timeout +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "Hello"} + ], + "temperature": 0.7, + "test_hang_seconds": 350, + "stream": false + }' + +# Test UTF-8 decoding fix: Markdown reference with special characters +# This validates that characters don't appear in markdown references +curl -s -X POST http://127.0.0.1:8000/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen3-32B", + "messages": [ + {"role": "user", "content": "请写一段包含markdown链接的文本,例如 [链接文本](https://example.com) 和引用 [^1]"} + ], + "temperature": 0.7, + "max_tokens": 2000, + "stream": false + }' | python3 -c "import sys, json; data = json.load(sys.stdin); content = data['choices'][0]['message']['content'] if 'choices' in data else ''; print('PASS' if '' not in content and '' not in content else 'FAIL: Found replacement characters'); print(content[:200])" + +# Test UTF-8 decoding: Chinese with emojis +curl -s -X POST http://127.0.0.1:8000/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "请用中文回答:什么是人工智能?使用一些表情符号。"} + ], + "temperature": 0.7, + "max_tokens": 150, + "stream": false + }' | python3 -c "import sys, json; data = json.load(sys.stdin); content = data['choices'][0]['message']['content'] if 'choices' in data else ''; print('PASS' if '' not in content and '' not in content else 'FAIL: Found replacement characters'); print(content[:200])" + + """ diff --git a/scripts/libinfinicore_infer/jiuge.py b/scripts/libinfinicore_infer/jiuge.py index 89553041..38cf536d 100644 --- a/scripts/libinfinicore_infer/jiuge.py +++ b/scripts/libinfinicore_infer/jiuge.py @@ -14,9 +14,14 @@ class JiugeMetaCStruct(Structure): ("dctx", c_size_t), ("dvoc", c_size_t), ("kvcache_block_size", c_size_t), + ("dim_model_base", c_size_t), ("epsilon", c_float), ("theta", c_float), ("end_token", c_uint), + ("rope_type", c_uint), + ("original_max_position_embeddings", c_size_t), + ("short_factor", POINTER(c_float)), + ("long_factor", POINTER(c_float)), ] @@ -86,6 +91,9 @@ def register_lib(cls, lib): POINTER(c_float), POINTER(c_uint), POINTER(c_float), + POINTER(c_float), + POINTER(POINTER(c_uint)), # previous_tokens_per_req: array of pointers + POINTER(c_uint), # previous_tokens_len_per_req: array of lengths POINTER(c_uint), ] @@ -128,6 +136,9 @@ def infer_batch( temperature, topk, topp, + repetition_penalty, + previous_tokens_per_req, + previous_tokens_len_per_req, output, ): self.lib.inferBatchJiuge( @@ -141,6 +152,9 @@ def infer_batch( temperature, topk, topp, + repetition_penalty, + previous_tokens_per_req, + previous_tokens_len_per_req, output, ) diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index 1cbc267a..1e5bd9bc 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -272,8 +272,11 @@ void InferenceContext::swiglu(std::shared_ptr out, } void InferenceContext::randomSample(std::shared_ptr out, - std::shared_ptr prob, - float random_val, float top_p, uint32_t top_k, float temperature) { + std::shared_ptr prob, + float random_val, float top_p, uint32_t top_k, float temperature, + float repetition_penalty, + const uint32_t *previous_tokens, + size_t previous_tokens_len) { size_t key = CacheManager::createDescriptorKey(out, prob); infiniopRandomSampleDescriptor_t desc; @@ -288,10 +291,12 @@ void InferenceContext::randomSample(std::shared_ptr out, ensure_workspace(workspace_size); void *workspace = workspace_storage->memory(); + RUN_INFINI(infiniopRandomSample( desc, workspace, workspace_size, out->data(), prob->data(), - random_val, top_p, top_k, temperature, + random_val, top_p, top_k, temperature, repetition_penalty, + previous_tokens, previous_tokens_len, stream)); } @@ -374,7 +379,7 @@ void InferenceContext::pagedCaching(std::shared_ptr k, infiniopPagedCachingDescriptor_t desc; if (!cache_manager->getPagedCachingDescriptor(key, desc)) { RUN_INFINI(infiniopCreatePagedCachingDescriptor( - op_handle, &desc, k->desc(), v->desc(), + op_handle, &desc, k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc())); cache_manager->putPagedCachingDescriptor(key, desc); } @@ -416,7 +421,7 @@ void InferenceContext::pagedAttention(std::shared_ptr out, RUN_INFINI(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size)); ensure_workspace(workspace_size); void *workspace = workspace_storage->memory(); - + const void* alibi_data = alibi_slopes ? alibi_slopes->data() : nullptr; RUN_INFINI(infiniopPagedAttention( desc, workspace, workspace_size, @@ -424,10 +429,3 @@ void InferenceContext::pagedAttention(std::shared_ptr out, block_tables->data(), seq_lens->data(), alibi_data, stream)); } - - - - - - - diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 96d8601c..12709095 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -61,7 +61,10 @@ struct InferenceContext { std::shared_ptr gate); void randomSample(std::shared_ptr out, std::shared_ptr prob, - float random_val, float top_p, uint32_t top_k, float temperature); + float random_val, float top_p, uint32_t top_k, float temperature, + float repetition_penalty = 1.0f, + const uint32_t *previous_tokens = nullptr, + size_t previous_tokens_len = 0); void linear(std::shared_ptr c, std::shared_ptr a, @@ -79,7 +82,7 @@ struct InferenceContext { std::shared_ptr k_cache, std::shared_ptr v_cache, std::shared_ptr slot_mapping); - + void pagedAttention(std::shared_ptr out, std::shared_ptr q, std::shared_ptr k_cache, @@ -180,8 +183,12 @@ inline void swiglu(std::shared_ptr out, std::shared_ptr up, } inline void randomSample(std::shared_ptr out, std::shared_ptr prob, - float random_val, float top_p, uint32_t top_k, float temperature) { - getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature); + float random_val, float top_p, uint32_t top_k, float temperature, + float repetition_penalty = 1.0f, + const uint32_t *previous_tokens = nullptr, + size_t previous_tokens_len = 0) { + getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature, repetition_penalty, + previous_tokens, previous_tokens_len); } inline void linear(std::shared_ptr c, std::shared_ptr a, @@ -211,5 +218,3 @@ inline void pagedAttention(std::shared_ptr out, std::shared_ptr std::shared_ptr alibi_slopes, float scale) { getInferenceContext().pagedAttention(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); } - - diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 8432bfee..e4764030 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -126,6 +126,9 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, uint32_t *output, void *last_logits) { auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; @@ -276,14 +279,19 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, } // Sample and Output if (idev == 0) { + // Calculate output scale: if dim_model_base != d, scale by dim_model_base/d, else no scaling + float output_scale = 1.0f; + if (meta.dim_model_base > 0 && meta.dim_model_base != meta.d && meta.d > 0) { + output_scale = meta.dim_model_base / float(meta.d); + } if (last_logits != nullptr) { rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); - linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); - + linear(last_logits_buf, logits_out, rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); + auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); logSoftmax(log_logits_buf, last_logits_buf); - + RUN_INFINI(infinirtStreamSynchronize(stream)); RUN_INFINI(infinirtMemcpy(last_logits, log_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); } @@ -297,16 +305,26 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rsrc.w_out_norm, meta.epsilon); } - linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); std::random_device _rd; std::mt19937 gen(_rd()); token_offset = 0; for (uint32_t req = 0; req < nreq; req++) { auto seq_len = req_lens[req]; float random_val = std::uniform_real_distribution(0, 1)(gen); + float rep_penalty_val = (repetition_penalty != nullptr) ? repetition_penalty[req] : 1.0f; + // Get unique tokens for this request (vLLM-style) + const uint32_t *prev_tokens = nullptr; + size_t prev_tokens_len = 0; + if (previous_tokens_per_req != nullptr && previous_tokens_per_req[req] != nullptr) { + prev_tokens = previous_tokens_per_req[req]; + prev_tokens_len = (previous_tokens_len_per_req != nullptr) ? previous_tokens_len_per_req[req] : 0; + } randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), - random_val, topp[req], topk[req], temperature[req]); + random_val, topp[req], topk[req], temperature[req], + rep_penalty_val, + prev_tokens, prev_tokens_len); token_offset += seq_len; } RUN_INFINI(infinirtStreamSynchronize(stream)); @@ -327,6 +345,9 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, const int32_t *block_tables, const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output, void *last_logits) { auto nlayer = meta.nlayer; @@ -425,7 +446,7 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); - + // MLP buffers auto gate_buf = gate_up_buf->slice(1, 0, di); auto up_buf = gate_up_buf->slice(1, di, di); @@ -451,7 +472,7 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto k = qkv_rope->slice({ {0, 0, ntok}, {1, nh, nkvh} }); auto v = qkv_rope->slice({ {0, 0, ntok}, {1, nh + nkvh, nkvh} }); - auto k_cache_pool = kv_caches[0]->k[idev][layer]; + auto k_cache_pool = kv_caches[0]->k[idev][layer]; auto v_cache_pool = kv_caches[0]->v[idev][layer]; pagedCaching(k, v, k_cache_pool, v_cache_pool, slot_mapping_buf); @@ -481,10 +502,10 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto o = o_buf->slice({{0, 0, ntok}})->view({ntok, nh, dh}); auto q_batch = qkv_rope->slice({ {0, 0, ntok}, {1, 0, nh} })->view({ntok, nh, dh}); float scale = 1.f / float(sqrt(dh)); - pagedAttention(o, q_batch, k_cache_pool, v_cache_pool, + pagedAttention(o, q_batch, k_cache_pool, v_cache_pool, block_tables_buf, seq_lens_buf, nullptr /* alibi_slopes */, scale); - - + + } } else { @@ -546,14 +567,19 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, // Sample and Output if (idev == 0) { + // Calculate output scale: if dim_model_base != d, scale by dim_model_base/d, else no scaling + float output_scale = 1.0f; + if (meta.dim_model_base > 0 && meta.dim_model_base != meta.d && meta.d > 0) { + output_scale = meta.dim_model_base / float(meta.d); + } if (last_logits != nullptr) { rmsnorm(logits_out, logits_in, rsrc.w_out_norm, meta.epsilon); auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); - linear(last_logits_buf, logits_out, rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); - + linear(last_logits_buf, logits_out, rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); + auto log_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool); logSoftmax(log_logits_buf, last_logits_buf); - + RUN_INFINI(infinirtStreamSynchronize(stream)); RUN_INFINI(infinirtMemcpy(last_logits, log_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H)); } @@ -567,16 +593,26 @@ void inferDeviceBatchPaged(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rsrc.w_out_norm, meta.epsilon); } - linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, 1.0, 0.0, nullptr, nullptr); + linear(prob_buf, logits_out->slice(0, 0, nreq), rsrc.w_out_embd, output_scale, 0.0, nullptr, nullptr); std::random_device _rd; std::mt19937 gen(_rd()); token_offset = 0; for (uint32_t req = 0; req < nreq; req++) { auto seq_len = req_lens[req]; float random_val = std::uniform_real_distribution(0, 1)(gen); + float rep_penalty_val = (repetition_penalty != nullptr) ? repetition_penalty[req] : 1.0f; + // Get unique tokens for this request (vLLM-style) + const uint32_t *prev_tokens = nullptr; + size_t prev_tokens_len = 0; + if (previous_tokens_per_req != nullptr && previous_tokens_per_req[req] != nullptr) { + prev_tokens = previous_tokens_per_req[req]; + prev_tokens_len = (previous_tokens_len_per_req != nullptr) ? previous_tokens_len_per_req[req] : 0; + } randomSample(result_buf->slice(0, req, 1)->view_as({}, {}), prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}), - random_val, topp[req], topk[req], temperature[req]); + random_val, topp[req], topk[req], temperature[req], + rep_penalty_val, + prev_tokens, prev_tokens_len); token_offset += seq_len; } RUN_INFINI(infinirtStreamSynchronize(stream)); @@ -595,6 +631,9 @@ inferBatchJiuge(struct JiugeModel *model, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, uint32_t *output) { model->req.tokens = tokens; model->req.ntok = ntok; @@ -607,6 +646,9 @@ inferBatchJiuge(struct JiugeModel *model, model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; + model->req.repetition_penalty = repetition_penalty; + model->req.previous_tokens_per_req = previous_tokens_per_req; + model->req.previous_tokens_len_per_req = previous_tokens_len_per_req; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -639,6 +681,7 @@ forwardBatchJiuge(struct JiugeModel *model, model->req.temperature = nullptr; model->req.topk = nullptr; model->req.topp = nullptr; + model->req.repetition_penalty = nullptr; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -658,10 +701,13 @@ __C void inferBatch(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, + struct KVCache **kv_caches, const int32_t *block_tables, const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const float *repetition_penalty, + const uint32_t *const *previous_tokens_per_req, + const uint32_t *previous_tokens_len_per_req, const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output) { model->req.tokens = tokens; @@ -677,6 +723,9 @@ inferBatch(struct JiugeModel *model, model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; + model->req.repetition_penalty = repetition_penalty; + model->req.previous_tokens_per_req = previous_tokens_per_req; + model->req.previous_tokens_len_per_req = previous_tokens_len_per_req; model->req.is_prefill = is_prefill; model->req.enable_paged_attn = enable_paged_attn; @@ -716,6 +765,7 @@ forwardBatch(struct JiugeModel *model, model->req.temperature = nullptr; model->req.topk = nullptr; model->req.topp = nullptr; + model->req.repetition_penalty = nullptr; model->req.is_prefill = is_prefill; model->req.enable_paged_attn = enable_paged_attn; @@ -764,15 +814,18 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDevic if (enable_paged){ inferDeviceBatchPaged(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.block_tables, req.slot_mapping, - req.temperature, req.topk, req.topp, + req.block_tables, req.slot_mapping, + req.temperature, req.topk, req.topp, req.repetition_penalty, + req.previous_tokens_per_req, req.previous_tokens_len_per_req, req.is_prefill, req.enable_paged_attn, req.output, req.logits); } else{ inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.temperature, req.topk, req.topp, req.output, req.logits); + req.temperature, req.topk, req.topp, req.repetition_penalty, + req.previous_tokens_per_req, req.previous_tokens_len_per_req, + req.output, req.logits); } @@ -800,7 +853,7 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi } for (int i = 0; i < ndev; i++) { - + threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); } for (int i = 0; i < ndev; i++) { diff --git a/src/models/jiuge/jiuge_impl.hpp b/src/models/jiuge/jiuge_impl.hpp index 5b1fd2f8..022e34da 100644 --- a/src/models/jiuge/jiuge_impl.hpp +++ b/src/models/jiuge/jiuge_impl.hpp @@ -50,6 +50,9 @@ struct InferRequest { const float *temperature; const uint32_t *topk; const float *topp; + const float *repetition_penalty; + const uint32_t *const *previous_tokens_per_req; // Array of pointers to unique tokens per request (vLLM-style) + const uint32_t *previous_tokens_len_per_req; // Array of lengths per request uint32_t *output; uint32_t is_prefill; bool enable_paged_attn; diff --git a/src/models/jiuge/jiuge_weight.hpp b/src/models/jiuge/jiuge_weight.hpp index 7ee10155..b36a513b 100644 --- a/src/models/jiuge/jiuge_weight.hpp +++ b/src/models/jiuge/jiuge_weight.hpp @@ -152,10 +152,30 @@ inline std::shared_ptr getSinTable(JiugeMeta const *meta) { auto unit = dsize(meta->dt_logits); void *table = std::malloc(meta->dctx * half_dh * unit); + bool is_longrope = (meta->rope_type == 1) && (meta->short_factor != nullptr) && (meta->long_factor != nullptr); + const float *factors = nullptr; + if (is_longrope) { + // Use long_factor if max context exceeds original_max_position_embeddings, otherwise short_factor + // Since we generate a single table at model creation time, choose based on dctx + if (meta->original_max_position_embeddings > 0 && meta->dctx > meta->original_max_position_embeddings) { + factors = meta->long_factor; + } else { + factors = meta->short_factor; + } + } + for (size_t i = 0; i < meta->dctx; i++) { for (size_t j = 0; j < half_dh; j++) { - float _sin = std::sin( - static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh)); + float angle; + if (is_longrope && factors != nullptr) { + // Longrope: apply per-frequency scaling factor + float inv_freq = 1.0f / (factors[j] * std::pow(meta->theta, static_cast(j) / half_dh)); + angle = static_cast(i) * inv_freq; + } else { + // Standard RoPE + angle = static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh); + } + float _sin = std::sin(angle); if (meta->dt_logits == INFINI_DTYPE_F16) { ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin); } else if (meta->dt_logits == INFINI_DTYPE_BF16) { @@ -179,10 +199,30 @@ inline std::shared_ptr getCosTable(JiugeMeta const *meta) { auto unit = dsize(meta->dt_logits); void *table = std::malloc(meta->dctx * half_dh * unit); + bool is_longrope = (meta->rope_type == 1) && (meta->short_factor != nullptr) && (meta->long_factor != nullptr); + const float *factors = nullptr; + if (is_longrope) { + // Use long_factor if max context exceeds original_max_position_embeddings, otherwise short_factor + // Since we generate a single table at model creation time, choose based on dctx + if (meta->original_max_position_embeddings > 0 && meta->dctx > meta->original_max_position_embeddings) { + factors = meta->long_factor; + } else { + factors = meta->short_factor; + } + } + for (size_t i = 0; i < meta->dctx; i++) { for (size_t j = 0; j < half_dh; j++) { - float _cos = std::cos( - static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh)); + float angle; + if (is_longrope && factors != nullptr) { + // Longrope: apply per-frequency scaling factor + float inv_freq = 1.0f / (factors[j] * std::pow(meta->theta, static_cast(j) / half_dh)); + angle = static_cast(i) * inv_freq; + } else { + // Standard RoPE + angle = static_cast(i) / std::pow(meta->theta, static_cast(j) / half_dh); + } + float _cos = std::cos(angle); if (meta->dt_logits == INFINI_DTYPE_F16) { ((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos); } else if (meta->dt_logits == INFINI_DTYPE_BF16) {