Skip to content

Commit a6a0ba3

Browse files
[SW-238029] [1.22]Fix max_batch_size handling - Lllama perf degradation fix (#1828)
Llama Perf degradation seen with Gemma3 suport: #1616. : max_batch_size was initialized incorrectly for the profile_run due to mm_registry checking instead of actual multimodal models. Fix to only initialized to 1 when multimodal(mrope or mm_optimized) model is in use. ## Test Result Llama v3.1 70B 2048/128 BF16 2xcard - perf drop 170 tps to 150 tps. With this fix, it's back to 170tps --------- Co-authored-by: Iryna Boiko <[email protected]>
1 parent 8fad535 commit a6a0ba3

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

vllm/worker/hpu_model_runner.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,11 +1491,10 @@ def move_to_device(self, tensor):
14911491
non_blocking=True)
14921492

14931493
def add_vision_buckets_to_mrope_mm_optimized(self):
1494-
if self.mm_registry is not None:
1495-
model = self.get_model()
1496-
self.is_mm_optimized = is_mm_optimized(model)
1497-
if self.model_is_mrope or self.is_mm_optimized:
1498-
model.vision_buckets = VisionBuckets(self.is_mm_optimized)
1494+
model = self.get_model()
1495+
self.is_mm_optimized = is_mm_optimized(model)
1496+
if self.model_is_mrope or self.is_mm_optimized:
1497+
model.vision_buckets = VisionBuckets(self.is_mm_optimized)
14991498

15001499
def _prepare_prompt(
15011500
self,
@@ -2804,10 +2803,10 @@ def profile_run(self) -> None:
28042803
max_seq_len = self.bucketing_manager.get_max_prompt_shape()
28052804
max_batch_size = min(self.max_num_seqs,
28062805
self.max_num_batched_tokens // max_seq_len)
2807-
# Using batch_size 1 is profile multimodal models
2808-
max_batch_size = max_batch_size if self.mm_registry is None else 1
28092806

28102807
if self.model_is_mrope or self.is_mm_optimized:
2808+
# Using batch_size 1 is profile multimodal models
2809+
max_batch_size = 1
28112810
model = self.get_model()
28122811
self.multimodal_buckets = model.vision_buckets.multimodal_buckets
28132812
logger_msg = "Multimodal bucket : " + str(self.multimodal_buckets)

0 commit comments

Comments
 (0)