|
| 1 | +/*! |
| 2 | + * Copyright (c) 2025 by Contributors |
| 3 | + * \file xgrammar/grammar_constructor.cc |
| 4 | + * \brief The implementation for building the BNF AST. |
| 5 | + */ |
| 6 | +#include "grammar_constructor.h" |
| 7 | + |
| 8 | +#include <xgrammar/grammar.h> |
| 9 | + |
| 10 | +#include <cstdint> |
| 11 | +#include <string> |
| 12 | + |
| 13 | +#include "grammar_functor.h" |
| 14 | +#include "support/utils.h" |
| 15 | + |
| 16 | +namespace xgrammar { |
| 17 | + |
| 18 | +/*! |
| 19 | + * \brief Implementation of grammar union operation. |
| 20 | + * |
| 21 | + * Creates a new grammar that accepts strings from any of the input grammars. |
| 22 | + * The resulting grammar has a new root rule that chooses between the root rules |
| 23 | + * of all input grammars. |
| 24 | + */ |
| 25 | +class GrammarUnionFunctorImpl : public GrammarMutator { |
| 26 | + public: |
| 27 | + GrammarUnionFunctorImpl() = default; |
| 28 | + |
| 29 | + Grammar Apply(const std::vector<Grammar>& grammars) { |
| 30 | + InitGrammar(); |
| 31 | + InitBuilder(); |
| 32 | + auto root_rule_id = builder_->AddEmptyRule("root"); |
| 33 | + |
| 34 | + std::vector<int32_t> new_root_choices; |
| 35 | + new_root_choices.reserve(grammars.size()); |
| 36 | + |
| 37 | + for (const auto& grammar : grammars) { |
| 38 | + auto new_root_id_for_grammar = SubGrammarAdder().Apply(builder_, grammar); |
| 39 | + auto new_rule_ref = builder_->AddRuleRef(new_root_id_for_grammar); |
| 40 | + auto new_rule_ref_seq = builder_->AddSequence({new_rule_ref}); |
| 41 | + new_root_choices.push_back(new_rule_ref_seq); |
| 42 | + } |
| 43 | + |
| 44 | + builder_->UpdateRuleBody(root_rule_id, builder_->AddChoices(new_root_choices)); |
| 45 | + return builder_->Get(root_rule_id); |
| 46 | + } |
| 47 | + |
| 48 | + // Avoid hiding the original Apply(const Grammar&) |
| 49 | + Grammar Apply(const Grammar& grammar) final { |
| 50 | + XGRAMMAR_LOG(FATAL) << "Should not be called"; |
| 51 | + XGRAMMAR_UNREACHABLE(); |
| 52 | + } |
| 53 | +}; |
| 54 | + |
| 55 | +/*! |
| 56 | + * \brief Implementation of grammar concatenation operation. |
| 57 | + * |
| 58 | + * Creates a new grammar that accepts strings that are concatenations of strings |
| 59 | + * from the input grammars in order. The resulting grammar has a new root rule |
| 60 | + * that concatenates the root rules of all input grammars. |
| 61 | + */ |
| 62 | +class GrammarConcatFunctorImpl : public GrammarMutator { |
| 63 | + public: |
| 64 | + GrammarConcatFunctorImpl() = default; |
| 65 | + |
| 66 | + Grammar Apply(const std::vector<Grammar>& grammars) { |
| 67 | + InitGrammar(); |
| 68 | + InitBuilder(); |
| 69 | + auto root_rule_id = builder_->AddEmptyRule("root"); |
| 70 | + |
| 71 | + std::vector<int32_t> new_root_sequence; |
| 72 | + new_root_sequence.reserve(grammars.size()); |
| 73 | + |
| 74 | + for (const auto& grammar : grammars) { |
| 75 | + auto new_root_id_for_grammar = SubGrammarAdder().Apply(builder_, grammar); |
| 76 | + auto new_rule_ref = builder_->AddRuleRef(new_root_id_for_grammar); |
| 77 | + new_root_sequence.push_back(new_rule_ref); |
| 78 | + } |
| 79 | + |
| 80 | + auto new_root_seq = builder_->AddSequence(new_root_sequence); |
| 81 | + builder_->UpdateRuleBody(root_rule_id, builder_->AddChoices({new_root_seq})); |
| 82 | + |
| 83 | + return builder_->Get(root_rule_id); |
| 84 | + } |
| 85 | + |
| 86 | + // Avoid hiding the original Apply(const Grammar&) |
| 87 | + Grammar Apply(const Grammar& grammar) final { |
| 88 | + XGRAMMAR_LOG(FATAL) << "Should not be called"; |
| 89 | + XGRAMMAR_UNREACHABLE(); |
| 90 | + } |
| 91 | +}; |
| 92 | + |
| 93 | +class StructuralTagGrammarCreatorImpl : public GrammarMutator { |
| 94 | + public: |
| 95 | + Grammar Apply( |
| 96 | + const std::vector<std::string>& triggers, |
| 97 | + const std::vector<std::vector<std::pair<StructuralTagItem, Grammar>>>& tag_groups |
| 98 | + ) { |
| 99 | + XGRAMMAR_CHECK(triggers.size() == tag_groups.size()) |
| 100 | + << "Number of triggers must match number of tag groups"; |
| 101 | + |
| 102 | + InitGrammar(); |
| 103 | + InitBuilder(); |
| 104 | + |
| 105 | + auto root_rule_id = builder_->AddEmptyRule("root"); |
| 106 | + |
| 107 | + Grammar::Impl::TagDispatch tag_dispatch{ |
| 108 | + /* tag_rule_pairs = */ {}, |
| 109 | + /* stop_eos = */ true, |
| 110 | + /* stop_str = */ {}, |
| 111 | + /* loop_after_dispatch = */ true, |
| 112 | + }; |
| 113 | + tag_dispatch.tag_rule_pairs.reserve(triggers.size()); |
| 114 | + |
| 115 | + // Create rules for each trigger group |
| 116 | + for (size_t i = 0; i < triggers.size(); i++) { |
| 117 | + // Skip empty trigger groups |
| 118 | + if (tag_groups[i].empty()) { |
| 119 | + continue; |
| 120 | + } |
| 121 | + |
| 122 | + auto rule_name = "trigger_rule_" + std::to_string(i); |
| 123 | + auto rule_id = builder_->AddEmptyRule(rule_name); |
| 124 | + |
| 125 | + // Create choices for each tag in this trigger group |
| 126 | + std::vector<int32_t> choices; |
| 127 | + choices.reserve(tag_groups[i].size()); |
| 128 | + for (const auto& [tag, schema_grammar] : tag_groups[i]) { |
| 129 | + // Create sequence: start_suffix + schema + end |
| 130 | + std::vector<int32_t> seq_elements; |
| 131 | + seq_elements.reserve(3); |
| 132 | + |
| 133 | + // Add begin suffix (everything after trigger) |
| 134 | + XGRAMMAR_DCHECK(tag.begin.size() >= triggers[i].size()) |
| 135 | + << "Tag begin must be at least as long as trigger"; |
| 136 | + if (tag.begin.size() > triggers[i].size()) { |
| 137 | + seq_elements.push_back(builder_->AddByteString(tag.begin.substr(triggers[i].size()))); |
| 138 | + } |
| 139 | + |
| 140 | + // Create and visit schema grammar for this tag |
| 141 | + auto schema_rule_id = SubGrammarAdder().Apply(builder_, schema_grammar); |
| 142 | + seq_elements.push_back(builder_->AddRuleRef(schema_rule_id)); |
| 143 | + |
| 144 | + // Add end string |
| 145 | + if (!tag.end.empty()) { |
| 146 | + seq_elements.push_back(builder_->AddByteString(tag.end)); |
| 147 | + } |
| 148 | + |
| 149 | + choices.push_back(builder_->AddSequence(seq_elements)); |
| 150 | + } |
| 151 | + |
| 152 | + builder_->UpdateRuleBody(rule_id, builder_->AddChoices(choices)); |
| 153 | + tag_dispatch.tag_rule_pairs.emplace_back(triggers[i], rule_id); |
| 154 | + } |
| 155 | + |
| 156 | + // Create root TagDispatch rule |
| 157 | + auto tag_dispatch_id = builder_->AddTagDispatch(tag_dispatch); |
| 158 | + builder_->UpdateRuleBody(root_rule_id, tag_dispatch_id); |
| 159 | + return builder_->Get(root_rule_id); |
| 160 | + } |
| 161 | + |
| 162 | + // Avoid hiding the original Apply(const Grammar&) |
| 163 | + Grammar Apply(const Grammar& grammar) final { |
| 164 | + XGRAMMAR_LOG(FATAL) << "Should not be called"; |
| 165 | + XGRAMMAR_UNREACHABLE(); |
| 166 | + } |
| 167 | +}; |
| 168 | + |
| 169 | +class TagDispatchGrammarCreatorImpl : public GrammarMutator { |
| 170 | + public: |
| 171 | + Grammar Apply( |
| 172 | + const std::vector<std::string>& triggers, |
| 173 | + const std::vector<Grammar>& tags, |
| 174 | + bool stop_eos, |
| 175 | + bool loop_after_dispatch, |
| 176 | + std::vector<std::string> stop_strs |
| 177 | + ) { |
| 178 | + InitGrammar(); |
| 179 | + InitBuilder(); |
| 180 | + |
| 181 | + auto root_rule_id = builder_->AddEmptyRule("root"); |
| 182 | + |
| 183 | + Grammar::Impl::TagDispatch tag_dispatch{ |
| 184 | + /* tag_rule_pairs = */ {}, |
| 185 | + /* stop_eos = */ stop_eos, |
| 186 | + /* stop_str = */ stop_strs, |
| 187 | + /* loop_after_dispatch = */ loop_after_dispatch, |
| 188 | + }; |
| 189 | + tag_dispatch.tag_rule_pairs.reserve(triggers.size()); |
| 190 | + |
| 191 | + // Create rules for each trigger group |
| 192 | + for (size_t i = 0; i < triggers.size(); i++) { |
| 193 | + auto rule_name = "trigger_rule_" + std::to_string(i); |
| 194 | + auto rule_id = builder_->AddEmptyRule(rule_name); |
| 195 | + |
| 196 | + // Create choices for each tag in this trigger group |
| 197 | + std::vector<int32_t> choices; |
| 198 | + std::vector<int32_t> seq_elements; |
| 199 | + seq_elements.reserve(1); |
| 200 | + |
| 201 | + // Create and visit schema grammar for this tag |
| 202 | + auto new_rule_id = SubGrammarAdder().Apply(builder_, tags[i]); |
| 203 | + seq_elements.push_back(builder_->AddRuleRef(new_rule_id)); |
| 204 | + |
| 205 | + choices.push_back(builder_->AddSequence(seq_elements)); |
| 206 | + |
| 207 | + builder_->UpdateRuleBody(rule_id, builder_->AddChoices(choices)); |
| 208 | + tag_dispatch.tag_rule_pairs.emplace_back(triggers[i], rule_id); |
| 209 | + } |
| 210 | + |
| 211 | + auto tag_dispatch_id = builder_->AddTagDispatch(tag_dispatch); |
| 212 | + builder_->UpdateRuleBody(root_rule_id, tag_dispatch_id); |
| 213 | + |
| 214 | + return builder_->Get(root_rule_id); |
| 215 | + } |
| 216 | + |
| 217 | + // Avoid hiding the original Apply(const Grammar&) |
| 218 | + Grammar Apply(const Grammar& grammar) final { |
| 219 | + XGRAMMAR_LOG(FATAL) << "Should not be called"; |
| 220 | + XGRAMMAR_UNREACHABLE(); |
| 221 | + } |
| 222 | +}; |
| 223 | + |
| 224 | +class StarGrammarCreatorImpl : public GrammarMutator { |
| 225 | + public: |
| 226 | + Grammar Apply(const Grammar& grammar) { |
| 227 | + // Initialize the grammar and builder. |
| 228 | + InitGrammar(); |
| 229 | + InitBuilder(); |
| 230 | + |
| 231 | + // Add a new empty rule for the root. |
| 232 | + auto root_rule_id = builder_->AddEmptyRule("root"); |
| 233 | + |
| 234 | + // Add the original grammar as a subgrammar. |
| 235 | + auto original_root_rule_id = SubGrammarAdder().Apply(builder_, grammar); |
| 236 | + |
| 237 | + // Get a rule reference for root_original. |
| 238 | + auto original_root_rule_ref = builder_->AddRuleRef(original_root_rule_id); |
| 239 | + |
| 240 | + // Get a rule reference for the new root rule. |
| 241 | + auto root_rule_ref = builder_->AddRuleRef(root_rule_id); |
| 242 | + |
| 243 | + // We get root_original root. |
| 244 | + auto new_root_seq = builder_->AddSequence({original_root_rule_ref, root_rule_ref}); |
| 245 | + |
| 246 | + // root ::= "" | root_original root |
| 247 | + auto new_root_choice = builder_->AddChoices({builder_->AddEmptyStr(), new_root_seq}); |
| 248 | + builder_->UpdateRuleBody(root_rule_id, new_root_choice); |
| 249 | + return builder_->Get(root_rule_id); |
| 250 | + } |
| 251 | +}; |
| 252 | + |
| 253 | +/**************************************** Grammar Functions ***************************************/ |
| 254 | + |
| 255 | +Grammar Grammar::Empty() { return Grammar::FromEBNF("root ::= \"\""); } |
| 256 | + |
| 257 | +Grammar Grammar::String(const std::string& str) { |
| 258 | + static const std::unordered_map<char, std::string> kCodepointToEscape = { |
| 259 | + {'\'', "\\\'"}, |
| 260 | + {'\"', "\\\""}, |
| 261 | + {'\?', "\\?"}, |
| 262 | + {'\\', "\\\\"}, |
| 263 | + {'\a', "\\a"}, |
| 264 | + {'\b', "\\b"}, |
| 265 | + {'\f', "\\f"}, |
| 266 | + {'\n', "\\n"}, |
| 267 | + {'\r', "\\r"}, |
| 268 | + {'\t', "\\t"}, |
| 269 | + {'\v', "\\v"}, |
| 270 | + {'\0', "\\0"}, |
| 271 | + {'\x1B', "\\e"} |
| 272 | + }; |
| 273 | + std::string ebnf_string = "root ::= \""; |
| 274 | + for (auto ch : str) { |
| 275 | + if (kCodepointToEscape.find(ch) != kCodepointToEscape.end()) { |
| 276 | + ebnf_string += kCodepointToEscape.at(ch); |
| 277 | + } else { |
| 278 | + ebnf_string += ch; |
| 279 | + } |
| 280 | + } |
| 281 | + ebnf_string += "\""; |
| 282 | + return Grammar::FromEBNF(ebnf_string); |
| 283 | +} |
| 284 | + |
| 285 | +Grammar Grammar::CharacterClass(const std::string& str) { return Grammar::FromRegex(str); } |
| 286 | + |
| 287 | +Grammar Grammar::TagDispatch( |
| 288 | + const std::vector<std::string>& triggers, |
| 289 | + const std::vector<Grammar>& tags, |
| 290 | + bool stop_eos, |
| 291 | + bool loop_after_dispatch, |
| 292 | + const std::vector<std::string>& stop_strs |
| 293 | +) { |
| 294 | + return TagDispatchGrammarCreator::Apply(triggers, tags, stop_eos, loop_after_dispatch, stop_strs); |
| 295 | +} |
| 296 | + |
| 297 | +Grammar Grammar::Union(const std::vector<Grammar>& grammars) { |
| 298 | + return GrammarUnionFunctor::Apply(grammars); |
| 299 | +} |
| 300 | + |
| 301 | +Grammar Grammar::Concat(const std::vector<Grammar>& grammars) { |
| 302 | + return GrammarConcatFunctor::Apply(grammars); |
| 303 | +} |
| 304 | + |
| 305 | +Grammar Grammar::Star(const Grammar& grammar) { return StarGrammarCreator::Apply(grammar); } |
| 306 | + |
| 307 | +Grammar Grammar::Plus(const Grammar& grammar) { |
| 308 | + return Grammar::Concat({grammar, Grammar::Star(grammar)}); |
| 309 | +} |
| 310 | + |
| 311 | +Grammar Grammar::Optional(const Grammar& grammar) { |
| 312 | + return Grammar::Union({grammar, Grammar::Empty()}); |
| 313 | +} |
| 314 | + |
| 315 | +/*************************** Forward grammar Constructors to their impl ***************************/ |
| 316 | + |
| 317 | +Grammar GrammarUnionFunctor::Apply(const std::vector<Grammar>& grammars) { |
| 318 | + return GrammarUnionFunctorImpl().Apply(grammars); |
| 319 | +} |
| 320 | + |
| 321 | +Grammar GrammarConcatFunctor::Apply(const std::vector<Grammar>& grammars) { |
| 322 | + return GrammarConcatFunctorImpl().Apply(grammars); |
| 323 | +} |
| 324 | + |
| 325 | +Grammar StructuralTagGrammarCreator::Apply( |
| 326 | + const std::vector<std::string>& triggers, |
| 327 | + const std::vector<std::vector<std::pair<StructuralTagItem, Grammar>>>& tag_groups |
| 328 | +) { |
| 329 | + return StructuralTagGrammarCreatorImpl().Apply(triggers, tag_groups); |
| 330 | +} |
| 331 | + |
| 332 | +Grammar TagDispatchGrammarCreator::Apply( |
| 333 | + const std::vector<std::string>& triggers, |
| 334 | + const std::vector<Grammar>& tags, |
| 335 | + bool stop_eos, |
| 336 | + bool loop_after_dispatch, |
| 337 | + const std::vector<std::string>& stop_strs |
| 338 | +) { |
| 339 | + return TagDispatchGrammarCreatorImpl().Apply( |
| 340 | + triggers, tags, stop_eos, loop_after_dispatch, stop_strs |
| 341 | + ); |
| 342 | +} |
| 343 | + |
| 344 | +Grammar StarGrammarCreator::Apply(const Grammar& grammar) { |
| 345 | + return StarGrammarCreatorImpl().Apply(grammar); |
| 346 | +} |
| 347 | + |
| 348 | +} // namespace xgrammar |
0 commit comments