Skip to content

Commit 9ed5a23

Browse files
fix max-token conflict w/ DS (#49)
1 parent 508f33b commit 9ed5a23

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

llmserve/backend/llm/predictor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
init_torch_dist_process_group_async,
2020
initialize_node,
2121
timeit,
22+
get_max_token_size,
2223
)
2324
from llmserve.backend.logger import get_logger
2425
from llmserve.backend.server.models import Args, LLMConfig, Prompt, Response
@@ -94,7 +95,7 @@ def init_model(
9495

9596
if llm_config.warmup and warmup_inputs:
9697
prowarmup_inputs_max = Prompt(prompt=warmup_inputs * (
97-
int(llm_config.max_input_words / (len(warmup_inputs.split()) + 1)) + 1
98+
int(get_max_token_size(llm_config) / (len(warmup_inputs.split()) + 1))
9899
), use_prompt_format=False)
99100

100101
logger.info(

llmserve/backend/llm/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.hub import _get_torch_home
2020

2121
from llmserve.backend.logger import get_logger
22-
from llmserve.backend.server.models import S3MirrorConfig
22+
from llmserve.backend.server.models import S3MirrorConfig, LLMConfig
2323

2424
logger = get_logger(__name__)
2525

@@ -279,6 +279,14 @@ async def init_torch_dist_process_group_async(
279279
node_id = node_and_gpu_ids[rank][0]
280280
local_rank = node_to_workers[node_id].index(rank)
281281
local_world_size = len(node_to_workers[node_id])
282+
logger.info("++++++++++++++")
283+
logger.info(rank)
284+
logger.info(world_size)
285+
logger.info(local_rank)
286+
logger.info(local_world_size)
287+
logger.info(master_addr)
288+
logger.info(master_port)
289+
logger.info(list(node_to_gpu_ids[node_id]))
282290
setup_futures.append(
283291
worker.execute.remote(
284292
_init_torch_distributed,
@@ -301,3 +309,7 @@ async def init_torch_dist_process_group_async(
301309
await asyncio.gather(*setup_futures)
302310

303311
return local_ranks
312+
313+
# To get max input token size for warmup. w/ DS, there is "max_tokens" localed "initializer/max_tokens", it will conflict with "max_input_words", prefer "max_tokens" if both existed
314+
def get_max_token_size(llm_config: LLMConfig):
315+
return llm_config.initialization.initializer.max_tokens if hasattr(llm_config.initialization.initializer, "max_tokens") else llm_config.max_input_words

models/text-generation--bigscience--bloom-3b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ model_config:
3232
trust_remote_code: true
3333
pipeline: default
3434
generation:
35-
max_batch_size: 2
36-
batch_wait_timeout_s: 30
35+
max_batch_size: 10
36+
batch_wait_timeout_s: 0
3737
generate_kwargs:
3838
do_sample: false
3939
max_new_tokens: 512

0 commit comments

Comments
 (0)