Skip to content

Commit 6cc9e6f

Browse files
committed
Merge remote-tracking branch 'origin/main' into sh/update_dlblas
2 parents c8ce820 + cf75374 commit 6cc9e6f

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

lmdeploy/pytorch/engine/model_agent.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def to_tensor(self):
5858
if isinstance(self.vals, torch.Tensor):
5959
vals = self.vals
6060
else:
61-
vals = torch.from_numpy(vals)
61+
vals = torch.from_numpy(self.vals)
6262
return BatchedLogProbs(vals=vals, indices=torch.from_numpy(self.indices))
6363

6464

@@ -729,7 +729,7 @@ def __update_inputs(next_token_ids, model_metas, extra_inputs):
729729
inputs,
730730
return_logits=return_logits,
731731
sync_long_context=sync_long_context,
732-
return_routed_experts=return_routed_experts and need_output,
732+
return_routed_experts=return_routed_experts and self.need_output,
733733
)
734734
logits = output['logits'][0] # [bs, seq, prob] -> [seq, prob]
735735
seq_length = output.get('seq_length', inputs.seq_length)
@@ -959,8 +959,7 @@ def _build_model(self):
959959
update_custom_module_map(custom_module_map)
960960
logger.debug(msg_with_rank(rank, 'build model.'))
961961
# for router replay
962-
need_output = self.dist_ctx.dp > 1 or self.dist_ctx.rank % self.dist_ctx.tp == 0
963-
enable_return_routed_experts = self.misc_config.enable_return_routed_experts and need_output
962+
enable_return_routed_experts = self.misc_config.enable_return_routed_experts and self.need_output
964963

965964
build_model_ctx = BuildModelContext(
966965
disable_vision_encoder=self.misc_config.disable_vision_encoder,

lmdeploy/serve/proxy/proxy.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
587587
if request.stream is True:
588588
response = node_manager.stream_generate(request_dict, node_url, '/v1/chat/completions')
589589
background_task = node_manager.create_background_tasks(node_url, start)
590-
return StreamingResponse(response, background=background_task)
590+
return StreamingResponse(response, background=background_task, media_type='text/event-stream')
591591
else:
592592
response = await node_manager.generate(request_dict, node_url, '/v1/chat/completions')
593593
node_manager.post_call(node_url, start)
@@ -649,7 +649,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
649649
if request.stream is True:
650650
response = node_manager.stream_generate(request_dict, d_url, '/v1/chat/completions')
651651
background_task = node_manager.create_background_tasks(d_url, start)
652-
resp = StreamingResponse(response, background=background_task)
652+
resp = StreamingResponse(response, background=background_task, media_type='text/event-stream')
653653
else:
654654
response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions')
655655
node_manager.post_call(d_url, start)
@@ -717,7 +717,7 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
717717
if request.stream is True:
718718
response = node_manager.stream_generate(request_dict, node_url, '/v1/completions')
719719
background_task = node_manager.create_background_tasks(node_url, start)
720-
return StreamingResponse(response, background=background_task)
720+
return StreamingResponse(response, background=background_task, media_type='text/event-stream')
721721
else:
722722
response = await node_manager.generate(request_dict, node_url, '/v1/completions')
723723
node_manager.post_call(node_url, start)
@@ -793,7 +793,7 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
793793
if request.stream is True:
794794
response = node_manager.stream_generate(request_dict, d_url, '/v1/completions')
795795
background_task = node_manager.create_background_tasks(d_url, start)
796-
resp = StreamingResponse(response, background=background_task)
796+
resp = StreamingResponse(response, background=background_task, media_type='text/event-stream')
797797
else:
798798
response = await node_manager.generate(request_dict, d_url, '/v1/completions')
799799
node_manager.post_call(d_url, start)

0 commit comments

Comments
 (0)