77
88from tensorrt_llm import MultimodalEncoder
99from tensorrt_llm .inputs import default_multimodal_input_loader
10- from tensorrt_llm .llmapi import KvCacheConfig
10+ from tensorrt_llm .llmapi import CacheTransceiverConfig , KvCacheConfig
1111from tensorrt_llm .llmapi .llm import LLM , SamplingParams
1212
1313test_data_root = Path (
@@ -41,7 +41,8 @@ def multimodal_model_config():
4141@pytest .mark .parametrize ("model_key" , [
4242 "llava-v1.6-mistral-7b-hf" ,
4343])
44- def test_single_image_chat (model_key , multimodal_model_config ):
44+ @pytest .mark .parametrize ("pd_disagg" , [False , True ])
45+ def test_single_image_chat (model_key , pd_disagg , multimodal_model_config ):
4546 """Test processing single image using encoder (pass mm_embeddings) + LLM API.
4647
4748 This test verifies that encoder (pass mm_embeddings) + LLM API produces identical
@@ -59,7 +60,7 @@ def test_single_image_chat(model_key, multimodal_model_config):
5960
6061 # Test configuration
6162 max_tokens = 64
62- free_gpu_memory_fraction = 0.6
63+ free_gpu_memory_fraction = 0.6 if not pd_disagg else 0.2
6364 max_batch_size = 1
6465
6566 # Test data - OpenAI chat completion format
@@ -76,10 +77,26 @@ def test_single_image_chat(model_key, multimodal_model_config):
7677 # Process multimodal data using encoder (pass mm_embeddings)
7778 encoder = MultimodalEncoder (model = encoder_model_dir ,
7879 max_batch_size = max_batch_size )
80+
81+ cache_transceiver_cfg = CacheTransceiverConfig (
82+ backend = "DEFAULT" ) if pd_disagg else None
83+
84+ disable_overlap_scheduler = pd_disagg
85+
7986 llm = LLM (model = encoder_model_dir ,
8087 backend = 'pytorch' ,
8188 kv_cache_config = kv_cache_config ,
82- trust_remote_code = True )
89+ trust_remote_code = True ,
90+ cache_transceiver_config = cache_transceiver_cfg ,
91+ disable_overlap_scheduler = disable_overlap_scheduler )
92+
93+ llm_decode = None
94+ if pd_disagg :
95+ llm_decode = LLM (model = encoder_model_dir ,
96+ backend = 'pytorch' ,
97+ kv_cache_config = kv_cache_config ,
98+ trust_remote_code = True ,
99+ cache_transceiver_config = cache_transceiver_cfg )
83100
84101 # Load model configuration
85102 config_path = os .path .join (llm ._hf_model_dir , 'config.json' )
@@ -122,10 +139,23 @@ def test_single_image_chat(model_key, multimodal_model_config):
122139 ep_disaggregated_params = encoder_outputs [0 ].disaggregated_params
123140
124141 assert ep_disaggregated_params is not None , "Encoder output disaggregated params is None"
125- ep_disaggregated_params .request_type = "context_and_generation"
142+ ep_disaggregated_params .request_type = "context_and_generation" if not pd_disagg else "context_only"
126143 outputs = llm .generate (inputs ,
127144 sampling_params = sampling_params ,
128145 disaggregated_params = ep_disaggregated_params )
146+
147+ if pd_disagg :
148+ # Generation using llm_decode
149+ assert len (outputs ) == 1
150+ pd_disaggregated_params = outputs [0 ].disaggregated_params
151+ pd_disaggregated_params .request_type = "generation_only"
152+ sampling_params = SamplingParams (max_tokens = max_tokens )
153+
154+ outputs = llm_decode .generate (
155+ inputs ,
156+ sampling_params = sampling_params ,
157+ disaggregated_params = pd_disaggregated_params )
158+
129159 # Validate outputs
130160 assert len (outputs ) == len (
131161 prompts ), f"Expected { len (prompts )} outputs, got { len (outputs )} "
0 commit comments