Skip to content

Commit f57dc01

Browse files
authored
[https://nvbugs/5625380][chore] Remove multimodal related fields from decoder llm input (#8846)
1 parent 0f42a24 commit f57dc01

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,8 @@ def generate_async(
426426
prompt = inputs.get("prompt", None)
427427
query_token_ids = inputs.get("query_token_ids", None)
428428
if is_gen_only:
429-
# TODO: support generation-only mode for multimodal disaggregated inference
430-
# Need to set multimodal_params = None; but not tested yet
431429
raise ValueError(
432-
"Multimodal disaggregated inference is not supported for generation-only mode"
430+
"Generation-only mode should not need multimodal parameters"
433431
)
434432
else:
435433
mm_hashes = disaggregated_params.multimodal_hashes

tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
150150
pd_disaggregated_params = outputs[0].disaggregated_params
151151
pd_disaggregated_params.request_type = "generation_only"
152152
sampling_params = SamplingParams(max_tokens=max_tokens)
153+
inputs[0][
154+
'multi_modal_data'] = None # remove multimodal data from input as decoder worker doesn't need it
155+
inputs[0]['prompt_token_ids'] = outputs[
156+
0].prompt_token_ids # use prompt token ids from encoder output
153157

154158
outputs = llm_decode.generate(
155159
inputs,
@@ -169,9 +173,9 @@ def test_single_image_chat(model_key, pd_disagg, multimodal_model_config):
169173
), f"Number of outputs don't match: {len(outputs_ref)} vs {len(outputs)}"
170174

171175
for i, (ref_output, test_output) in enumerate(zip(outputs_ref, outputs)):
172-
# Compare prompts
173-
assert ref_output.prompt == test_output.prompt, \
174-
f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}"
176+
# Cannot compare prompts as decoder worker would void it
177+
#assert ref_output.prompt == test_output.prompt, \
178+
# f"Prompts don't match for output {i}:\nReference: {ref_output.prompt!r}\nTest: {test_output.prompt!r}"
175179

176180
# Compare number of generated outputs
177181
assert len(ref_output.outputs) == len(test_output.outputs), \

0 commit comments

Comments
 (0)