Skip to content

Commit a887420

Browse files
committed
add additional arguments for mm data
Signed-off-by: John Calderon <jcalderon@nvidia.com>
1 parent 2c4f551 commit a887420

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

tensorrt_llm/_torch/models/modeling_qwen2vl.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,18 +234,17 @@ def get_dummy_text(self, input_seq_len: int):
234234
for _ in range(input_seq_len)
235235
])
236236

237-
def get_dummy_images(self,
238-
max_width: int,
239-
max_height: int,
240-
num_images: int = 1):
237+
def get_dummy_images(self, max_width: int, max_height: int,
238+
num_images: int):
241239
image = Image.new("RGB", (max_width, max_height), color=255)
242240
return [image] * num_images
243241

244-
def get_dummy_prompt(self, input_seq_len: int):
242+
def get_dummy_prompt(self, input_seq_len: int, mm_data: dict):
243+
num_images = mm_data.get("image", 0)
245244
text = self.get_dummy_text(input_seq_len)
246245
images = self.get_dummy_images(
247-
max_width=3584,
248-
max_height=3584) #sqrt of max_pixels value (12845056)
246+
max_width=3584, max_height=3584,
247+
num_images=num_images) #w, h is sqrt of max_pixels value (12845056)
249248
return default_multimodal_input_loader(
250249
tokenizer=self.tokenizer,
251250
model_dir=self.model_path,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def _create_dummy_mm_context_request(
163163
"Profiling with the default input dummy context request. This may not take into account the memory consumption of " \
164164
"ViT's encoder")
165165
return requests
166-
text_prompt = input_processor.get_dummy_prompt(input_seq_len)
166+
text_prompt = input_processor.get_dummy_prompt(input_seq_len,
167+
{'image': 1})
167168
max_beam_width = self._max_beam_width
168169
input_processor_with_hash = create_input_processor_with_hash(
169170
input_processor)

tensorrt_llm/inputs/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class BaseDummyInputsBuilder:
4747
Base class for generating dummy inputs. Specially for profiling
4848
"""
4949

50-
def get_dummy_prompt(self):
50+
def get_dummy_prompt(self, input_seq_len: int, mm_data: dict):
5151
raise NotImplementedError(
5252
"Please ensure this method is implemented in your inherited class")
5353

0 commit comments

Comments
 (0)