Skip to content

Commit 3e78a65

Browse files
authored
Support Structural Tag (#420)
This PR brings the structural tag. See the docs for its details. It also refactors GrammarCompiler, the printing methods in EarleyParser and GrammarMatcher, the hashing utils, the exception library. Although introducing the new structural tag API, the API is still backward compatible. Signed-off-by: Ubospica <[email protected]> --------- Signed-off-by: Ubospica <[email protected]>
1 parent b8f42e2 commit 3e78a65

34 files changed

+4602
-600
lines changed

cpp/earley_parser.cc

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ std::pair</* scanable */ bool, /* completable */ bool> EarleyParser::Predict(
152152
return std::make_pair(true, false); // The element is scanable, but not completable.
153153
}
154154
default: {
155-
return std::make_pair(false, false);
155+
XGRAMMAR_LOG(FATAL) << "The element type is not supported! The type is: "
156+
<< int(element_expr.type);
157+
XGRAMMAR_UNREACHABLE();
156158
}
157159
}
158160
}
@@ -178,6 +180,7 @@ void EarleyParser::Scan(const ParserState& state, const uint8_t ch) {
178180
default: {
179181
XGRAMMAR_LOG(FATAL) << "The element type is not supported! The type is: "
180182
<< int(element_expr.type);
183+
XGRAMMAR_UNREACHABLE();
181184
}
182185
}
183186
} else {
@@ -269,19 +272,9 @@ void EarleyParser::PushStateAndExpand(const ParserState& state) {
269272
tmp_states_visited_in_queue_.Clear();
270273
tmp_accept_stop_token_ = false;
271274
tmp_states_to_be_added_.clear();
272-
if (state.IsInvalid()) {
273-
ExpandAndEnqueueUnexpandedState(ParserState{
274-
grammar_->GetRootRuleId(),
275-
ParserState::kUnexpandedRuleStartSequenceId,
276-
0,
277-
ParserState::kNoPrevInputPos,
278-
0
279-
});
280-
} else {
281-
// If the rule can't be expanded, we need to add it to the queue.
282-
if (!ExpandAndEnqueueUnexpandedState(state)) {
283-
Enqueue(state);
284-
}
275+
// If the rule can't be expanded, we need to add it to the queue.
276+
if (!ExpandAndEnqueueUnexpandedState(state)) {
277+
Enqueue(state);
285278
}
286279
rule_id_to_completable_states_.PushBack(std::vector<std::pair<int32_t, ParserState>>());
287280
while (!tmp_process_state_queue_.empty()) {

cpp/earley_parser.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,7 @@ struct ParserState {
111111
}
112112

113113
friend std::ostream& operator<<(std::ostream& os, const ParserState& state) {
114-
os << "ParserState(rule_id=" << state.rule_id << ", sequence_id=" << state.sequence_id
115-
<< ", element_id=" << state.element_id << ", rule_start_pos=" << state.rule_start_pos
116-
<< ", sub_element_id=" << state.sub_element_id << ", repeat_count=" << state.repeat_count
117-
<< ")";
114+
os << state.ToString();
118115
return os;
119116
}
120117

@@ -456,10 +453,12 @@ class EarleyParser {
456453

457454
std::string PrintStates() const {
458455
std::string result;
459-
result += "There are " + std::to_string(scanable_state_history_.size()) + " scanable states:\n";
456+
result += "There are " + std::to_string(scanable_state_history_.size()) +
457+
" steps in history. Last step: [\n";
460458
for (const auto& state : scanable_state_history_[scanable_state_history_.size() - 1]) {
461-
result += state.ToString() + "\n";
459+
result += state.ToString() + ", \n";
462460
}
461+
result += "]";
463462
return result;
464463
}
465464
};

cpp/fsm.h

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,21 +125,9 @@ XGRAMMAR_MEMBER_ARRAY(FSMEdge, &FSMEdge::min, &FSMEdge::max, &FSMEdge::target);
125125

126126
} // namespace xgrammar
127127

128-
namespace std {
129-
130-
/*!
131-
* \brief Hash function for FSMEdge.
132-
*/
133-
template <>
134-
struct hash<xgrammar::FSMEdge> {
135-
size_t operator()(const xgrammar::FSMEdge& edge) const {
136-
return std::hash<std::tuple<int16_t, int16_t, int32_t>>()(
137-
std::make_tuple(edge.min, edge.max, edge.target)
138-
);
139-
}
140-
};
141-
142-
} // namespace std
128+
XGRAMMAR_HASH_BY_MEMBERS(
129+
xgrammar::FSMEdge, &xgrammar::FSMEdge::min, &xgrammar::FSMEdge::max, &xgrammar::FSMEdge::target
130+
);
143131

144132
namespace xgrammar {
145133

cpp/grammar.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,10 @@ Grammar Grammar::FromRegex(const std::string& regex, bool print_converted_ebnf)
6565
return FromEBNF(ebnf_string);
6666
}
6767

68-
Grammar Grammar::FromStructuralTag(
69-
const std::vector<StructuralTagItem>& tags, const std::vector<std::string>& triggers
68+
std::variant<Grammar, StructuralTagError> Grammar::FromStructuralTag(
69+
const std::string& structural_tag_json
7070
) {
71-
Grammar grammar = StructuralTagToGrammar(tags, triggers);
72-
return grammar;
71+
return StructuralTagToGrammar(structural_tag_json).ToVariant();
7372
}
7473

7574
// Optimized json grammar for the speed of the grammar matcher

cpp/grammar_builder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class GrammarBuilder {
202202
);
203203
}
204204

205-
int32_t AddRepeat(const int32_t ref_rule_id, int32_t min_repeat_count, int32_t max_repeat_count) {
205+
int32_t AddRepeat(int32_t ref_rule_id, int32_t min_repeat_count, int32_t max_repeat_count) {
206206
std::vector<int32_t> data({ref_rule_id, min_repeat_count, max_repeat_count});
207207
return AddGrammarExpr({GrammarExprType::kRepeat, data.data(), static_cast<int32_t>(data.size())}
208208
);
@@ -249,6 +249,10 @@ class GrammarBuilder {
249249
*/
250250
int32_t AddEmptyRule(const std::string& name) { return AddRule({name, -1}); }
251251

252+
int32_t AddEmptyRuleWithHint(const std::string& name_hint) {
253+
return AddRule({GetNewRuleName(name_hint), -1});
254+
}
255+
252256
/*!
253257
* \brief Update the rule body of the given rule, specified by rule id. Can be used to set the
254258
* rule body of a rule inserted by GrammarBuilder::AddEmptyRule.

0 commit comments

Comments
 (0)