Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 69 additions & 47 deletions cpp/grammar_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cstddef>
#include <cstdint>
#include <optional>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <variant>
Expand All @@ -30,6 +31,12 @@

namespace xgrammar {

struct EmptyHolder {
EmptyHolder() = default;
template <typename... Args>
explicit EmptyHolder(Args&&...) {}
};

/************** AdaptiveTokenMaskCache Generator **************/

/*! \brief The concrete implementation of GrammarMatcherNode. */
Expand Down Expand Up @@ -57,7 +64,7 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
const std::vector<int32_t>& subtree_nodes_range,
bool is_root_rule
);
) &&;

/*!
* \brief Get the token mask for the given ParserState.
Expand Down Expand Up @@ -531,7 +538,7 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab,
const std::vector<int32_t>& subtree_nodes_range,
bool is_root_rule
) {
) && {
tmp_accepted_indices_.clear();
tmp_rejected_indices_.clear();
tmp_uncertain_indices_.clear();
Expand Down Expand Up @@ -590,17 +597,20 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
sorted_decoded_vocab, first_character_mask, subtree_nodes_range, is_root_rule
);
if (rejected_indices_are_filled) {
return AdaptiveTokenMask(
return AdaptiveTokenMask{
vocab_size,
sorted_decoded_vocab,
tmp_accepted_indices_,
tmp_rejected_indices_,
tmp_uncertain_indices_
);
std::move(tmp_accepted_indices_),
std::move(tmp_rejected_indices_),
std::move(tmp_uncertain_indices_),
};
} else {
return AdaptiveTokenMask(
vocab_size, sorted_decoded_vocab, tmp_accepted_indices_, tmp_uncertain_indices_
);
return AdaptiveTokenMask{
vocab_size,
sorted_decoded_vocab,
std::move(tmp_accepted_indices_),
std::move(tmp_uncertain_indices_),
};
}
}

Expand All @@ -611,8 +621,10 @@ AdaptiveTokenMask GrammarMatcherForTokenMaskCache::GetAdaptiveTokenMask(
*/
class GrammarCompilerNoCache {
public:
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, int max_threads)
: tokenizer_info_(tokenizer_info), max_threads_(max_threads) {}
GrammarCompilerNoCache(const TokenizerInfo& tokenizer_info, const int max_threads)
: tokenizer_info_(tokenizer_info) {
if (max_threads > 1) thread_pool_.emplace(max_threads);
}

CompiledGrammar CompileBuiltinJSONGrammar();

Expand All @@ -635,17 +647,28 @@ class GrammarCompilerNoCache {

private:
/*! \brief The main logic. Compile the grammar with multi-threading. */
CompiledGrammar MultiThreadCompileGrammar(Grammar grammar);
CompiledGrammar MultiThreadCompileGrammar(Grammar grammar) {
// dispatch based on whether thread_pool_ is set
if (thread_pool_) {
return MultiThreadCompileGrammarImpl<true>(std::move(grammar));
} else {
return MultiThreadCompileGrammarImpl<false>(std::move(grammar));
}
}

template <bool kUseMultiThread>
CompiledGrammar MultiThreadCompileGrammarImpl(Grammar grammar);

/*! \brief The vocabulary associated with this storage class. */
const TokenizerInfo tokenizer_info_;
/*! \brief The maximum number of threads to use. */
const int max_threads_;
std::optional<ThreadPool> thread_pool_;
/*! \brief Mapping from the rule_id to the definite accepted token mask. */
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
};

CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar_unoptimized) {
template <bool kUseMultiThread>
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammarImpl(Grammar grammar_unoptimized) {
using GrammarExprType = Grammar::Impl::GrammarExprType;

auto compiled_grammar_impl = std::make_shared<CompiledGrammar::Impl>();
Expand Down Expand Up @@ -702,47 +725,46 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
// 2. All byte strings (with element_in_string=0, 1, 2, ...)
// since other positions will be expanded to the above positions
[[maybe_unused]] auto task_counter = [&] {
if constexpr (kUseMultiThread) {
return thread_pool_->CreateTaskCounter();
} else {
return 0;
}
}();

// TODO(Charlie): Figure out how to support ThreadPool and std::mutex in WebAssembly.
// Only declare ThreadPool and mutex if max_threads > 1, so when max_threads = 1, we do
// not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly.
std::optional<ThreadPool> thread_pool;
std::optional<std::mutex> adaptive_token_mask_cache_mutex;
if (max_threads_ > 1) {
thread_pool.emplace(max_threads_);
adaptive_token_mask_cache_mutex.emplace();
}

auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
const auto get_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
auto grammar_matcher = GrammarMatcherForTokenMaskCache(
compiled_grammar_impl->grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false
);
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
tokenizer_info_.GetVocabSize(),
tokenizer_info_.GetSortedDecodedVocab(),
tokenizer_info_.GetTrieSubtreeNodesRange(),
is_root_rule
);
if (max_threads_ > 1) {
std::lock_guard<std::mutex> lock(adaptive_token_mask_cache_mutex.value());
compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
} else {
compiled_grammar_impl->adaptive_token_mask_cache[state] = cur_adaptive_token_mask_cache;
}
return std::move(grammar_matcher)
.GetAdaptiveTokenMask(
tokenizer_info_.GetVocabSize(),
tokenizer_info_.GetSortedDecodedVocab(),
tokenizer_info_.GetTrieSubtreeNodesRange(),
is_root_rule
);
};

auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
const auto add_task_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
// Execute depending on whether we use thread_pool
if (max_threads_ > 1) {
thread_pool->Execute([add_adaptive_token_mask, state, is_root_rule]() {
add_adaptive_token_mask(state, is_root_rule);
});
if constexpr (kUseMultiThread) {
task_counter.Submit(
// parallel part: construct the mask
[=] { return get_adaptive_token_mask(state, is_root_rule); },
// protected region, insert into the cache
[=](AdaptiveTokenMask&& cache) {
compiled_grammar_impl->adaptive_token_mask_cache.try_emplace(state, std::move(cache));
}
);
} else {
add_adaptive_token_mask(state, is_root_rule);
compiled_grammar_impl->adaptive_token_mask_cache.try_emplace(
state, get_adaptive_token_mask(state, is_root_rule)
);
}
};

auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId();
const auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId();

for (int32_t rule_id = 0; rule_id < static_cast<int>(compiled_grammar_impl->grammar->NumRules());
++rule_id) {
Expand Down Expand Up @@ -796,8 +818,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
}
}

if (max_threads_ > 1) {
thread_pool->Join();
if constexpr (kUseMultiThread) {
task_counter.WaitUntilComplete();
}

return CompiledGrammar(compiled_grammar_impl);
Expand Down
40 changes: 19 additions & 21 deletions cpp/grammar_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ class BatchGrammarMatcher::Impl {
bool debug_print
);

void InitThreadPoolOnce() {
if (!thread_pool_.has_value() && max_threads_ > 1) {
thread_pool_.emplace(max_threads_);
}
}

private:
std::optional<ThreadPool> thread_pool_ = std::nullopt;
int32_t max_threads_ = 1;
Expand Down Expand Up @@ -912,33 +918,25 @@ void BatchGrammarMatcher::Impl::BatchFillNextTokenBitmask(
XGRAMMAR_CHECK(!indices.has_value() || indices->size() == matchers->size())
<< "The size of indices (" << (indices.has_value() ? indices->size() : 0)
<< ") should be the same as the size of matchers (" << matchers->size() << ").";
// Initialize the thread pool if needed. It should be initialized each time,
// because ThreadPool cannot be reused after Join().
if (max_threads_ > 1) {
thread_pool_.emplace(max_threads_);
}
this->InitThreadPoolOnce();
auto fill_next_token_mask = [&](int i) {
auto& matcher = (*matchers)[i];
int index = indices.has_value() ? (*indices)[i] : i;
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
<< ") for batch_id " << i << ".";
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
};
if (!thread_pool_.has_value()) {
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
auto& matcher = (*matchers)[i];
int index = indices.has_value() ? (*indices)[i] : i;
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
<< ") for batch_id " << i << ".";
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
fill_next_token_mask(i);
}
} else {
auto fill_next_token_mask = [&](int32_t batch_id) {
auto& matcher = (*matchers)[batch_id];
int index = indices.has_value() ? (*indices)[batch_id] : batch_id;
XGRAMMAR_CHECK(index >= 0 && index < next_token_bitmask->shape[0])
<< "The index " << index << " is out of range [0, " << next_token_bitmask->shape[0]
<< ") for batch_id " << batch_id << ".";
matcher->FillNextTokenBitmask(next_token_bitmask, index, debug_print);
};
auto task_counter = thread_pool_->CreateTaskCounter();
for (int i = 0; i < static_cast<int32_t>(matchers->size()); i++) {
thread_pool_->Execute([fill_next_token_mask, i]() { fill_next_token_mask(i); });
task_counter.Submit([fill_next_token_mask, i] { fill_next_token_mask(i); });
}
thread_pool_->Join();
task_counter.WaitUntilComplete();
}
}

Expand Down
Loading
Loading