Skip to content

Commit 9fd759d

Browse files
committed
add generation cfg
1 parent b2ae632 commit 9fd759d

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

lightllm/server/api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,4 +392,5 @@ async def startup_event():
392392
pd_master_start(g_objs)
393393
else:
394394
init_tokenizer(args) # for openai api
395+
SamplingParams.load_generation_cfg(args.model_dir)
395396
normal_or_p_d_start(g_objs)

lightllm/server/sampling_params.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Sampling parameters for text generation."""
22
import os
33
from typing import List, Optional, Union, Tuple
4-
4+
from transformers import GenerationConfig
55
from .req_id_generator import MAX_BEST_OF
66

77
_SAMPLING_EPS = 1e-5
@@ -10,18 +10,27 @@
1010

1111

1212
class SamplingParams:
13+
14+
_do_sample: bool = (False,)
15+
_presence_penalty: float = (0.0,)
16+
_frequency_penalty: float = (0.0,)
17+
_repetition_penalty: float = (1.0,)
18+
_temperature: float = (1.0,)
19+
_top_p: float = (1.0,)
20+
_top_k: int = (-1,) # -1 is for all
21+
1322
def __init__(
1423
self,
1524
best_of: int = 1,
1625
n: int = None, # number of results
17-
do_sample: bool = False,
18-
presence_penalty: float = 0.0,
19-
frequency_penalty: float = 0.0,
20-
repetition_penalty: float = 1.0,
26+
do_sample: bool = None,
27+
presence_penalty: float = None,
28+
frequency_penalty: float = None,
29+
repetition_penalty: float = None,
2130
exponential_decay_length_penalty: Tuple[int, float] = (1, 1.0),
22-
temperature: float = 1.0,
23-
top_p: float = 1.0,
24-
top_k: int = -1, # -1 is for all
31+
temperature: float = None,
32+
top_p: float = None,
33+
top_k: int = None, # -1 is for all
2534
ignore_eos: bool = False,
2635
max_new_tokens: int = 16,
2736
min_new_tokens: int = 1,
@@ -46,14 +55,18 @@ def __init__(
4655
) -> None:
4756
self.best_of = best_of
4857
self.n = n
49-
self.do_sample = do_sample
50-
self.presence_penalty = presence_penalty
51-
self.frequency_penalty = frequency_penalty
52-
self.repetition_penalty = repetition_penalty
58+
self.do_sample = do_sample if do_sample is not None else SamplingParams._do_sample
59+
self.presence_penalty = presence_penalty if presence_penalty is not None else SamplingParams._presence_penalty
60+
self.frequency_penalty = (
61+
frequency_penalty if frequency_penalty is not None else SamplingParams._frequency_penalty
62+
)
63+
self.repetition_penalty = (
64+
repetition_penalty if repetition_penalty is not None else SamplingParams._repetition_penalty
65+
)
5366
self.exponential_decay_length_penalty = exponential_decay_length_penalty
54-
self.temperature = temperature
55-
self.top_p = top_p
56-
self.top_k = top_k
67+
self.temperature = temperature if temperature is not None else SamplingParams._temperature
68+
self.top_p = top_p if top_p is not None else SamplingParams._top_p
69+
self.top_k = top_k if top_k is not None else SamplingParams._top_k
5770
self.ignore_eos = ignore_eos
5871
self.max_new_tokens = max_new_tokens
5972
self.min_new_tokens = min_new_tokens
@@ -81,6 +94,20 @@ def __init__(
8194
self.n = self.best_of
8295
return
8396

97+
@classmethod
98+
def load_generation_cfg(cls, weight_dir):
99+
try:
100+
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
101+
cls._do_sample = generation_cfg.get("do_sample", False)
102+
cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
103+
cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
104+
cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
105+
cls._temperature = generation_cfg.get("temperature", 1.0)
106+
cls._top_p = generation_cfg.get("top_p", 1.0)
107+
cls._top_k = generation_cfg.get("top_k", -1)
108+
except:
109+
pass
110+
84111
def verify(self):
85112
if self.best_of <= 0 or self.best_of > MAX_BEST_OF:
86113
raise ValueError(f"need 0 < best_of <= {MAX_BEST_OF}, but get {self.best_of}")

0 commit comments

Comments
 (0)