Skip to content

Commit ba63066

Browse files
[Fix] Fix slow FSM optimization methods SimplifyEpsilon and MergeEquivalentSuccessors (#407)
This PR refactors part of the fsm functions. In general, this PR: - Fix the `MergeEquivalentSuccessors` and `SimplifyEpsilon` to speed up. - Set proper limitations of number of states for the fsm functions. --------- Signed-off-by: Yuchuan <[email protected]>
1 parent 853c7ac commit ba63066

File tree

8 files changed

+322
-264
lines changed

8 files changed

+322
-264
lines changed

cpp/fsm.cc

Lines changed: 181 additions & 159 deletions
Large diffs are not rendered by default.

cpp/fsm.h

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
#include <cstdint>
1717
#include <functional>
1818
#include <string>
19-
#include <unordered_map>
2019
#include <unordered_set>
2120
#include <vector>
2221

2322
#include "support/compact_2d_array.h"
23+
#include "support/logging.h"
2424
#include "support/reflection.h"
2525
#include "support/utils.h"
2626
#include "xgrammar/exception.h"
@@ -302,9 +302,10 @@ class FSM {
302302
* \brief Add a whole FSM to the current FSM.
303303
* \param fsm The FSM to be added.
304304
* \param state_mapping The mapping from the state ids of the added FSM to the new ids in the
305-
* current FSM. The result is cleared at the beginning.
305+
* current FSM. The result is cleared at the beginning. If the fsm's state id starts from 0, use
306+
* it for efficiency.
306307
*/
307-
void AddFSM(const FSM& fsm, std::unordered_map<int, int>* state_mapping = nullptr);
308+
void AddFSM(const FSM& fsm, std::vector<int>* state_mapping = nullptr);
308309

309310
/****************** FSM Construction Methods ******************/
310311

@@ -319,7 +320,7 @@ class FSM {
319320
* \param new_num_states The new number of states.
320321
* \return The rebuilt FSM.
321322
*/
322-
FSM RebuildWithMapping(std::unordered_map<int, int>& state_mapping, int new_num_states) const;
323+
FSM RebuildWithMapping(const std::vector<int>& state_mapping, int new_num_states) const;
323324

324325
/*!
325326
* \brief Sort the edges of the FSM by their min, max and target.
@@ -483,20 +484,6 @@ class FSMWithStartEndBase {
483484
// For serialization only
484485
FSMWithStartEndBase() = default;
485486

486-
/*! \brief Constructs an FSMWithStartEnd with a given FSM, start state, and end states. */
487-
FSMWithStartEndBase(
488-
const FSMType& fsm, int start, const std::unordered_set<int>& ends, bool is_dfa = false
489-
)
490-
: fsm_(fsm), start_(start), is_dfa_(is_dfa) {
491-
ends_.resize(fsm.NumStates(), false);
492-
for (const auto& end : ends) {
493-
XGRAMMAR_DCHECK(end < fsm.NumStates())
494-
<< "End state " << end << " is out of bounds for FSM with " << fsm.NumStates()
495-
<< " states.";
496-
ends_[end] = true;
497-
}
498-
}
499-
500487
FSMWithStartEndBase(
501488
const FSMType& fsm, int start, const std::vector<bool>& ends, bool is_dfa = false
502489
)
@@ -653,9 +640,8 @@ class FSMWithStartEnd : public FSMWithStartEndBase<FSM> {
653640
* \param state_mapping The mapping from old state ids to new state ids.
654641
* \param new_num_states The new number of states.
655642
*/
656-
FSMWithStartEnd RebuildWithMapping(
657-
std::unordered_map<int, int>& state_mapping, int new_num_states
658-
);
643+
FSMWithStartEnd RebuildWithMapping(const std::vector<int>& state_mapping, int new_num_states)
644+
const;
659645

660646
/*!
661647
* \brief Add the underlying FSM to another complete FSM that could contain multiple FSMs.
@@ -666,7 +652,7 @@ class FSMWithStartEnd : public FSMWithStartEndBase<FSM> {
666652
* cleared at the beginning. Should not be nullptr.
667653
* \return The FSMWithStartEnd that points to the complete FSM.
668654
*/
669-
FSMWithStartEnd AddToCompleteFSM(FSM* complete_fsm, std::unordered_map<int, int>* state_mapping);
655+
FSMWithStartEnd AddToCompleteFSM(FSM* complete_fsm, std::vector<int>* state_mapping);
670656

671657
/*!
672658
* \brief Transform the FSMWithStartEnd to a CompactFSMWithStartEnd.
@@ -735,29 +721,29 @@ class FSMWithStartEnd : public FSMWithStartEndBase<FSM> {
735721
* \details If a --\epsilon--> b, and either 1) b doesn't have any other inward edges, or
736722
* 2) a doesn't have any other outward edges, we can merge a and b.
737723
*/
738-
FSMWithStartEnd SimplifyEpsilon() const;
724+
FSMWithStartEnd SimplifyEpsilon(int max_num_states = 1e8) const;
739725

740726
/*!
741727
* \brief Merge equivalent states in the FSM.
742728
* \details If two states are 1) pointed to by edges with the same label from the same state, and
743729
* 2) they are not pointed to by other edges, then we can merge them.
744730
* \example n0 --(c)--> n1, n0 --(c)--> n2, then we can merge n1 and n2.
745731
*/
746-
FSMWithStartEnd MergeEquivalentSuccessors() const;
732+
FSMWithStartEnd MergeEquivalentSuccessors(int max_num_states = 1e5) const;
747733

748734
/*!
749735
* \brief Transform the FSM to a DFA.
750736
* \param max_result_num_states The maximum number of states in the DFA.
751737
* \return The DFA.
752738
*/
753-
Result<FSMWithStartEnd> ToDFA(int max_result_num_states = 1e6) const;
739+
Result<FSMWithStartEnd> ToDFA(int max_num_states = 1e3) const;
754740

755741
/*!
756742
* \brief Minimize the DFA.
757743
* \param max_result_num_states The maximum number of states in the DFA.
758744
* \return The minimized DFA.
759745
*/
760-
Result<FSMWithStartEnd> MinimizeDFA(int max_result_num_states = 1e6) const;
746+
Result<FSMWithStartEnd> MinimizeDFA(int max_num_states = 1e3) const;
761747
};
762748

763749
/*!

cpp/fsm_builder.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ Result<std::pair<int, int>> RegexIR::CheckRepeat(const std::string& regex, int&
192192
Result<FSMWithStartEnd> RegexIR::Build() const {
193193
if (states.empty()) {
194194
FSM empty_fsm(1);
195-
FSMWithStartEnd result(empty_fsm, 0, std::unordered_set<int>{0}, false);
195+
FSMWithStartEnd result(empty_fsm, 0, {true}, false);
196196
return ResultOk(std::move(result));
197197
}
198198
std::vector<FSMWithStartEnd> fsm_list;
@@ -336,7 +336,7 @@ Result<FSMWithStartEnd> RegexIR::visit(const RegexIR::Repeat& state) const {
336336

337337
FSMWithStartEnd RegexIR::BuildLeafFSMFromRegex(const std::string& regex) {
338338
FSM empty_fsm(0);
339-
FSMWithStartEnd result(empty_fsm, 0, std::unordered_set<int>{}, true);
339+
FSMWithStartEnd result(empty_fsm, 0, {}, true);
340340
// Handle the regex string.
341341
if (!(regex[0] == '[' && regex[regex.size() - 1] == ']')) {
342342
result.AddState();
@@ -495,7 +495,9 @@ FSMWithStartEnd RegexIR::BuildLeafFSMFromRegex(const std::string& regex) {
495495
new_fsm.AddEdge(0, 1, last, 0xFF);
496496
}
497497
}
498-
result = FSMWithStartEnd(new_fsm, 0, std::unordered_set<int>{1}, false);
498+
std::vector<bool> ends(new_fsm.NumStates(), false);
499+
ends[1] = true;
500+
result = FSMWithStartEnd(new_fsm, 0, ends, false);
499501
} else {
500502
// TODO: The support for rules.
501503
XGRAMMAR_LOG(WARNING) << "rule is not supported yet.";
@@ -830,7 +832,13 @@ std::optional<FSMWithStartEnd> TrieFSMBuilderImpl::Build(
830832
if (add_back_edges) {
831833
AddBackEdges(&fsm, start, ends);
832834
}
833-
return FSMWithStartEnd(fsm, start, ends);
835+
836+
std::vector<bool> is_end_state(fsm.NumStates(), false);
837+
for (const auto& end : ends) {
838+
is_end_state[end] = true;
839+
}
840+
841+
return FSMWithStartEnd(fsm, start, is_end_state);
834842
}
835843

836844
void TrieFSMBuilderImpl::AddBackEdges(FSM* fsm, int start, const std::unordered_set<int>& ends) {

cpp/grammar_functor.cc

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,10 +1070,18 @@ class StructuralTagGrammarCreatorImpl : public GrammarMutator {
10701070

10711071
class GrammarFSMBuilderImpl {
10721072
public:
1073+
const static uint32_t kMax1ByteUnicode = 0x7F;
1074+
const static uint32_t kMin2BytesUnicode = 0xC080;
1075+
const static uint32_t kMax2BytesUnicode = 0xDFBF;
1076+
const static uint32_t kMin3BytesUnicode = 0xE08080;
1077+
const static uint32_t kMax3BytesUnicode = 0xEFBFBF;
1078+
const static uint32_t kMin4BytesUnicode = 0xF0808080;
1079+
const static uint32_t kMax4BytesUnicode = 0xF7BFBFBF;
1080+
10731081
void Apply(Grammar* grammar) {
10741082
FSM complete_fsm;
10751083
std::vector<std::optional<FSMWithStartEnd>> per_rule_fsms((*grammar)->NumRules());
1076-
std::unordered_map<int, int> state_mapping;
1084+
std::vector<int> state_mapping;
10771085

10781086
for (int i = 0; i < (*grammar)->NumRules(); ++i) {
10791087
auto rule = (*grammar)->GetRule(i);
@@ -1108,13 +1116,13 @@ class GrammarFSMBuilderImpl {
11081116
}
11091117

11101118
/* Basic Building functions.*/
1111-
static std::optional<FSMWithStartEnd> RuleRef(const GrammarExpr& expr);
1112-
static std::optional<FSMWithStartEnd> CharacterClass(const GrammarExpr& expr);
1113-
static std::optional<FSMWithStartEnd> ByteString(const GrammarExpr& expr);
1119+
static FSMWithStartEnd RuleRef(const GrammarExpr& expr);
1120+
static FSMWithStartEnd CharacterClass(const GrammarExpr& expr);
1121+
static FSMWithStartEnd ByteString(const GrammarExpr& expr);
11141122
static std::optional<FSMWithStartEnd> Sequence(const GrammarExpr& expr, const Grammar& grammar);
11151123
static std::optional<FSMWithStartEnd> Choices(const GrammarExpr& expr, const Grammar& grammar);
11161124
static std::optional<FSMWithStartEnd> TagDispatch(const Grammar::Impl::TagDispatch& tag_dispatch);
1117-
1125+
static void AddCharacterRange(FSMWithStartEnd& fsm, int from, int to, uint32_t min, uint32_t max);
11181126
/* Building tool funtions.*/
11191127
static std::optional<FSMWithStartEnd> BuildTagDispatchWithEOSStop(
11201128
const std::vector<std::pair<std::string, int>>& tag_dispatch_rules, bool loop_after_dispatch
@@ -1124,17 +1132,9 @@ class GrammarFSMBuilderImpl {
11241132
const std::vector<std::string>& stop_strings,
11251133
bool loop_after_dispatch
11261134
);
1127-
static std::optional<FSMWithStartEnd> BuildNegativeCharacterClass(const GrammarExpr& expr);
1135+
static FSMWithStartEnd BuildNegativeCharacterClass(const GrammarExpr& expr);
11281136
};
11291137

1130-
const static uint32_t kMax1ByteUnicode = 0x7F;
1131-
const static uint32_t kMin2BytesUnicode = 0xC080;
1132-
const static uint32_t kMax2BytesUnicode = 0xDFBF;
1133-
const static uint32_t kMin3BytesUnicode = 0xE08080;
1134-
const static uint32_t kMax3BytesUnicode = 0xEFBFBF;
1135-
const static uint32_t kMin4BytesUnicode = 0xF0808080;
1136-
const static uint32_t kMax4BytesUnicode = 0xF7BFBFBF;
1137-
11381138
// This function will add a range [min, max] of characters to the FSM, and the length
11391139
// of the characters are the same.
11401140
void AddSameLengthCharacterRange(
@@ -1268,7 +1268,9 @@ void AddSameLengthCharacterRange(
12681268
}
12691269

12701270
// This function will add a range [min, max] of unicode characters to the FSM.
1271-
void AddCharacterRange(FSMWithStartEnd& fsm, int from, int to, uint32_t min, uint32_t max) {
1271+
void GrammarFSMBuilderImpl::AddCharacterRange(
1272+
FSMWithStartEnd& fsm, int from, int to, uint32_t min, uint32_t max
1273+
) {
12721274
XGRAMMAR_CHECK(min <= max) << "Invalid character range: min (" << min << ") > max (" << max
12731275
<< ")";
12741276
// Ensure max and min are valid unicode value.
@@ -1346,9 +1348,7 @@ void AddCharacterRange(FSMWithStartEnd& fsm, int from, int to, uint32_t min, uin
13461348
return;
13471349
}
13481350

1349-
std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::BuildNegativeCharacterClass(
1350-
const GrammarExpr& expr
1351-
) {
1351+
FSMWithStartEnd GrammarFSMBuilderImpl::BuildNegativeCharacterClass(const GrammarExpr& expr) {
13521352
XGRAMMAR_DCHECK(
13531353
expr.type == ExprType::kCharacterClass || expr.type == ExprType::kCharacterClassStar
13541354
);
@@ -1400,15 +1400,12 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::BuildNegativeCharacterClas
14001400
return result_fsm;
14011401
}
14021402

1403-
std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::CharacterClass(const GrammarExpr& expr) {
1403+
FSMWithStartEnd GrammarFSMBuilderImpl::CharacterClass(const GrammarExpr& expr) {
14041404
bool is_negative = expr[0];
14051405
FSMWithStartEnd result_fsm;
14061406
if (is_negative) {
1407-
auto optional_fsm = BuildNegativeCharacterClass(expr);
1408-
if (!optional_fsm.has_value()) {
1409-
return std::nullopt;
1410-
}
1411-
return result_fsm = std::move(optional_fsm.value());
1407+
result_fsm = BuildNegativeCharacterClass(expr);
1408+
return result_fsm;
14121409
}
14131410
int start_state = result_fsm.AddState();
14141411
result_fsm.SetStartState(start_state);
@@ -1438,28 +1435,16 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::Sequence(
14381435
const auto& sequence_expr = grammar->GetGrammarExpr(sequence_id);
14391436
switch (sequence_expr.type) {
14401437
case (ExprType::kByteString): {
1441-
auto fsm = ByteString(sequence_expr);
1442-
if (!fsm.has_value()) {
1443-
return std::nullopt;
1444-
}
1445-
fsm_lists.push_back(std::move(fsm.value()));
1438+
fsm_lists.push_back(ByteString(sequence_expr));
14461439
break;
14471440
}
14481441
case (ExprType::kRuleRef): {
1449-
auto fsm = RuleRef(sequence_expr);
1450-
if (!fsm.has_value()) {
1451-
return std::nullopt;
1452-
}
1453-
fsm_lists.push_back(std::move(fsm.value()));
1442+
fsm_lists.push_back(RuleRef(sequence_expr));
14541443
break;
14551444
}
14561445
case (ExprType::kCharacterClass):
14571446
case (ExprType::kCharacterClassStar): {
1458-
auto fsm = CharacterClass(sequence_expr);
1459-
if (!fsm.has_value()) {
1460-
return std::nullopt;
1461-
}
1462-
fsm_lists.push_back(std::move(fsm.value()));
1447+
fsm_lists.push_back(CharacterClass(sequence_expr));
14631448
break;
14641449
}
14651450
default: {
@@ -1480,7 +1465,7 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::Sequence(
14801465
return FSMWithStartEnd::Concat(fsm_lists);
14811466
}
14821467

1483-
std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::RuleRef(const GrammarExpr& expr) {
1468+
FSMWithStartEnd GrammarFSMBuilderImpl::RuleRef(const GrammarExpr& expr) {
14841469
FSMWithStartEnd result_fsm;
14851470
result_fsm.AddState();
14861471
result_fsm.AddState();
@@ -1490,7 +1475,7 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::RuleRef(const GrammarExpr&
14901475
return result_fsm;
14911476
}
14921477

1493-
std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::ByteString(const GrammarExpr& expr) {
1478+
FSMWithStartEnd GrammarFSMBuilderImpl::ByteString(const GrammarExpr& expr) {
14941479
XGRAMMAR_DCHECK(expr.type == ExprType::kByteString);
14951480
FSMWithStartEnd result_fsm;
14961481
int current_state = result_fsm.AddState();
@@ -1580,13 +1565,13 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::BuildTagDispatchWithStopSt
15801565
old_ends.insert(end);
15811566
}
15821567
}
1583-
std::unordered_set<int> ends;
1568+
std::vector<bool> ends(trie_fsm.NumStates(), false);
15841569

15851570
// The final end states are the end of each stop string.
15861571
for (int i = static_cast<int>(tag_dispatch_rules.size());
15871572
i < static_cast<int>(trie_end_states.size());
15881573
i++) {
1589-
ends.insert(trie_end_states[i]);
1574+
ends[trie_end_states[i]] = true;
15901575
}
15911576

15921577
if (loop_after_dispatch) {
@@ -1611,11 +1596,12 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::BuildTagDispatchWithStopSt
16111596
}
16121597
}
16131598

1614-
std::unordered_map<int, int> stop_trie_to_trie_map;
1599+
std::vector<int> stop_trie_to_trie_map;
16151600
trie_fsm.AddFSM(stop_trie_fsm, &stop_trie_to_trie_map);
1601+
ends.resize(trie_fsm.NumStates(), false);
16161602
int start_of_stop_trie = stop_trie_to_trie_map[stop_trie_start];
16171603
for (auto state : stop_trie_ends) {
1618-
ends.insert(stop_trie_to_trie_map[state]);
1604+
ends[stop_trie_to_trie_map[state]] = true;
16191605
}
16201606

16211607
for (int i = 0; i < static_cast<int>(tag_dispatch_rules.size()); i++) {
@@ -1642,7 +1628,7 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::BuildTagDispatchWithEOSSto
16421628
auto trie_fsm = trie_result->GetFsm();
16431629
auto start = trie_result->GetStart();
16441630
std::unordered_set<int> old_ends;
1645-
std::unordered_set<int> ends;
1631+
std::vector<bool> ends(trie_fsm.NumStates(), false);
16461632
for (int end = 0; end < trie_result->NumStates(); end++) {
16471633
if (trie_result->IsEndState(end)) {
16481634
old_ends.insert(end);
@@ -1652,7 +1638,7 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::BuildTagDispatchWithEOSSto
16521638
// The final end states are all but old_ends.
16531639
for (int i = 0; i < trie_fsm.NumStates(); i++) {
16541640
if (old_ends.count(i) == 0) {
1655-
ends.insert(i);
1641+
ends[i] = true;
16561642
}
16571643
}
16581644

@@ -1663,7 +1649,7 @@ std::optional<FSMWithStartEnd> GrammarFSMBuilderImpl::BuildTagDispatchWithEOSSto
16631649
next_state = start;
16641650
} else {
16651651
next_state = trie_fsm.AddState();
1666-
ends.insert(next_state);
1652+
ends.push_back(true);
16671653
}
16681654
trie_fsm.AddRuleEdge(end_states[i], next_state, tag_dispatch_rules[i].second);
16691655
}
@@ -1758,15 +1744,15 @@ void GrammarFSMBuilder::Apply(Grammar* grammar) { GrammarFSMBuilderImpl().Apply(
17581744

17591745
void RepetitionNormalizer::Apply(Grammar* grammar) { RepetitionNormalizerImpl().Apply(grammar); }
17601746

1761-
std::optional<FSMWithStartEnd> GrammarFSMBuilder::RuleRef(const GrammarExpr& expr) {
1747+
FSMWithStartEnd GrammarFSMBuilder::RuleRef(const GrammarExpr& expr) {
17621748
return GrammarFSMBuilderImpl::RuleRef(expr);
17631749
}
17641750

1765-
std::optional<FSMWithStartEnd> GrammarFSMBuilder::CharacterClass(const GrammarExpr& expr) {
1751+
FSMWithStartEnd GrammarFSMBuilder::CharacterClass(const GrammarExpr& expr) {
17661752
return GrammarFSMBuilderImpl::CharacterClass(expr);
17671753
}
17681754

1769-
std::optional<FSMWithStartEnd> GrammarFSMBuilder::ByteString(const GrammarExpr& expr) {
1755+
FSMWithStartEnd GrammarFSMBuilder::ByteString(const GrammarExpr& expr) {
17701756
return GrammarFSMBuilderImpl::ByteString(expr);
17711757
}
17721758

cpp/grammar_functor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,9 @@ class GrammarFSMBuilder {
347347

348348
public:
349349
static void Apply(Grammar* grammar);
350-
static std::optional<FSMWithStartEnd> RuleRef(const GrammarExpr& expr);
351-
static std::optional<FSMWithStartEnd> CharacterClass(const GrammarExpr& expr);
352-
static std::optional<FSMWithStartEnd> ByteString(const GrammarExpr& expr);
350+
static FSMWithStartEnd RuleRef(const GrammarExpr& expr);
351+
static FSMWithStartEnd CharacterClass(const GrammarExpr& expr);
352+
static FSMWithStartEnd ByteString(const GrammarExpr& expr);
353353
static std::optional<FSMWithStartEnd> Sequence(const GrammarExpr& expr, const Grammar& grammar);
354354
static std::optional<FSMWithStartEnd> Choices(const GrammarExpr& expr, const Grammar& grammar);
355355
static std::optional<FSMWithStartEnd> TagDispatch(const Grammar::Impl::TagDispatch& tag_dispatch);

0 commit comments

Comments
 (0)