1919#include " fsm.h"
2020#include " grammar_functor.h"
2121#include " grammar_impl.h"
22+ #include " grammar_matcher_for_cache.h"
2223#include " support/logging.h"
2324#include " support/thread_pool.h"
2425#include " support/thread_safe_cache.h"
@@ -44,73 +45,6 @@ namespace xgrammar {
4445
4546/* ************* Use GrammarMatcher to generate the AdaptiveTokenMaskCache **************/
4647
47- /* ! \brief The concrete implementation of GrammarMatcherNode. */
48- class GrammarMatcherForTokenMaskCache : public EarleyParser {
49- public:
50- GrammarMatcherForTokenMaskCache (
51- const Grammar& grammar, const ParserState& init_state, const bool & need_expand = true
52- )
53- : EarleyParser(grammar, init_state),
54- init_rule_id (init_state.rule_id),
55- initial_state(init_state) {}
56- /* !
57- * \brief Get the adaptive token mask for the given ParserState.
58- * \param is_root_rule Whether to consider the parent rule. If false, there will be
59- * no uncertain tokens. Useful for the root rule.
60- */
61- AdaptiveTokenMask GetAdaptiveTokenMask (
62- size_t vocab_size,
63- const std::vector<std::pair<int32_t , std::string>>& sorted_decoded_vocab,
64- const std::vector<int32_t >& subtree_nodes_range,
65- bool is_root_rule
66- );
67-
68- /* !
69- * \brief Get the token mask for the given ParserState.
70- * \param sorted_decoded_vocab The sorted decoded vocabulary.
71- * \param first_char_mask The first character mask.
72- * \param is_root_rule Whether to consider the parent rule. If false, there will be
73- * no uncertain tokens. Useful for the root rule.
74- * \returns True if the rejected indices are filled as usual, False otherwise.
75- * It's used to determine which construction function will be used.
76- */
77- bool GetTokenMaskWithFirstCharacterCheck (
78- const std::vector<std::pair<int32_t , std::string>>& sorted_decoded_vocab,
79- const std::bitset<256 >& first_char_mask,
80- const std::vector<int >& subtree_nodes_range,
81- bool is_root_rule
82- );
83-
84- private:
85- /* ! \brief Check if a token can pass the lookahead assertion. */
86- std::pair</* acceptable*/ bool , /* can reach end*/ bool > IsTokenPassLookaheadAssertion (
87- const std::string& token, const std::vector<bool >& can_reach_end_stack
88- );
89-
90- /* !
91- * \brief Check if speculative calculation will be applied.
92- * \return first: whether speculative calculation is applicable.
93- * \return second: part of the first character mask,
94- * which can be used in speculative calculation.
95- */
96- std::pair<bool , std::bitset<256 >> GetSpeculativeCalculation (
97- const std::vector<std::pair<int32_t , std::string>>& sorted_decoded_vocab
98- );
99-
100- // The id of the initial rule.
101- int32_t init_rule_id;
102-
103- // The initial state of the parser.
104- ParserState initial_state;
105-
106- // Temporary data for GetAdaptiveTokenMask.
107- std::vector<int32_t > tmp_accepted_indices_;
108- std::vector<int32_t > tmp_rejected_indices_;
109- std::vector<int32_t > tmp_uncertain_indices_;
110- std::vector<bool > tmp_can_reach_end_stack_;
111- std::vector<bool > tmp_can_reach_end_prefix_or_stack_;
112- };
113-
11448std::pair<bool , bool > GrammarMatcherForTokenMaskCache::IsTokenPassLookaheadAssertion (
11549 const std::string& token, const std::vector<bool >& can_reach_end_stack
11650) {
@@ -565,11 +499,13 @@ class GrammarCompiler::Impl {
565499 const TokenizerInfo& tokenizer_info,
566500 int max_threads,
567501 bool cache_enabled,
568- long long max_memory_bytes
502+ long long max_memory_bytes,
503+ bool is_jit = false
569504 )
570505 : tokenizer_info_(tokenizer_info),
571506 max_threads_ (max_threads),
572507 cache_enabled_(cache_enabled),
508+ is_jit_(is_jit),
573509 compile_builtin_json_grammar_cache_([&] { return CompileJson (); }),
574510 compile_cache_(static_cast <std::size_t >(max_memory_bytes), *this) {}
575511
@@ -642,6 +578,8 @@ class GrammarCompiler::Impl {
642578 const int max_threads_;
643579 /* ! \brief Whether the cache is enabled. */
644580 const bool cache_enabled_;
581+ /* ! \brief Whether the jit mode is enabled.*/
582+ const bool is_jit_;
645583
646584 ThreadSafeCache<CompiledGrammar> compile_builtin_json_grammar_cache_;
647585 ThreadSafeLRUCache<MultipleKey, CompiledGrammar, Computer, SizeEstimator> compile_cache_;
@@ -658,6 +596,9 @@ CompiledGrammar GrammarCompiler::Impl::MultiThreadCompileGrammar(Grammar grammar
658596 if (tokenizer_info_.GetVocabSize () == 0 ) {
659597 return CompiledGrammar (compiled_grammar_impl);
660598 }
599+ if (is_jit_) {
600+ return CompiledGrammar (compiled_grammar_impl);
601+ }
661602 // Step 3. Compute the adaptive token mask cache
662603 // The token mask cache is computed for these positions in the grammar:
663604 // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
@@ -862,9 +803,12 @@ GrammarCompiler::GrammarCompiler(
862803 const TokenizerInfo& tokenizer_info,
863804 int max_threads,
864805 bool cache_enabled,
865- long long max_memory_bytes
806+ long long max_memory_bytes,
807+ bool is_jit
866808)
867- : pimpl_(std::make_shared<Impl>(tokenizer_info, max_threads, cache_enabled, max_memory_bytes)) {
809+ : pimpl_(std::make_shared<Impl>(
810+ tokenizer_info, max_threads, cache_enabled, max_memory_bytes, is_jit
811+ )) {
868812 if (max_memory_bytes < -1 ) {
869813 XGRAMMAR_LOG (FATAL) << " Invalid max_memory_bytes: " << max_memory_bytes << " . "
870814 << " It should be -1 (unlimited) or a non-negative integer." ;
0 commit comments