Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/petals/server/block_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ async def iterate_rpc_inference(
points: int,
quant_type: QuantType,
args_structure: Any = None,
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
assert len(cache_handles) == len(requested_backends)

prefix_length = 0
Expand Down Expand Up @@ -224,7 +224,7 @@ async def iterate_rpc_inference(
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
can_push = not has_prompts
yield output_tensors, can_push
yield output_tensors, can_push, step_metadata

# prepare for next step
prefix_length += length_increment
4 changes: 2 additions & 2 deletions src/petals/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ async def rpc_inference(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
Expand All @@ -186,7 +186,7 @@ async def rpc_inference(
args_structure=args_structure,
):
if can_push:
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
background_tasks.add(task) # Keep reference until it is done to save it from GC
task.add_done_callback(background_tasks.discard)
yield runtime_pb2.ExpertResponse(tensors=output_tensors)
Expand Down