Skip to content

Commit 206db55

Browse files
authored
[Grammar] Fix include protection and paths in docstring (#2515)
Following #2464, This PR fixes the include protecting in the header files and the paths in the docstrings of the header files. This PR also fixes tests that were broken after the refactor.
1 parent 868334d commit 206db55

31 files changed

+148
-192
lines changed

cpp/grammar/grammar.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
* \brief The header for the support of grammar-guided generation.
55
*/
66

7-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_H_
8-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_H_
7+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_H_
8+
#define MLC_LLM_GRAMMAR_GRAMMAR_H_
99

1010
#include <tvm/runtime/object.h>
1111
#include <tvm/runtime/registry.h>
@@ -191,10 +191,10 @@ class BNFGrammar : public ObjectRef {
191191
* format of the schema of a JSON file. We will parse the schema and generate a BNF grammar.
192192
* \param schema The schema string.
193193
* \param indent The number of spaces for indentation. If set to std::nullopt, the output will be
194-
* in one line. Default: std::nullopt.
194+
* in one line. Default: 2.
195195
* \param separators Two separators used in the schema: comma and colon. Examples: {",", ":"},
196196
* {", ", ": "}. If std::nullopt, the default separators will be used: {",", ": "} when the
197-
* indent is not -1, and {", ", ": "} otherwise. This follows the convention in python
197+
* indent is not nullopt, and {", ", ": "} otherwise. This follows the convention in python
198198
* json.dumps(). Default: std::nullopt.
199199
* \param strict_mode Whether to use strict mode. In strict mode, the generated grammar will not
200200
* allow properties and items that is not specified in the schema. This is equivalent to
@@ -223,4 +223,4 @@ class BNFGrammar : public ObjectRef {
223223
} // namespace llm
224224
} // namespace mlc
225225

226-
#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_H_
226+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_H_

cpp/grammar/grammar_builder.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
* \brief The header for the building the BNF AST.
55
*/
66

7-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_BUILDER_H_
8-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_BUILDER_H_
7+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_BUILDER_H_
8+
#define MLC_LLM_GRAMMAR_GRAMMAR_BUILDER_H_
99
#include <tvm/runtime/object.h>
1010

1111
#include <cstdint>
@@ -38,7 +38,7 @@ class BNFGrammarBuilder {
3838
*/
3939
BNFGrammar Get(const std::string& main_rule = "main") {
4040
int32_t main_rule_id = GetRuleId(main_rule);
41-
CHECK(main_rule_id != -1) << "The in rule with name \"" << main_rule << "\" is not found.";
41+
CHECK(main_rule_id != -1) << "The main rule with name \"" << main_rule << "\" is not found.";
4242
grammar_->main_rule_id_ = main_rule_id;
4343

4444
return BNFGrammar(grammar_);
@@ -251,4 +251,4 @@ class BNFGrammarBuilder {
251251
} // namespace llm
252252
} // namespace mlc
253253

254-
#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_BUILDER_H_
254+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_BUILDER_H_

cpp/grammar/grammar_functor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
* \brief The header for the simplification of the BNF AST.
55
*/
66

7-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_
8-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_
7+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_FUNCTOR_H_
8+
#define MLC_LLM_GRAMMAR_GRAMMAR_FUNCTOR_H_
99

1010
#include <queue>
1111
#include <string>
@@ -216,4 +216,4 @@ class BNFGrammarNormalizer : public BNFGrammarMutator {
216216
} // namespace llm
217217
} // namespace mlc
218218

219-
#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_
219+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_FUNCTOR_H_

cpp/grammar/grammar_parser.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
* \brief The header for the parser of BNF/EBNF grammar into BNF AST.
55
*/
66

7-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_PARSER_H_
8-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_PARSER_H_
7+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_PARSER_H_
8+
#define MLC_LLM_GRAMMAR_GRAMMAR_PARSER_H_
99

1010
#include <tvm/runtime/container/string.h>
1111
#include <tvm/runtime/logging.h>
@@ -65,4 +65,4 @@ class BNFJSONParser {
6565
} // namespace llm
6666
} // namespace mlc
6767

68-
#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_PARSER_H_
68+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_PARSER_H_

cpp/grammar/grammar_serializer.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include "grammar_serializer.h"
77

88
#include <picojson.h>
9-
#include <tvm/runtime/memory.h>
109
#include <tvm/runtime/registry.h>
1110

1211
#include "../support/encoding.h"

cpp/grammar/grammar_serializer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
* \brief The header for printing the AST of a BNF grammar.
55
*/
66

7-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SERIALIZER_H_
8-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SERIALIZER_H_
7+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_SERIALIZER_H_
8+
#define MLC_LLM_GRAMMAR_GRAMMAR_SERIALIZER_H_
99

1010
#include <string>
1111

@@ -114,4 +114,4 @@ class BNFGrammarJSONSerializer : public BNFGrammarSerializer {
114114
} // namespace llm
115115
} // namespace mlc
116116

117-
#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SERIALIZER_H_
117+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_SERIALIZER_H_

cpp/grammar/grammar_state_matcher.cc

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,6 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
246246
// {-1} means the universal set, i.e. all tokens initially
247247
tmp_rejected_indices_.assign({-1});
248248

249-
// std::chrono::microseconds time_unc(0);
250-
// std::chrono::microseconds time_idx(0);
251249
int check_cnt = 0;
252250

253251
for (auto top : latest_stack_tops) {
@@ -258,8 +256,6 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
258256

259257
const auto& catagorized_tokens = catagorized_tokens_for_grammar.at(cur_rule_position);
260258

261-
// auto start = std::chrono::high_resolution_clock::now();
262-
263259
// For each stack, we will check every uncertain token and put them into the accepted or
264260
// rejected list.
265261

@@ -277,35 +273,6 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
277273
const std::string* prev_token = nullptr;
278274
int prev_matched_size = 0;
279275

280-
// std::cout << tree_.PrintNode(top) << std::endl;
281-
282-
// std::cout << "Accepted count: " << catagorized_tokens.accepted_indices.size()
283-
// << ", rejected count: " << catagorized_tokens.rejected_indices.size()
284-
// << ", uncertain count: " << catagorized_tokens.uncertain_indices.size()
285-
// << ", save type: " << static_cast<int>(catagorized_tokens.save_type) << std::endl;
286-
287-
// if (catagorized_tokens.accepted_indices.size() < 200) {
288-
// std::cout << "Accpeted: ";
289-
// for (int i = 0; i < catagorized_tokens.accepted_indices.size(); ++i) {
290-
// std::cout << "<"
291-
// << PrintAsEscaped(
292-
// sorted_token_table[catagorized_tokens.accepted_indices[i]].second)
293-
// << "> ";
294-
// }
295-
// std::cout << "\n";
296-
// }
297-
298-
// if (catagorized_tokens.uncertain_indices.size() > 100) {
299-
// std::cout << "Uncertain: ";
300-
// for (int i = 0; i < catagorized_tokens.uncertain_indices.size(); ++i) {
301-
// std::cout << "<"
302-
// << PrintAsEscaped(
303-
// sorted_token_table[catagorized_tokens.uncertain_indices[i]].second)
304-
// << "> ";
305-
// }
306-
// std::cout << "\n";
307-
// }
308-
309276
for (auto cur_token_idx : catagorized_tokens.uncertain_indices) {
310277
const auto& cur_token = sorted_token_table[cur_token_idx].second;
311278
bool accepted = true;
@@ -354,13 +321,7 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
354321

355322
RollbackChars(prev_matched_size + 1);
356323

357-
// auto end = std::chrono::high_resolution_clock::now();
358-
359-
// time_unc += std::chrono::duration_cast<std::chrono::microseconds>(end - start);
360-
361-
// start = std::chrono::high_resolution_clock::now();
362-
363-
// Step 3. Update the accepted_indices and rejected_indices
324+
// Step 3. Update the accepted_indices or rejected_indices
364325
if (catagorized_tokens.save_type == SaveType::kAcceptedBitset) {
365326
tmp_accepted_bitset_ |= catagorized_tokens.accepted_bitset;
366327
} else if (catagorized_tokens.save_type == SaveType::kAccepted) {
@@ -374,19 +335,11 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm
374335
IntsetUnion(&tmp_rejected_indices_delta_, catagorized_tokens.rejected_indices);
375336
IntsetIntersection(&tmp_rejected_indices_, tmp_rejected_indices_delta_);
376337
}
377-
// end = std::chrono::high_resolution_clock::now();
378-
// time_idx += std::chrono::duration_cast<std::chrono::microseconds>(end - start);
379338
}
380339

381340
// Finally update the rejected_ids bitset
382-
// auto start = std::chrono::high_resolution_clock::now();
383341
bool can_reach_end = CanReachEnd();
384342
SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end);
385-
// auto end = std::chrono::high_resolution_clock::now();
386-
// time_idx += std::chrono::duration_cast<std::chrono::microseconds>(end - start);
387-
// std::cout << "Time for uncertain: " << time_unc.count()
388-
// << "us, time for index: " << time_idx.count() << "us" << std::endl;
389-
// std::cout << "Check cnt " << check_cnt << std::endl;
390343
}
391344

392345
void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) {

cpp/grammar/grammar_state_matcher.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
* logic of the grammar-guided generation.
66
*/
77

8-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_H_
9-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_H_
8+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_H_
9+
#define MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_H_
1010

1111
#include <tvm/runtime/object.h>
1212
#include <tvm/runtime/registry.h>
@@ -172,4 +172,4 @@ class GrammarInitContextCache : public ObjectRef {
172172
} // namespace llm
173173
} // namespace mlc
174174

175-
#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_H_
175+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_H_

cpp/grammar/grammar_state_matcher_base.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
* \file grammar/grammar_state_matcher_base.h
44
* \brief The base class of GrammarStateMatcher. It implements a character-based matching automata.
55
*/
6-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
7-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
6+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
7+
#define MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
88

99
#include <vector>
1010

@@ -109,7 +109,8 @@ class GrammarStateMatcherBase {
109109
// We store the stack tops in different steps in the history to support rollback.
110110
StackTopsHistory stack_tops_history_;
111111

112-
// Temporary data for AcceptChar.
112+
// Temporary data for AcceptChar, PushInitialState, etc to store new stacks.
113+
// They are stored here to avoid repeated allocation.
113114
std::vector<int32_t> tmp_new_stack_tops_;
114115
};
115116

@@ -267,21 +268,21 @@ inline void GrammarStateMatcherBase::PushInitialState(RulePosition init_rule_pos
267268
// Initialize the stack with the main rule.
268269
auto main_rule = grammar_->GetMainRule();
269270
auto main_rule_body = grammar_->GetRuleExpr(main_rule.body_expr_id);
270-
std::vector<int32_t> stack_tops;
271+
tmp_new_stack_tops_.clear();
271272
for (auto i : main_rule_body) {
272273
auto init_rule_position = RulePosition(0, i, 0, RulePosition::kNoParent);
273274
if (expand_init_rule_position) {
274-
ExpandRulePosition(init_rule_position, &stack_tops, true);
275+
ExpandRulePosition(init_rule_position, &tmp_new_stack_tops_, true);
275276
} else {
276-
stack_tops.push_back(tree_.NewNode(init_rule_position));
277+
tmp_new_stack_tops_.push_back(tree_.NewNode(init_rule_position));
277278
}
278279
}
279-
stack_tops_history_.PushHistory(stack_tops);
280+
stack_tops_history_.PushHistory(tmp_new_stack_tops_);
280281
} else {
281282
if (expand_init_rule_position) {
282-
std::vector<int32_t> stack_tops;
283-
ExpandRulePosition(init_rule_position, &stack_tops, true);
284-
stack_tops_history_.PushHistory(stack_tops);
283+
tmp_new_stack_tops_.clear();
284+
ExpandRulePosition(init_rule_position, &tmp_new_stack_tops_, true);
285+
stack_tops_history_.PushHistory(tmp_new_stack_tops_);
285286
} else {
286287
stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)});
287288
}
@@ -397,4 +398,4 @@ inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_po
397398
} // namespace llm
398399
} // namespace mlc
399400

400-
#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_
401+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_BASE_H_

cpp/grammar/grammar_state_matcher_preproc.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
* \file grammar/grammar_state_matcher_preproc.h
44
* \brief The header for the preprocessing of the grammar state matcher.
55
*/
6-
#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_
7-
#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_
6+
#ifndef MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_
7+
#define MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_
88

99
#include <vector>
1010

@@ -309,6 +309,8 @@ inline std::shared_ptr<GrammarStateInitContext> GrammarStateMatcher::CreateInitC
309309

310310
for (int i = 0; i < token_table.size(); ++i) {
311311
const auto& token = token_table[i];
312+
// TODO(yixin): Now we detect stop tokens from the token string. We should be able to pass
313+
// the stop token set in.
312314
// LLaMA2: </s>
313315
// LLaMA3: <|end_of_text|>, <|eot_id|>
314316
// Phi-2: <|endoftext|>
@@ -432,4 +434,4 @@ GrammarInitContextCache::GrammarInitContextCache(const std::vector<std::string>&
432434
} // namespace llm
433435
} // namespace mlc
434436

435-
#endif // TVM_LLVM_COMPILE_ENGINE_CPP_SERVE_GRAMMAR_STATE_MATCHER_PREPROC_H_
437+
#endif // MLC_LLM_GRAMMAR_GRAMMAR_STATE_MATCHER_PREPROC_H_

0 commit comments

Comments
 (0)