Skip to content

Commit 85184b6

Browse files
committed
feat:support jit.
Signed-off-by: Yuchuan <[email protected]>
1 parent 3e78a65 commit 85184b6

13 files changed

+382
-82
lines changed

cpp/grammar_compiler.cc

Lines changed: 21 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -19,81 +19,15 @@
1919
#include "fsm.h"
2020
#include "grammar_functor.h"
2121
#include "grammar_impl.h"
22+
#include "grammar_matcher_for_cache.h"
2223
#include "support/logging.h"
2324
#include "support/thread_pool.h"
2425
#include "support/thread_safe_cache.h"
2526
#include "support/utils.h"
2627

2728
namespace xgrammar {
2829

29-
/************** AdaptiveTokenMaskCache Generator **************/
30-
31-
/*! \brief The concrete implementation of GrammarMatcherNode. */
32-
class GrammarMatcherForTokenMaskCache : public EarleyParser {
33-
public:
34-
GrammarMatcherForTokenMaskCache(
35-
const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
36-
)
37-
: EarleyParser(grammar, init_state),
38-
init_rule_id(init_state.rule_id),
39-
initial_state(init_state) {}
40-
/*!
41-
* \brief Get the adaptive token mask for the given ParserState.
42-
* \param is_root_rule Whether to consider the parent rule. If false, there will be
43-
* no uncertain tokens. Useful for the root rule.
44-
*/
45-
AdaptiveTokenMask GetAdaptiveTokenMask(
46-
size_t vocab_size,
47-
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
48-
const std::vector<int32_t>& subtree_nodes_range,
49-
bool is_root_rule
50-
);
51-
52-
/*!
53-
* \brief Get the token mask for the given ParserState.
54-
* \param sorted_decoded_vocab The sorted decoded vocabulary.
55-
* \param first_char_mask The first character mask.
56-
* \param is_root_rule Whether to consider the parent rule. If false, there will be
57-
* no uncertain tokens. Useful for the root rule.
58-
* \returns True if the rejected indices are filled as usual, False otherwise.
59-
* It's used to determine which construction function will be used.
60-
*/
61-
bool GetTokenMaskWithFirstCharacterCheck(
62-
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
63-
const std::bitset<256>& first_char_mask,
64-
const std::vector<int>& subtree_nodes_range,
65-
bool is_root_rule
66-
);
67-
68-
private:
69-
/*! \brief Check if a token can pass the lookahead assertion. */
70-
std::pair</*acceptable*/ bool, /*can reach end*/ bool> IsTokenPassLookaheadAssertion(
71-
const std::string& token, const std::vector<bool>& can_reach_end_stack
72-
);
73-
74-
/*!
75-
* \brief Check if speculative calculation will be applied.
76-
* \return first: whether speculative calculation is applicable.
77-
* \return second: part of the first character mask,
78-
* which can be used in speculative calculation.
79-
*/
80-
std::pair<bool, std::bitset<256>> GetSpeculativeCalculation(
81-
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
82-
);
83-
84-
// The id of the initial rule.
85-
int32_t init_rule_id;
86-
87-
// The initial state of the parser.
88-
ParserState initial_state;
89-
90-
// Temporary data for GetAdaptiveTokenMask.
91-
std::vector<int32_t> tmp_accepted_indices_;
92-
std::vector<int32_t> tmp_rejected_indices_;
93-
std::vector<int32_t> tmp_uncertain_indices_;
94-
std::vector<bool> tmp_can_reach_end_stack_;
95-
std::vector<bool> tmp_can_reach_end_prefix_or_stack_;
96-
};
30+
/************** Use GrammarMatcher to generate the AdaptiveTokenMaskCache **************/
9731

9832
std::pair<bool, bool> GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion(
9933
const std::string& token, const std::vector<bool>& can_reach_end_stack
@@ -543,8 +477,8 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
543477
*/
544478
class GrammarCompilerNoCache {
545479
public:
546-
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads)
547-
: tokenizer_info_(tokenizer_info), max_threads_(max_threads) {}
480+
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads, bool is_jit)
481+
: tokenizer_info_(tokenizer_info), max_threads_(max_threads), is_jit_(is_jit) {}
548482

549483
CompiledGrammar CompileBuiltinJSONGrammar();
550484

@@ -573,6 +507,8 @@ class GrammarCompilerNoCache {
573507
const TokenizerInfo tokenizer_info_;
574508
/*! \brief The maximum number of threads to use. */
575509
const int max_threads_;
510+
/*! \brief Whether the jit mode is enabled.*/
511+
const bool is_jit_;
576512
};
577513

578514
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar) {
@@ -588,6 +524,9 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
588524
if (tokenizer_info_.GetVocabSize() == 0) {
589525
return CompiledGrammar(compiled_grammar_impl);
590526
}
527+
if (is_jit_) {
528+
return CompiledGrammar(compiled_grammar_impl);
529+
}
591530
// Step 3. Compute the adaptive token mask cache
592531
// The token mask cache is computed for these positions in the grammar:
593532
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
@@ -827,9 +766,10 @@ class GrammarCompiler::Impl {
827766
const TokenizerInfo& tokenizer_info,
828767
int max_threads,
829768
bool cache_enabled,
830-
int64_t max_memory_bytes
769+
int64_t max_memory_bytes,
770+
bool is_jit
831771
)
832-
: no_cache_compiler_(tokenizer_info, max_threads),
772+
: no_cache_compiler_(tokenizer_info, max_threads, is_jit),
833773
cache_enabled_(cache_enabled),
834774
compile_cache_(static_cast<std::size_t>(max_memory_bytes), Computer(*this)) {
835775
if (max_memory_bytes < -1) {
@@ -997,9 +937,16 @@ GrammarCompiler::GrammarCompiler(
997937
const TokenizerInfo& tokenizer_info,
998938
int max_threads,
999939
bool cache_enabled,
1000-
int64_t max_memory_bytes
940+
int64_t max_memory_bytes,
941+
bool is_jit
1001942
)
1002-
: pimpl_(std::make_shared<Impl>(tokenizer_info, max_threads, cache_enabled, max_memory_bytes)) {
943+
: pimpl_(std::make_shared<Impl>(
944+
tokenizer_info, max_threads, cache_enabled, max_memory_bytes, is_jit
945+
)) {
946+
if (max_memory_bytes < -1) {
947+
XGRAMMAR_LOG(FATAL) << "Invalid max_memory_bytes: " << max_memory_bytes << ". "
948+
<< "It should be -1 (unlimited) or a non-negative integer.";
949+
}
1003950
}
1004951

1005952
CompiledGrammar GrammarCompiler::CompileJSONSchema(

cpp/grammar_matcher.cc

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "compiled_grammar_impl.h"
1616
#include "earley_parser.h"
1717
#include "grammar_impl.h"
18+
#include "grammar_matcher_for_cache.h"
1819
#include "support/dynamic_bitset.h"
1920
#include "support/encoding.h"
2021
#include "support/int_set.h"
@@ -509,7 +510,7 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
509510
CheckAndGetBitmaskPtr(*next_token_bitmask, tokenizer_info_.GetVocabSize(), index);
510511
const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
511512
const auto& subtree_range = tokenizer_info_.GetTrieSubtreeNodesRange();
512-
const auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
513+
auto& adaptive_token_mask_cache = compiled_grammar_->adaptive_token_mask_cache;
513514
// We need to have a copy, because scanable_state_history_ will be modified during the
514515
// FillNextTokenBitmask process, which can lead to undefined behavior.
515516
auto latest_states = GetLatestScanableStates();
@@ -532,8 +533,31 @@ bool GrammarMatcher::Impl::FillNextTokenBitmask(
532533
std::vector<std::pair<ParserState, decltype(adaptive_token_mask_cache.cbegin())>>
533534
latest_states_with_masks;
534535

536+
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
537+
auto grammar_matcher = GrammarMatcherForTokenMaskCache(grammar_, state, false);
538+
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
539+
tokenizer_info_.GetVocabSize(),
540+
tokenizer_info_.GetSortedDecodedVocab(),
541+
tokenizer_info_.GetTrieSubtreeNodesRange(),
542+
is_root_rule
543+
);
544+
return adaptive_token_mask_cache.emplace(state, std::move(cur_adaptive_token_mask_cache)).first;
545+
};
546+
535547
for (const auto& state : latest_states) {
536548
auto adaptive_token_mask_it = adaptive_token_mask_cache.find(state);
549+
if (adaptive_token_mask_it == adaptive_token_mask_cache.end()) {
550+
// It means that the grammar is jit.
551+
bool is_root_rule = state.rule_id == grammar_->GetRootRuleId();
552+
ParserState state_to_check = ParserState{
553+
state.rule_id,
554+
state.sequence_id,
555+
state.element_id,
556+
ParserState::kNoPrevInputPos,
557+
state.sub_element_id
558+
};
559+
adaptive_token_mask_it = add_adaptive_token_mask(state_to_check, is_root_rule);
560+
}
537561
XGRAMMAR_CHECK(adaptive_token_mask_it != adaptive_token_mask_cache.end()) << state;
538562
const auto& adaptive_token_mask = adaptive_token_mask_it->second;
539563
latest_states_with_masks.push_back(std::make_pair(state, adaptive_token_mask_it));

cpp/grammar_matcher_for_cache.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*!
2+
* Copyright (c) 2025 by Contributors
3+
* \file xgrammar/grammar_matcher_for_cache.h
4+
* \brief The header for the grammar matcher for the cache.
5+
*/
6+
7+
#ifndef XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_
8+
#define XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_
9+
10+
#include <bitset>
11+
12+
#include "compiled_grammar_impl.h"
13+
#include "earley_parser.h"
14+
15+
namespace xgrammar {
16+
/*! \brief The concrete implementation of GrammarMatcherNode. */
17+
class GrammarMatcherForTokenMaskCache : public EarleyParser {
18+
public:
19+
GrammarMatcherForTokenMaskCache(
20+
const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
21+
)
22+
: EarleyParser(grammar, init_state),
23+
init_rule_id(init_state.rule_id),
24+
initial_state(init_state) {}
25+
/*!
26+
* \brief Get the adaptive token mask for the given ParserState.
27+
* \param is_root_rule Whether to consider the parent rule. If false, there will be
28+
* no uncertain tokens. Useful for the root rule.
29+
*/
30+
AdaptiveTokenMask GetAdaptiveTokenMask(
31+
size_t vocab_size,
32+
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
33+
const std::vector<int32_t>& subtree_nodes_range,
34+
bool is_root_rule
35+
);
36+
37+
/*!
38+
* \brief Get the token mask for the given ParserState.
39+
* \param sorted_decoded_vocab The sorted decoded vocabulary.
40+
* \param first_char_mask The first character mask.
41+
* \param is_root_rule Whether to consider the parent rule. If false, there will be
42+
* no uncertain tokens. Useful for the root rule.
43+
* \returns True if the rejected indices are filled as usual, False otherwise.
44+
* It's used to determine which construction function will be used.
45+
*/
46+
bool GetTokenMaskWithFirstCharacterCheck(
47+
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
48+
const std::bitset<256>& first_char_mask,
49+
const std::vector<int>& subtree_nodes_range,
50+
bool is_root_rule
51+
);
52+
53+
private:
54+
/*! \brief Check if a token can pass the lookahead assertion. */
55+
std::pair</*acceptable*/ bool, /*can reach end*/ bool> IsTokenPassLookaheadAssertion(
56+
const std::string& token, const std::vector<bool>& can_reach_end_stack
57+
);
58+
59+
/*!
60+
* \brief Check if speculative calculation will be applied.
61+
* \return first: whether speculative calculation is applicable.
62+
* \return second: part of the first character mask,
63+
* which can be used in speculative calculation.
64+
*/
65+
std::pair<bool, std::bitset<256>> GetSpeculativeCalculation(
66+
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
67+
);
68+
69+
// The id of the initial rule.
70+
int32_t init_rule_id;
71+
72+
// The initial state of the parser.
73+
ParserState initial_state;
74+
75+
// Temporary data for GetAdaptiveTokenMask.
76+
std::vector<int32_t> tmp_accepted_indices_;
77+
std::vector<int32_t> tmp_rejected_indices_;
78+
std::vector<int32_t> tmp_uncertain_indices_;
79+
std::vector<bool> tmp_can_reach_end_stack_;
80+
std::vector<bool> tmp_can_reach_end_prefix_or_stack_;
81+
};
82+
} // namespace xgrammar
83+
84+
#endif // XGRAMMAR_GRAMMAR_MATCHER_FOR_CACHE_H_

cpp/nanobind/nanobind.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ NB_MODULE(xgrammar_bindings, m) {
168168
.def_static("deserialize_json", &CompiledGrammar_DeserializeJSON);
169169

170170
auto pyGrammarCompiler = nb::class_<GrammarCompiler>(m, "GrammarCompiler");
171-
pyGrammarCompiler.def(nb::init<const TokenizerInfo&, int, bool, int64_t>())
171+
pyGrammarCompiler.def(nb::init<const TokenizerInfo&, int, bool, int64_t, bool>())
172172
.def(
173173
"compile_json_schema",
174174
&GrammarCompiler::CompileJSONSchema,

include/xgrammar/compiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ class GrammarCompiler {
6868
const TokenizerInfo& tokenizer_info,
6969
int max_threads = 8,
7070
bool cache_enabled = true,
71-
int64_t max_memory_bytes = -1 // unlimited
71+
int64_t max_memory_bytes = -1, // unlimited
72+
bool is_jit = false
7273
);
7374

7475
/*! \brief Get the compiled grammar for a JSON schema string. */

python/xgrammar/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
max_threads: int = 8,
112112
cache_enabled: bool = True,
113113
cache_limit_bytes: int = -1,
114+
is_jit: bool = False,
114115
):
115116
"""Construct the compiler.
116117
@@ -137,7 +138,7 @@ def __init__(
137138

138139
self._init_handle(
139140
_core.GrammarCompiler(
140-
tokenizer_info._handle, max_threads, cache_enabled, cache_limit_bytes
141+
tokenizer_info._handle, max_threads, cache_enabled, cache_limit_bytes, is_jit
141142
)
142143
)
143144

python/xgrammar/testing.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def _ebnf_to_grammar_no_normalization(ebnf_string: str, root_rule_name: str = "r
111111
)
112112

113113

114-
def _get_matcher_from_grammar(grammar: Union[Grammar, str], **kwargs) -> GrammarMatcher:
114+
def _get_matcher_from_grammar(
115+
grammar: Union[Grammar, str], is_jit: bool = False, **kwargs
116+
) -> GrammarMatcher:
115117
"""Create a GrammarMatcher from a grammar. The tokenizer info will be set to an empty
116118
TokenizerInfo. The result matcher can only accept strings, and cannot accept tokens.
117119
@@ -127,7 +129,7 @@ def _get_matcher_from_grammar(grammar: Union[Grammar, str], **kwargs) -> Grammar
127129
The created grammar matcher.
128130
"""
129131
tokenizer_info = TokenizerInfo([])
130-
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
132+
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False, is_jit=is_jit)
131133
compiled_grammar = grammar_compiler.compile_grammar(grammar)
132134
return GrammarMatcher(compiled_grammar, terminate_without_stop_token=True, **kwargs)
133135

@@ -263,7 +265,10 @@ def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
263265

264266

265267
def _get_matcher_from_grammar_and_tokenizer_info(
266-
grammar: Union[Grammar, str], tokenizer_info: Optional[TokenizerInfo] = None, **kwargs
268+
grammar: Union[Grammar, str],
269+
tokenizer_info: Optional[TokenizerInfo] = None,
270+
is_jit: bool = False,
271+
**kwargs,
267272
) -> GrammarMatcher:
268273
"""Create a GrammarMatcher from a grammar and tokenizer info.
269274
@@ -285,7 +290,7 @@ def _get_matcher_from_grammar_and_tokenizer_info(
285290
"""
286291
if tokenizer_info is None:
287292
tokenizer_info = TokenizerInfo([])
288-
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False)
293+
grammar_compiler = GrammarCompiler(tokenizer_info, cache_enabled=False, is_jit=is_jit)
289294
compiled_grammar = grammar_compiler.compile_grammar(grammar)
290295
return GrammarMatcher(compiled_grammar, **kwargs)
291296

tests/python/test_grammar_compiler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,5 +329,16 @@ def make_schema(name_str: str):
329329
assert grammar_compiler.get_cache_size_bytes() == 0
330330

331331

332+
@pytest.mark.hf_token_required
333+
def test_grammar_compiler_jit():
334+
grammar = xgr.Grammar.builtin_json_grammar()
335+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
336+
compiler = xgr.GrammarCompiler(xgr.TokenizerInfo.from_huggingface(tokenizer), is_jit=True)
337+
time_start = time.monotonic_ns()
338+
_ = compiler.compile_grammar(grammar)
339+
time_end = time.monotonic_ns()
340+
print(f"JIT compilation time: {(time_end - time_start) / 1e6} ms")
341+
342+
332343
if __name__ == "__main__":
333344
pytest.main(sys.argv)

0 commit comments

Comments
 (0)