Skip to content

Commit 4285ddb

Browse files
committed
fix
1 parent 6858586 commit 4285ddb

File tree

5 files changed

+59
-3
lines changed

5 files changed

+59
-3
lines changed

src/petals/client/inference_session.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
8484
break # this message means "done sending"
8585

8686
def step(
87-
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
87+
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *,
88+
step_id: str, last_validated_position: int
8889
) -> torch.Tensor:
8990
"""
9091
Inference step: send a chunk of input tensors and receive a chunk of outputs
@@ -94,6 +95,12 @@ def step(
9495
if self.closed:
9596
raise Exception("Session is closed, cannot perform step")
9697

98+
if last_validated_position is not None:
99+
assert last_validated_position <= self._position
100+
self._position = last_validated_position
101+
if self.history is not None and self.history.shape[1] >= last_validated_position:
102+
self.history = self.history[:, :last_validated_position, :] if last_validated_position > 0 else None
103+
97104
n_input_tokens = inputs.shape[1]
98105
if self.history is None:
99106
self.history = inputs
@@ -115,6 +122,8 @@ def step(
115122
request_metadata = dict(session_id=self.session_id, step_id=step_id)
116123
if not self.stepped:
117124
request_metadata.update(self.session_metadata)
125+
if last_validated_position is not None:
126+
request_metadata["last_validated_position"] = last_validated_position
118127
elif self.config.use_server_to_server:
119128
next_servers = self._collect_next_servers()
120129
if next_servers:
@@ -257,8 +266,13 @@ def __enter__(self) -> "InferenceSession":
257266
return self
258267

259268
def step(
260-
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
269+
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None,
270+
hypo_ids: Optional[torch.Tensor] = None, last_validated_position: Optional[int] = None
261271
) -> torch.Tensor:
272+
273+
if last_validated_position is not None:
274+
self._position = last_validated_position
275+
262276
assert not self._closed
263277
if torch.is_grad_enabled():
264278
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -303,7 +317,8 @@ def step(
303317

304318
server_session = self._server_sessions[server_idx]
305319
inputs = server_session.step(
306-
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
320+
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids,
321+
step_id=step_id, last_validated_position=last_validated_position
307322
)
308323

309324
server_idx += 1

src/petals/server/.handler.py.swp

20 KB
Binary file not shown.

src/petals/server/block_functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ async def iterate_rpc_inference(
160160
point_per_piece = points / max_length if max_length > 0 else 0.0
161161

162162
async for request, step_metadata in input_iterator:
163+
if "last_validated_position" in step_metadata:
164+
last_validated_position = min(step_metadata["last_validated_position"], prefix_length)
165+
assert prefix_length >= last_validated_position, f"prefix_length={prefix_length}, last_validated_position={last_validated_position}"
166+
prefix_length = last_validated_position
167+
163168
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
164169
if args_structure is not None:
165170
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation

src/petals/server/handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ async def rpc_inference(
150150
max_length = metadata.get("max_length")
151151
points = metadata.get("points", 0)
152152
session_id = metadata.get("session_id")
153+
last_validated_position = metadata.get("last_validated_position", None)
153154
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
154155
args_structure = metadata.get("args_structure")
155156
if not requested_uids:
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
6+
from petals import AutoDistributedConfig, RemoteSequential
7+
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
8+
from petals.server.from_pretrained import load_pretrained_block
9+
from test_utils import *
10+
11+
12+
@pytest.mark.forked
13+
def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
14+
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
15+
remote_sequential = RemoteSequential(config)
16+
17+
block_index = random.randint(0, config.num_hidden_layers - 1)
18+
remote_block = remote_sequential[block_index]
19+
20+
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
21+
short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
22+
short_inputs[:, :2, :] = inputs[:, :2, :]
23+
24+
initial_outputs_inference = None
25+
secondary_outputs_inference = None
26+
with torch.inference_mode():
27+
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
28+
initial_outputs_inference = sess.step(inputs)
29+
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], last_validated_position=2)
30+
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
31+
32+
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
33+
(outputs_local,) = ref_block(short_inputs)
34+
35+
assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)

0 commit comments

Comments
 (0)