Skip to content

Commit 96d12d4

Browse files
authored
[Improvement]support base-port & flask-port (#7668)
* support base-port & flask-port * update flask_port
1 parent e5eefa7 commit 96d12d4

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

llm/flask_server.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def __free_port(port):
5353
@dataclass
5454
class ServerArgument:
5555
port: int = field(default=8011, metadata={"help": "The port of ui service"})
56-
base_port: int = field(default=8010, metadata={"help": "The port of flask service"})
56+
base_port: int = field(default=None, metadata={"help": "The port of flask service"})
57+
flask_port: int = field(default=None, metadata={"help": "The port of flask service"})
5758
title: str = field(default="LLM", metadata={"help": "The title of gradio"})
5859
sub_title: str = field(default="LLM-subtitle", metadata={"help": "The sub-title of gradio"})
5960

@@ -64,8 +65,8 @@ def __init__(self, args: ServerArgument, predictor: BasePredictor):
6465
self.predictor = predictor
6566
self.args = args
6667
scan_l, scan_u = (
67-
self.args.base_port + port_interval * predictor.tensor_parallel_rank,
68-
self.args.base_port + port_interval * (predictor.tensor_parallel_rank + 1),
68+
self.args.flask_port + port_interval * predictor.tensor_parallel_rank,
69+
self.args.flask_port + port_interval * (predictor.tensor_parallel_rank + 1),
6970
)
7071

7172
if self.predictor.tensor_parallel_rank == 0:
@@ -174,6 +175,14 @@ def start_ui_service(self, args):
174175

175176
parser = PdArgumentParser((PredictorArgument, ModelArgument, ServerArgument))
176177
predictor_args, model_args, server_args = parser.parse_args_into_dataclasses()
178+
# check port
179+
if server_args.base_port is not None:
180+
logger.warning("`--base_port` is deprecated, please use `--flask_port` instead after 2023.12.30.")
181+
182+
if server_args.flask_port is None:
183+
server_args.flask_port = server_args.base_port
184+
else:
185+
logger.warning("`--base_port` and `--flask_port` are both set, `--base_port` will be ignored.")
177186

178187
log_dir = os.getenv("PADDLE_LOG_DIR", "./")
179188
PORT_FILE = os.path.join(log_dir, PORT_FILE)

llm/gradio_ui.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_l
9595
"max_length": max_length,
9696
"min_length": 1,
9797
}
98-
res = requests.post(f"http://0.0.0.0:{args.base_port}/api/chat", json=data, stream=True)
98+
res = requests.post(f"http://0.0.0.0:{args.flask_port}/api/chat", json=data, stream=True)
9999
for line in res.iter_lines():
100100
result = json.loads(line)
101101
bot_response = result["result"]["response"]

tests/llm/test_gradio.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def is_port_in_use(port):
3939
class UITest(unittest.TestCase):
4040
def setUp(self):
4141
# start web ui
42-
self.base_port = self.avaliable_free_port()
43-
self.port = self.avaliable_free_port([self.base_port])
42+
self.flask_port = self.avaliable_free_port()
43+
self.port = self.avaliable_free_port([self.flask_port])
4444
self.model_path = "__internal_testing__/tiny-random-llama"
45-
command = 'cd llm && python flask_server.py --model_name_or_path {model_path} --port {port} --base_port {base_port} --src_length 1024 --dtype "float16"'.format(
46-
base_port=self.base_port, port=self.port, model_path=self.model_path
45+
command = 'cd llm && python flask_server.py --model_name_or_path {model_path} --port {port} --flask_port {flask_port} --src_length 1024 --dtype "float16"'.format(
46+
flask_port=self.flask_port, port=self.port, model_path=self.model_path
4747
)
4848
self.ui_process = subprocess.Popen(command, shell=True, stdout=sys.stdout, stderr=sys.stderr)
4949
self.tokenizer = LlamaTokenizer.from_pretrained(self.model_path)
@@ -66,7 +66,7 @@ def avaliable_free_port(self, exclude=None):
6666

6767
def wait_until_server_is_ready(self):
6868
while True:
69-
if is_port_in_use(self.base_port) and is_port_in_use(self.port):
69+
if is_port_in_use(self.flask_port) and is_port_in_use(self.port):
7070
break
7171

7272
print("waiting for server ...")
@@ -84,7 +84,7 @@ def test_argument(self):
8484
self.wait_until_server_is_ready()
8585

8686
def get_response(data):
87-
res = requests.post(f"http://localhost:{self.base_port}/api/chat", json=data, stream=True)
87+
res = requests.post(f"http://localhost:{self.flask_port}/api/chat", json=data, stream=True)
8888
result_ = ""
8989
for line in res.iter_lines():
9090
print(line)

0 commit comments

Comments
 (0)