Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy
import functools
import torch
from typing import List, Tuple

Expand Down Expand Up @@ -34,6 +36,22 @@ def init_custom(self):
eos_token_ids = []
eos_token_ids.append(self.tokenizer.eos_token_id)
eos_token_ids.extend(self.args.eos_id)

@functools.lru_cache(maxsize=200)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The lru_cache decorator is used without specifying a typed=True argument. If the type argument to dispatch_grammar can vary in type but represent the same value (e.g., "grammar" vs str("grammar")), the cache may not function as expected. Consider adding typed=True to the decorator to differentiate cache entries by argument type.

Suggested change
@functools.lru_cache(maxsize=200)
@functools.lru_cache(maxsize=200, typed=True)
def dispatch_grammar(type: str, grammar: str):

def get_cached_grammar(type: str, grammar: str):
logger.info(f"grammar cache miss for {type}: '{grammar}'")
try:
if type == "grammar":
return self.xgrammar_compiler.compile_grammar(grammar)
elif type == "schema":
return self.xgrammar_compiler.compile_json_schema(grammar)
else:
raise ValueError(f"Unknown xgrammar type: {type}")
except Exception as e:
logger.error(f"Failed to compile {type}: {e}")
raise

self.get_cached_grammar = get_cached_grammar
return

@calculate_time(show=False, min_cost_ms=300)
Expand Down Expand Up @@ -149,10 +167,10 @@ def _init_req_xgrammer_matcher_infos(self, run_reqs: List[InferReq]):
sample_params = run_obj.sampling_param
if sample_params.guided_grammar is not None:
if not hasattr(sample_params, "xgrammar_matcher"):
xgrammar_compiled_grammar = self.xgrammar_compiler.compile_grammar(sample_params.guided_grammar)
sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar)
ctx = self.get_cached_grammar("grammar", sample_params.guided_grammar)
sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx)
elif sample_params.guided_json is not None:
if not hasattr(sample_params, "xgrammar_matcher"):
xgrammar_compiled_grammar = self.xgrammar_compiler.compile_json_schema(sample_params.guided_json)
sample_params.xgrammar_matcher = xgr.GrammarMatcher(xgrammar_compiled_grammar)
ctx = self.get_cached_grammar("schema", sample_params.guided_json)
sample_params.xgrammar_matcher = xgr.GrammarMatcher(ctx)
return