diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index 8cdd840e6..72172431f 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -1,3 +1,5 @@ +import copy +import functools import torch from typing import List, Tuple @@ -34,6 +36,22 @@ def init_custom(self): eos_token_ids = [] eos_token_ids.append(self.tokenizer.eos_token_id) eos_token_ids.extend(self.args.eos_id) + + @functools.lru_cache(maxsize=200) + def get_cached_grammar(type: str, grammar: str): + logger.info(f"grammar cache miss for {type}: '{grammar}'") + try: + if type == "grammar": + return self.xgrammar_compiler.compile_grammar(grammar) + elif type == "schema": + return self.xgrammar_compiler.compile_json_schema(grammar) + else: + raise ValueError(f"Unknown xgrammar type: {type}") + except Exception as e: + logger.error(f"Failed to compile {type}: {e}") + raise + + self.get_cached_grammar = get_cached_grammar return @calculate_time(show=False, min_cost_ms=300) @@ -149,10 +167,10 @@ def _init_req_xgrammer_matcher_infos(self, run_reqs: List[InferReq]): sample_params = run_obj.sampling_param if sample_params.guided_grammar is not None: if not hasattr(sample_params, "xgrammar_matcher"): - xgrammar_compiled_grammar = self.xgrammar_compiler.compile_grammar(sample_params.guided_grammar) - sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) + ctx = self.get_cached_grammar("grammar", sample_params.guided_grammar) + sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx) elif sample_params.guided_json is not None: if not hasattr(sample_params, "xgrammar_matcher"): - xgrammar_compiled_grammar = self.xgrammar_compiler.compile_json_schema(sample_params.guided_json) - sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar) + ctx = self.get_cached_grammar("schema", sample_params.guided_json) + sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx) return