Skip to content

Commit 59fa487

Browse files
committed
refactor.
Signed-off-by: Yuchuan <[email protected]>
1 parent bfc7400 commit 59fa487

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

cpp/grammar_compiler.cc

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -629,18 +629,22 @@ class GrammarCompilerNoCache {
629629
/*! \brief The main logic. Compile the grammar with multi-threading. */
630630
CompiledGrammar MultiThreadCompileGrammar(Grammar grammar);
631631
/*! \brief Optimization for TagDispatch: Precompute the definitely accepted tokens. */
632-
void TagDispatchOptimization(decltype(std::make_shared<CompiledGrammar::Impl>()
633-
) compiled_grammar_impl);
632+
void TagDispatchOptimization(
633+
decltype(std::make_shared<CompiledGrammar::Impl>()) compiled_grammar_impl,
634+
std::unordered_map<int32_t, DynamicBitset>* tag_dispatch_rule_id_to_second_slicing_bitset
635+
);
634636
/*! \brief Generate the token mask cache for all scannable states. */
635-
void GenerateTokenMaskCacheForScannableStates(decltype(std::make_shared<CompiledGrammar::Impl>()
636-
) compiled_grammar_impl);
637+
void GenerateTokenMaskCacheForScannableStates(
638+
decltype(std::make_shared<CompiledGrammar::Impl>()) compiled_grammar_impl,
639+
const std::unordered_map<int32_t, DynamicBitset>&
640+
tag_dispatch_rule_id_to_second_slicing_bitset
641+
);
637642

638643
/*! \brief The vocabulary associated with this storage class. */
639644
const TokenizerInfo tokenizer_info_;
640645
/*! \brief The maximum number of threads to use. */
641646
const int max_threads_;
642647
/*! \brief Mapping from the rule_id to the definite accepted token mask. */
643-
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
644648
};
645649

646650
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar_unoptimized) {
@@ -650,16 +654,20 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
650654
if (tokenizer_info_.GetVocabSize() == 0) {
651655
return CompiledGrammar(compiled_grammar_impl);
652656
}
653-
TagDispatchOptimization(compiled_grammar_impl);
654-
GenerateTokenMaskCacheForScannableStates(compiled_grammar_impl);
657+
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
658+
TagDispatchOptimization(compiled_grammar_impl, &tag_dispatch_rule_id_to_second_slicing_bitset);
659+
GenerateTokenMaskCacheForScannableStates(
660+
compiled_grammar_impl, tag_dispatch_rule_id_to_second_slicing_bitset
661+
);
655662
return CompiledGrammar(compiled_grammar_impl);
656663
}
657664

658665
void GrammarCompilerNoCache::TagDispatchOptimization(
659-
decltype(std::make_shared<CompiledGrammar::Impl>()) compiled_grammar_impl
666+
decltype(std::make_shared<CompiledGrammar::Impl>()) compiled_grammar_impl,
667+
std::unordered_map<int32_t, DynamicBitset>* tag_dispatch_rule_id_to_second_slicing_bitset
660668
) {
661669
using GrammarExprType = Grammar::Impl::GrammarExprType;
662-
tag_dispatch_rule_id_to_second_slicing_bitset.clear();
670+
tag_dispatch_rule_id_to_second_slicing_bitset->clear();
663671

664672
// Optimization for TagDispatch: Precompute the definitely accepted tokens.
665673
for (int i = 0; i < compiled_grammar_impl->grammar->NumRules(); i++) {
@@ -700,12 +708,14 @@ void GrammarCompilerNoCache::TagDispatchOptimization(
700708
definite_accepted_tokens_since_second_char.Set(i);
701709
}
702710
}
703-
tag_dispatch_rule_id_to_second_slicing_bitset[i] = definite_accepted_tokens_since_second_char;
711+
(*tag_dispatch_rule_id_to_second_slicing_bitset)[i] =
712+
definite_accepted_tokens_since_second_char;
704713
}
705714
}
706715

707716
void GrammarCompilerNoCache::GenerateTokenMaskCacheForScannableStates(
708-
decltype(std::make_shared<CompiledGrammar::Impl>()) compiled_grammar_impl
717+
decltype(std::make_shared<CompiledGrammar::Impl>()) compiled_grammar_impl,
718+
const std::unordered_map<int32_t, DynamicBitset>& tag_dispatch_rule_id_to_second_slicing_bitset
709719
) {
710720
using GrammarExprType = Grammar::Impl::GrammarExprType;
711721

0 commit comments

Comments
 (0)