Skip to content

Commit 6eceaa5

Browse files
committed
feat:support jit.
Signed-off-by: Yuchuan <[email protected]>
1 parent 6a01260 commit 6eceaa5

13 files changed

+375
-78
lines changed

cpp/grammar_compiler.cc

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "fsm.h"
2020
#include "grammar_functor.h"
2121
#include "grammar_impl.h"
22+
#include "grammar_matcher_for_cache.h"
2223
#include "support/logging.h"
2324
#include "support/thread_pool.h"
2425
#include "support/thread_safe_cache.h"
@@ -44,73 +45,6 @@ namespace xgrammar {
4445

4546
/************** Use GrammarMatcher to generate the AdaptiveTokenMaskCache **************/
4647

47-
/*! \brief The concrete implementation of GrammarMatcherNode. */
48-
class GrammarMatcherForTokenMaskCache : public EarleyParser {
49-
public:
50-
GrammarMatcherForTokenMaskCache(
51-
const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
52-
)
53-
: EarleyParser(grammar, init_state),
54-
init_rule_id(init_state.rule_id),
55-
initial_state(init_state) {}
56-
/*!
57-
* \brief Get the adaptive token mask for the given ParserState.
58-
* \param is_root_rule Whether to consider the parent rule. If false, there will be
59-
* no uncertain tokens. Useful for the root rule.
60-
*/
61-
AdaptiveTokenMask GetAdaptiveTokenMask(
62-
size_t vocab_size,
63-
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
64-
const std::vector<int32_t>& subtree_nodes_range,
65-
bool is_root_rule
66-
);
67-
68-
/*!
69-
* \brief Get the token mask for the given ParserState.
70-
* \param sorted_decoded_vocab The sorted decoded vocabulary.
71-
* \param first_char_mask The first character mask.
72-
* \param is_root_rule Whether to consider the parent rule. If false, there will be
73-
* no uncertain tokens. Useful for the root rule.
74-
* \returns True if the rejected indices are filled as usual, False otherwise.
75-
* It's used to determine which construction function will be used.
76-
*/
77-
bool GetTokenMaskWithFirstCharacterCheck(
78-
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
79-
const std::bitset<256>& first_char_mask,
80-
const std::vector<int>& subtree_nodes_range,
81-
bool is_root_rule
82-
);
83-
84-
private:
85-
/*! \brief Check if a token can pass the lookahead assertion. */
86-
std::pair</*acceptable*/ bool, /*can reach end*/ bool> IsTokenPassLookaheadAssertion(
87-
const std::string& token, const std::vector<bool>& can_reach_end_stack
88-
);
89-
90-
/*!
91-
* \brief Check if speculative calculation will be applied.
92-
* \return first: whether speculative calculation is applicable.
93-
* \return second: part of the first character mask,
94-
* which can be used in speculative calculation.
95-
*/
96-
std::pair<bool, std::bitset<256>> GetSpeculativeCalculation(
97-
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
98-
);
99-
100-
// The id of the initial rule.
101-
int32_t init_rule_id;
102-
103-
// The initial state of the parser.
104-
ParserState initial_state;
105-
106-
// Temporary data for GetAdaptiveTokenMask.
107-
std::vector<int32_t> tmp_accepted_indices_;
108-
std::vector<int32_t> tmp_rejected_indices_;
109-
std::vector<int32_t> tmp_uncertain_indices_;
110-
std::vector<bool> tmp_can_reach_end_stack_;
111-
std::vector<bool> tmp_can_reach_end_prefix_or_stack_;
112-
};
113-
11448
std::pair<bool, bool> GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(
11549
const std::string& token, const std::vector<bool>& can_reach_end_stack
11650
) {
@@ -565,11 +499,13 @@ class GrammarCompiler::Impl {
565499
const TokenizerInfo& tokenizer_info,
566500
int max_threads,
567501
bool cache_enabled,
568-
long long max_memory_bytes
502+
long long max_memory_bytes,
503+
bool is_jit = false
569504
)
570505
: tokenizer_info_(tokenizer_info),
571506
max_threads_(max_threads),
572507
cache_enabled_(cache_enabled),
508+
is_jit_(is_jit),
573509
compile_builtin_json_grammar_cache_([&] { return CompileJson(); }),
574510
compile_cache_(static_cast<std::size_t>(max_memory_bytes), *this) {}
575511

@@ -642,6 +578,8 @@ class GrammarCompiler::Impl {
642578
const int max_threads_;
643579
/*! \brief Whether the cache is enabled. */
644580
const bool cache_enabled_;
581+
/*! \brief Whether the jit mode is enabled.*/
582+
const bool is_jit_;
645583

646584
ThreadSafeCache<CompiledGrammar> compile_builtin_json_grammar_cache_;
647585
ThreadSafeLRUCache<MultipleKey, CompiledGrammar, Computer, SizeEstimator> compile_cache_;
@@ -658,6 +596,9 @@ CompiledGrammar GrammarCompiler::Impl::MultiThreadCompileGrammar(Grammar grammar
658596
if (tokenizer_info_.GetVocabSize() == 0) {
659597
return CompiledGrammar(compiled_grammar_impl);
660598
}
599+
if (is_jit_) {
600+
return CompiledGrammar(compiled_grammar_impl);
601+
}
661602
// Step 3. Compute the adaptive token mask cache
662603
// The token mask cache is computed for these positions in the grammar:
663604
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
@@ -862,9 +803,12 @@ GrammarCompiler::GrammarCompiler(
862803
const TokenizerInfo& tokenizer_info,
863804
int max_threads,
864805
bool cache_enabled,
865-
long long max_memory_bytes
806+
long long max_memory_bytes,
807+
bool is_jit
866808
)
867-
: pimpl_(std::make_shared<Impl>(tokenizer_info, max_threads, cache_enabled, max_memory_bytes)) {
809+
: pimpl_(std::make_shared<Impl>(
810+
tokenizer_info, max_threads, cache_enabled, max_memory_bytes, is_jit
811+
)) {
868812
if (max_memory_bytes < -1) {
869813
XGRAMMAR_LOG(FATAL) << "Invalid max_memory_bytes: " << max_memory_bytes << ". "
870814
<< "It should be -1 (unlimited) or a non-negative integer.";

cpp/grammar_matcher.cc

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "compiled_grammar_impl.h"
1616
#include "earley_parser.h"
1717
#include "grammar_impl.h"
18+
#include "grammar_matcher_for_cache.h"
1819
#include "support/dynamic_bitset.h"
1920
#include "support/encoding.h"
2021
#include "support/int_set.h"
@@ -506,7 +507,7 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
506507
CheckAndGetBitmaskPtr(*next_token_bitmask, tokenizer_info_.GetVocabSize(), index);
507508
const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
508509
const auto& subtree_range = tokenizer_info_.GetTrieSubtreeNodesRange();
509-
const auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
510+
auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
510511
// We need to have a copy, because scanable_state_history_ will be modified during the
511512
// FillNextTokenBitmask process, which can lead to undefined behavior.
512513
auto latest_states = GetLatestScanableStates();
@@ -529,8 +530,31 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
529530
std::vector<std::pair<ParserState, decltype(adaptive_token_mask_cache.cbegin())>>
530531
latest_states_with_masks;
531532

533+
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
534+
auto grammar_matcher = GrammarMatcherForTokenMaskCache(grammar_, state, false);
535+
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
536+
tokenizer_info_.GetVocabSize(),
537+
tokenizer_info_.GetSortedDecodedVocab(),
538+
tokenizer_info_.GetTrieSubtreeNodesRange(),
539+
is_root_rule
540+
);
541+
return adaptive_token_mask_cache.emplace(state, std::move(cur_adaptive_token_mask_cache)).first;
542+
};
543+
532544
for (const auto& state : latest_states) {
533545
auto adaptive_token_mask_it = adaptive_token_mask_cache.find(state);
546+
if (adaptive_token_mask_it == adaptive_token_mask_cache.end()) {
547+
// It means that the grammar is jit.
548+
bool is_root_rule = state.rule_id == grammar_->GetRootRuleId();
549+
ParserState state_to_check = ParserState{
550+
state.rule_id,
551+
state.sequence_id,
552+
state.element_id,
553+
ParserState::kNoPrevInputPos,
554+
state.sub_element_id
555+
};
556+
adaptive_token_mask_it = add_adaptive_token_mask(state_to_check, is_root_rule);
557+
}
534558
XGRAMMAR_CHECK(adaptive_token_mask_it != adaptive_token_mask_cache.end()) << state;
535559
const auto& adaptive_token_mask = adaptive_token_mask_it->second;
536560
latest_states_with_masks.push_back(std::make_pair(state, adaptive_token_mask_it));

cpp/grammar_matcher_for_cache.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*!
2+
* Copyright (c) 2025 by Contributors
3+
* \file xgrammar/grammar_matcher_for_cache.h
4+
* \brief The header for the grammar matcher for the cache.
5+
*/
6+
7+
#ifndef XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_
8+
#define XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_
9+
10+
#include <bitset>
11+
12+
#include "compiled_grammar_impl.h"
13+
#include "earley_parser.h"
14+
15+
namespace xgrammar {
16+
/*! \brief The concrete implementation of GrammarMatcherNode. */
17+
class GrammarMatcherForTokenMaskCache : public EarleyParser {
18+
public:
19+
GrammarMatcherForTokenMaskCache(
20+
const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
21+
)
22+
: EarleyParser(grammar, init_state),
23+
init_rule_id(init_state.rule_id),
24+
initial_state(init_state) {}
25+
/*!
26+
* \brief Get the adaptive token mask for the given ParserState.
27+
* \param is_root_rule Whether to consider the parent rule. If false, there will be
28+
* no uncertain tokens. Useful for the root rule.
29+
*/
30+
AdaptiveTokenMask GetAdaptiveTokenMask(
31+
size_t vocab_size,
32+
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
33+
const std::vector<int32_t>& subtree_nodes_range,
34+
bool is_root_rule
35+
);
36+
37+
/*!
38+
* \brief Get the token mask for the given ParserState.
39+
* \param sorted_decoded_vocab The sorted decoded vocabulary.
40+
* \param first_char_mask The first character mask.
41+
* \param is_root_rule Whether to consider the parent rule. If false, there will be
42+
* no uncertain tokens. Useful for the root rule.
43+
* \returns True if the rejected indices are filled as usual, False otherwise.
44+
* It's used to determine which construction function will be used.
45+
*/
46+
bool GetTokenMaskWithFirstCharacterCheck(
47+
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
48+
const std::bitset<256>& first_char_mask,
49+
const std::vector<int>& subtree_nodes_range,
50+
bool is_root_rule
51+
);
52+
53+
private:
54+
/*! \brief Check if a token can pass the lookahead assertion. */
55+
std::pair</*acceptable*/ bool, /*can reach end*/ bool> IsTokenPassLookaheadAssertion(
56+
const std::string& token, const std::vector<bool>& can_reach_end_stack
57+
);
58+
59+
/*!
60+
* \brief Check if speculative calculation will be applied.
61+
* \return first: whether speculative calculation is applicable.
62+
* \return second: part of the first character mask,
63+
* which can be used in speculative calculation.
64+
*/
65+
std::pair<bool, std::bitset<256>> GetSpeculativeCalculation(
66+
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
67+
);
68+
69+
// The id of the initial rule.
70+
int32_t init_rule_id;
71+
72+
// The initial state of the parser.
73+
ParserState initial_state;
74+
75+
// Temporary data for GetAdaptiveTokenMask.
76+
std::vector<int32_t> tmp_accepted_indices_;
77+
std::vector<int32_t> tmp_rejected_indices_;
78+
std::vector<int32_t> tmp_uncertain_indices_;
79+
std::vector<bool> tmp_can_reach_end_stack_;
80+
std::vector<bool> tmp_can_reach_end_prefix_or_stack_;
81+
};
82+
} // namespace xgrammar
83+
84+
#endif // XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_

cpp/nanobind/nanobind.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ NB_MODULE(xgrammar_bindings, m) {
177177
.def_static("deserialize_json", &CompiledGrammar_DeserializeJSON);
178178

179179
auto pyGrammarCompiler = nb::class_<GrammarCompiler>(m, "GrammarCompiler");
180-
pyGrammarCompiler.def(nb::init<const TokenizerInfo&, int, bool, long long>())
180+
pyGrammarCompiler.def(nb::init<const TokenizerInfo&, int, bool, long long, bool>())
181181
.def(
182182
"compile_json_schema",
183183
&GrammarCompiler::CompileJSONSchema,

include/xgrammar/compiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ class GrammarCompiler {
6868
const TokenizerInfo& tokenizer_info,
6969
int max_threads = 8,
7070
bool cache_enabled = true,
71-
long long max_memory_bytes = -1 // unlimited
71+
long long max_memory_bytes = -1, // unlimited
72+
bool is_jit = false
7273
);
7374

7475
/*! \brief Get the compiled grammar for a JSON schema string. */

python/xgrammar/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def __init__(
104104
max_threads: int = 8,
105105
cache_enabled: bool = True,
106106
cache_limit_bytes: int = -1,
107+
is_jit: bool = False,
107108
):
108109
"""Construct the compiler.
109110
@@ -130,7 +131,7 @@ def __init__(
130131

131132
self._init_handle(
132133
_core.GrammarCompiler(
133-
tokenizer_info._handle, max_threads, cache_enabled, cache_limit_bytes
134+
tokenizer_info._handle, max_threads, cache_enabled, cache_limit_bytes, is_jit
134135
)
135136
)
136137

python/xgrammar/testing.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def _ebnf_to_grammar_no_normalization(ebnf_string: str, root_rule_name: str = "r
104104
)
105105

106106

107-
def _get_matcher_from_grammar(grammar: Union[Grammar, str], **kwargs) -> GrammarMatcher:
107+
def _get_matcher_from_grammar(
108+
grammar: Union[Grammar, str], is_jit: bool = False, **kwargs
109+
) -> GrammarMatcher:
108110
"""Create a GrammarMatcher from a grammar. The tokenizer info will be set to an empty
109111
TokenizerInfo. The result matcher can only accept strings, and cannot accept tokens.
110112
@@ -120,7 +122,7 @@ def _get_matcher_from_grammar(grammar: Union[Grammar, str], **kwargs) -> Grammar
120122
The created grammar matcher.
121123
"""
122124
tokenizer_info = TokenizerInfo([])
123-
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
125+
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False, is_jit=is_jit)
124126
compiled_grammar = grammar_compiler.compile_grammar(grammar)
125127
return GrammarMatcher(compiled_grammar, terminate_without_stop_token=True, **kwargs)
126128

@@ -256,7 +258,10 @@ def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
256258

257259

258260
def _get_matcher_from_grammar_and_tokenizer_info(
259-
grammar: Union[Grammar, str], tokenizer_info: Optional[TokenizerInfo] = None, **kwargs
261+
grammar: Union[Grammar, str],
262+
tokenizer_info: Optional[TokenizerInfo] = None,
263+
is_jit: bool = False,
264+
**kwargs,
260265
) -> GrammarMatcher:
261266
"""Create a GrammarMatcher from a grammar and tokenizer info.
262267
@@ -278,7 +283,7 @@ def _get_matcher_from_grammar_and_tokenizer_info(
278283
"""
279284
if tokenizer_info is None:
280285
tokenizer_info = TokenizerInfo([])
281-
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
286+
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False, is_jit=is_jit)
282287
compiled_grammar = grammar_compiler.compile_grammar(grammar)
283288
return GrammarMatcher(compiled_grammar, **kwargs)
284289

tests/python/test_grammar_compiler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,5 +329,16 @@ def make_schema(name_str: str):
329329
assert grammar_compiler.get_cache_size_bytes() == 0
330330

331331

332+
@pytest.mark.hf_token_required
333+
def test_grammar_compiler_jit():
334+
grammar = xgr.Grammar.builtin_json_grammar()
335+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
336+
compiler = xgr.GrammarCompiler(xgr.TokenizerInfo.from_huggingface(tokenizer), is_jit=True)
337+
time_start = time.monotonic_ns()
338+
_ = compiler.compile_grammar(grammar)
339+
time_end = time.monotonic_ns()
340+
print(f"JIT compilation time: {(time_end - time_start) / 1e6} ms")
341+
342+
332343
if __name__ == "__main__":
333344
pytest.main(sys.argv)

0 commit comments

Comments
 (0)