You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Fix] Fix the efficiency issue in repetition structure. (#467)
As reported in #465, this PR's target is to fix the efficiency issue in
the repetition structure.
The issue is caused by two reasons in the repetition construction
method:
- The threshold is too small.
- The construction of the other part, except for the repetition
expression, is not good enough.
This PR fixes the problem. The efficiency can be evaluated with the
following code:
``` python
import xgrammar as xgr
from transformers import AutoTokenizer
from xgrammar.testing import _get_matcher_from_grammar_and_tokenizer_info
import time
schema = {
"additionalProperties": False,
"properties": {
"description": {
"maxLength": 200,
"minLength": 10,
"title": "Some description",
"type": "string"
}
},
"required": [
"description"
],
"title": "Some object",
"type": "object"
}
tokenizer_path = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True, trust_remote_code=True)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
grammar = xgr.Grammar.from_json_schema(schema)
matcher = _get_matcher_from_grammar_and_tokenizer_info(grammar, tokenizer_info)
test_str = "{\"description\": \"1234567890123456789101234567891012345678910123456789101234567891" + \
"012345678910123456789101234567891012345678910123456789101234567891012345678910123456" + \
"78910123456789101234567897891012345678910123456789\"}"
token_bitmask = xgr.allocate_token_bitmask(1, tokenizer_info.vocab_size)
tpot = []
print(grammar)
for char in test_str:
start_time = time.time_ns()
matcher.fill_next_token_bitmask(token_bitmask)
assert matcher.accept_string(char)
end_time = time.time_ns()
print(f"Processed char '{char}' in {(end_time - start_time)/1e3} us")
tpot.append((end_time - start_time)/1e3)
tpot.sort()
print(f"avg: {sum(tpot) / len(tpot)} us")
print(f"max: {tpot[-1]} us")
print(f"p50: {tpot[int(0.5 * len(tpot))]} us")
print(f"p90: {tpot[int(0.9 * len(tpot))]} us")
```
And the results are listed below:
```
v0.1.21
avg: 544.5573364055299 us
max: 6720.648 us
p50: 332.969 us
p90: 620.647 us
main branch
avg: 4789.501852534562 us
max: 7862.335 us
p50: 5813.238 us
p90: 6398.427 us
This PR
avg: 353.5097096774194 us
max: 1064.672 us
p50: 449.098 us
p90: 664.901 us
```
---------
Signed-off-by: Yuchuan <[email protected]>
0 commit comments