Skip to content

Commit cfbd875

Browse files
committed
feat:offer apis.
1 parent 3e78a65 commit cfbd875

File tree

9 files changed

+971
-33
lines changed

9 files changed

+971
-33
lines changed

cpp/grammar.cc

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,6 @@ Grammar Grammar::BuiltinJSONGrammar() {
158158
return grammar;
159159
}
160160

161-
Grammar Grammar::Union(const std::vector<Grammar>& grammars) {
162-
return GrammarUnionFunctor::Apply(grammars);
163-
}
164-
165-
Grammar Grammar::Concat(const std::vector<Grammar>& grammars) {
166-
return GrammarConcatFunctor::Apply(grammars);
167-
}
168-
169161
std::ostream& operator<<(std::ostream& os, const Grammar& grammar) {
170162
os << grammar.ToString();
171163
return os;

cpp/grammar_constructor.cc

Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
/*!
2+
* Copyright (c) 2025 by Contributors
3+
* \file xgrammar/grammar_constructor.cc
4+
* \brief The implementation for building the BNF AST.
5+
*/
6+
#include "grammar_constructor.h"
7+
8+
#include <xgrammar/grammar.h>
9+
10+
#include <cstdint>
11+
#include <string>
12+
13+
#include "grammar_functor.h"
14+
#include "support/utils.h"
15+
16+
namespace xgrammar {
17+
18+
/*!
19+
* \brief Implementation of grammar union operation.
20+
*
21+
* Creates a new grammar that accepts strings from any of the input grammars.
22+
* The resulting grammar has a new root rule that chooses between the root rules
23+
* of all input grammars.
24+
*/
25+
class GrammarUnionFunctorImpl : public GrammarMutator {
26+
public:
27+
GrammarUnionFunctorImpl() = default;
28+
29+
Grammar Apply(const std::vector<Grammar>& grammars) {
30+
InitGrammar();
31+
InitBuilder();
32+
auto root_rule_id = builder_->AddEmptyRule("root");
33+
34+
std::vector<int32_t> new_root_choices;
35+
new_root_choices.reserve(grammars.size());
36+
37+
for (const auto& grammar : grammars) {
38+
auto new_root_id_for_grammar = SubGrammarAdder().Apply(builder_, grammar);
39+
auto new_rule_ref = builder_->AddRuleRef(new_root_id_for_grammar);
40+
auto new_rule_ref_seq = builder_->AddSequence({new_rule_ref});
41+
new_root_choices.push_back(new_rule_ref_seq);
42+
}
43+
44+
builder_->UpdateRuleBody(root_rule_id, builder_->AddChoices(new_root_choices));
45+
return builder_->Get(root_rule_id);
46+
}
47+
48+
// Avoid hiding the original Apply(const Grammar&)
49+
Grammar Apply(const Grammar& grammar) final {
50+
XGRAMMAR_LOG(FATAL) << "Should not be called";
51+
XGRAMMAR_UNREACHABLE();
52+
}
53+
};
54+
55+
/*!
56+
* \brief Implementation of grammar concatenation operation.
57+
*
58+
* Creates a new grammar that accepts strings that are concatenations of strings
59+
* from the input grammars in order. The resulting grammar has a new root rule
60+
* that concatenates the root rules of all input grammars.
61+
*/
62+
class GrammarConcatFunctorImpl : public GrammarMutator {
63+
public:
64+
GrammarConcatFunctorImpl() = default;
65+
66+
Grammar Apply(const std::vector<Grammar>& grammars) {
67+
InitGrammar();
68+
InitBuilder();
69+
auto root_rule_id = builder_->AddEmptyRule("root");
70+
71+
std::vector<int32_t> new_root_sequence;
72+
new_root_sequence.reserve(grammars.size());
73+
74+
for (const auto& grammar : grammars) {
75+
auto new_root_id_for_grammar = SubGrammarAdder().Apply(builder_, grammar);
76+
auto new_rule_ref = builder_->AddRuleRef(new_root_id_for_grammar);
77+
new_root_sequence.push_back(new_rule_ref);
78+
}
79+
80+
auto new_root_seq = builder_->AddSequence(new_root_sequence);
81+
builder_->UpdateRuleBody(root_rule_id, builder_->AddChoices({new_root_seq}));
82+
83+
return builder_->Get(root_rule_id);
84+
}
85+
86+
// Avoid hiding the original Apply(const Grammar&)
87+
Grammar Apply(const Grammar& grammar) final {
88+
XGRAMMAR_LOG(FATAL) << "Should not be called";
89+
XGRAMMAR_UNREACHABLE();
90+
}
91+
};
92+
93+
class StructuralTagGrammarCreatorImpl : public GrammarMutator {
94+
public:
95+
Grammar Apply(
96+
const std::vector<std::string>& triggers,
97+
const std::vector<std::vector<std::pair<StructuralTagItem, Grammar>>>& tag_groups
98+
) {
99+
XGRAMMAR_CHECK(triggers.size() == tag_groups.size())
100+
<< "Number of triggers must match number of tag groups";
101+
102+
InitGrammar();
103+
InitBuilder();
104+
105+
auto root_rule_id = builder_->AddEmptyRule("root");
106+
107+
Grammar::Impl::TagDispatch tag_dispatch{
108+
/* tag_rule_pairs = */ {},
109+
/* stop_eos = */ true,
110+
/* stop_str = */ {},
111+
/* loop_after_dispatch = */ true,
112+
};
113+
tag_dispatch.tag_rule_pairs.reserve(triggers.size());
114+
115+
// Create rules for each trigger group
116+
for (size_t i = 0; i < triggers.size(); i++) {
117+
// Skip empty trigger groups
118+
if (tag_groups[i].empty()) {
119+
continue;
120+
}
121+
122+
auto rule_name = "trigger_rule_" + std::to_string(i);
123+
auto rule_id = builder_->AddEmptyRule(rule_name);
124+
125+
// Create choices for each tag in this trigger group
126+
std::vector<int32_t> choices;
127+
choices.reserve(tag_groups[i].size());
128+
for (const auto& [tag, schema_grammar] : tag_groups[i]) {
129+
// Create sequence: start_suffix + schema + end
130+
std::vector<int32_t> seq_elements;
131+
seq_elements.reserve(3);
132+
133+
// Add begin suffix (everything after trigger)
134+
XGRAMMAR_DCHECK(tag.begin.size() >= triggers[i].size())
135+
<< "Tag begin must be at least as long as trigger";
136+
if (tag.begin.size() > triggers[i].size()) {
137+
seq_elements.push_back(builder_->AddByteString(tag.begin.substr(triggers[i].size())));
138+
}
139+
140+
// Create and visit schema grammar for this tag
141+
auto schema_rule_id = SubGrammarAdder().Apply(builder_, schema_grammar);
142+
seq_elements.push_back(builder_->AddRuleRef(schema_rule_id));
143+
144+
// Add end string
145+
if (!tag.end.empty()) {
146+
seq_elements.push_back(builder_->AddByteString(tag.end));
147+
}
148+
149+
choices.push_back(builder_->AddSequence(seq_elements));
150+
}
151+
152+
builder_->UpdateRuleBody(rule_id, builder_->AddChoices(choices));
153+
tag_dispatch.tag_rule_pairs.emplace_back(triggers[i], rule_id);
154+
}
155+
156+
// Create root TagDispatch rule
157+
auto tag_dispatch_id = builder_->AddTagDispatch(tag_dispatch);
158+
builder_->UpdateRuleBody(root_rule_id, tag_dispatch_id);
159+
return builder_->Get(root_rule_id);
160+
}
161+
162+
// Avoid hiding the original Apply(const Grammar&)
163+
Grammar Apply(const Grammar& grammar) final {
164+
XGRAMMAR_LOG(FATAL) << "Should not be called";
165+
XGRAMMAR_UNREACHABLE();
166+
}
167+
};
168+
169+
class TagDispatchGrammarCreatorImpl : public GrammarMutator {
170+
public:
171+
Grammar Apply(
172+
const std::vector<std::string>& triggers,
173+
const std::vector<Grammar>& tags,
174+
bool stop_eos,
175+
bool loop_after_dispatch,
176+
std::vector<std::string> stop_strs
177+
) {
178+
InitGrammar();
179+
InitBuilder();
180+
181+
auto root_rule_id = builder_->AddEmptyRule("root");
182+
183+
Grammar::Impl::TagDispatch tag_dispatch{
184+
/* tag_rule_pairs = */ {},
185+
/* stop_eos = */ stop_eos,
186+
/* stop_str = */ stop_strs,
187+
/* loop_after_dispatch = */ loop_after_dispatch,
188+
};
189+
tag_dispatch.tag_rule_pairs.reserve(triggers.size());
190+
191+
// Create rules for each trigger group
192+
for (size_t i = 0; i < triggers.size(); i++) {
193+
auto rule_name = "trigger_rule_" + std::to_string(i);
194+
auto rule_id = builder_->AddEmptyRule(rule_name);
195+
196+
// Create choices for each tag in this trigger group
197+
std::vector<int32_t> choices;
198+
std::vector<int32_t> seq_elements;
199+
seq_elements.reserve(1);
200+
201+
// Create and visit schema grammar for this tag
202+
auto new_rule_id = SubGrammarAdder().Apply(builder_, tags[i]);
203+
seq_elements.push_back(builder_->AddRuleRef(new_rule_id));
204+
205+
choices.push_back(builder_->AddSequence(seq_elements));
206+
207+
builder_->UpdateRuleBody(rule_id, builder_->AddChoices(choices));
208+
tag_dispatch.tag_rule_pairs.emplace_back(triggers[i], rule_id);
209+
}
210+
211+
auto tag_dispatch_id = builder_->AddTagDispatch(tag_dispatch);
212+
builder_->UpdateRuleBody(root_rule_id, tag_dispatch_id);
213+
214+
return builder_->Get(root_rule_id);
215+
}
216+
217+
// Avoid hiding the original Apply(const Grammar&)
218+
Grammar Apply(const Grammar& grammar) final {
219+
XGRAMMAR_LOG(FATAL) << "Should not be called";
220+
XGRAMMAR_UNREACHABLE();
221+
}
222+
};
223+
224+
class StarGrammarCreatorImpl : public GrammarMutator {
225+
public:
226+
Grammar Apply(const Grammar& grammar) {
227+
// Initialize the grammar and builder.
228+
InitGrammar();
229+
InitBuilder();
230+
231+
// Add a new empty rule for the root.
232+
auto root_rule_id = builder_->AddEmptyRule("root");
233+
234+
// Add the original grammar as a subgrammar.
235+
auto original_root_rule_id = SubGrammarAdder().Apply(builder_, grammar);
236+
237+
// Get a rule reference for root_original.
238+
auto original_root_rule_ref = builder_->AddRuleRef(original_root_rule_id);
239+
240+
// Get a rule reference for the new root rule.
241+
auto root_rule_ref = builder_->AddRuleRef(root_rule_id);
242+
243+
// We get root_original root.
244+
auto new_root_seq = builder_->AddSequence({original_root_rule_ref, root_rule_ref});
245+
246+
// root ::= "" | root_original root
247+
auto new_root_choice = builder_->AddChoices({builder_->AddEmptyStr(), new_root_seq});
248+
builder_->UpdateRuleBody(root_rule_id, new_root_choice);
249+
return builder_->Get(root_rule_id);
250+
}
251+
};
252+
253+
/**************************************** Grammar Functions ***************************************/
254+
255+
Grammar Grammar::Empty() { return Grammar::FromEBNF("root ::= \"\""); }
256+
257+
Grammar Grammar::String(const std::string& str) {
258+
static const std::unordered_map<char, std::string> kCodepointToEscape = {
259+
{'\'', "\\\'"},
260+
{'\"', "\\\""},
261+
{'\?', "\\?"},
262+
{'\\', "\\\\"},
263+
{'\a', "\\a"},
264+
{'\b', "\\b"},
265+
{'\f', "\\f"},
266+
{'\n', "\\n"},
267+
{'\r', "\\r"},
268+
{'\t', "\\t"},
269+
{'\v', "\\v"},
270+
{'\0', "\\0"},
271+
{'\x1B', "\\e"}
272+
};
273+
std::string ebnf_string = "root ::= \"";
274+
for (auto ch : str) {
275+
if (kCodepointToEscape.find(ch) != kCodepointToEscape.end()) {
276+
ebnf_string += kCodepointToEscape.at(ch);
277+
} else {
278+
ebnf_string += ch;
279+
}
280+
}
281+
ebnf_string += "\"";
282+
return Grammar::FromEBNF(ebnf_string);
283+
}
284+
285+
Grammar Grammar::CharacterClass(const std::string& str) { return Grammar::FromRegex(str); }
286+
287+
Grammar Grammar::TagDispatch(
288+
const std::vector<std::string>& triggers,
289+
const std::vector<Grammar>& tags,
290+
bool stop_eos,
291+
bool loop_after_dispatch,
292+
const std::vector<std::string>& stop_strs
293+
) {
294+
return TagDispatchGrammarCreator::Apply(triggers, tags, stop_eos, loop_after_dispatch, stop_strs);
295+
}
296+
297+
Grammar Grammar::Union(const std::vector<Grammar>& grammars) {
298+
return GrammarUnionFunctor::Apply(grammars);
299+
}
300+
301+
Grammar Grammar::Concat(const std::vector<Grammar>& grammars) {
302+
return GrammarConcatFunctor::Apply(grammars);
303+
}
304+
305+
Grammar Grammar::Star(const Grammar& grammar) { return StarGrammarCreator::Apply(grammar); }
306+
307+
Grammar Grammar::Plus(const Grammar& grammar) {
308+
return Grammar::Concat({grammar, Grammar::Star(grammar)});
309+
}
310+
311+
Grammar Grammar::Optional(const Grammar& grammar) {
312+
return Grammar::Union({grammar, Grammar::Empty()});
313+
}
314+
315+
/*************************** Forward grammar Constructors to their impl ***************************/
316+
317+
Grammar GrammarUnionFunctor::Apply(const std::vector<Grammar>& grammars) {
318+
return GrammarUnionFunctorImpl().Apply(grammars);
319+
}
320+
321+
Grammar GrammarConcatFunctor::Apply(const std::vector<Grammar>& grammars) {
322+
return GrammarConcatFunctorImpl().Apply(grammars);
323+
}
324+
325+
Grammar StructuralTagGrammarCreator::Apply(
326+
const std::vector<std::string>& triggers,
327+
const std::vector<std::vector<std::pair<StructuralTagItem, Grammar>>>& tag_groups
328+
) {
329+
return StructuralTagGrammarCreatorImpl().Apply(triggers, tag_groups);
330+
}
331+
332+
Grammar TagDispatchGrammarCreator::Apply(
333+
const std::vector<std::string>& triggers,
334+
const std::vector<Grammar>& tags,
335+
bool stop_eos,
336+
bool loop_after_dispatch,
337+
const std::vector<std::string>& stop_strs
338+
) {
339+
return TagDispatchGrammarCreatorImpl().Apply(
340+
triggers, tags, stop_eos, loop_after_dispatch, stop_strs
341+
);
342+
}
343+
344+
Grammar StarGrammarCreator::Apply(const Grammar& grammar) {
345+
return StarGrammarCreatorImpl().Apply(grammar);
346+
}
347+
348+
} // namespace xgrammar

0 commit comments

Comments
 (0)