Skip to content

Commit 2ad0b2b

Browse files
Fix p2p pushing in rpc_inference (by @miaoqijun ) , support transformers 4.38.2 (#563)
This pull request solves #560 using a solution proposed by @miaoqijun . It also bumps transformers to the latest version to test with the latest code. --------- Co-authored-by: Yingtong Dou <[email protected]>
1 parent efee5d1 commit 2ad0b2b

File tree

6 files changed

+23
-11
lines changed

6 files changed

+23
-11
lines changed

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ install_requires =
3737
accelerate>=0.27.2
3838
huggingface-hub>=0.11.1,<1.0.0
3939
tokenizers>=0.13.3
40-
transformers==4.37.1 # if you change this, please also change version assert in petals/__init__.py
40+
transformers==4.38.2 # if you change this, please also change version assert in petals/__init__.py
4141
speedtest-cli==2.1.3
4242
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
4343
hivemind==1.1.10.post2

src/petals/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
2424
assert (
25-
version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
26-
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"
25+
version.parse("4.38.2") <= version.parse(transformers.__version__) < version.parse("4.39.0")
26+
), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.39.0"
2727

2828

2929
def _override_bfloat16_mode_default():

src/petals/models/llama/block.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,15 @@ def forward(
5050
past_key_value: Optional[Tuple[torch.Tensor]] = None,
5151
output_attentions: bool = False,
5252
use_cache: bool = False,
53+
cache_position: Optional[torch.LongTensor] = None,
5354
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
5455
assert not output_attentions
55-
assert position_ids is None
56+
if position_ids is None:
57+
past_seen_tokens = past_key_value[0].shape[2] if past_key_value is not None else 0
58+
position_ids = torch.arange(
59+
past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
60+
).unsqueeze(0)
61+
5662
bsz, q_len, _ = hidden_states.size()
5763

5864
if self.config.pretraining_tp > 1:
@@ -84,9 +90,8 @@ def forward(
8490
kv_seq_len = key_states.shape[-2]
8591
if past_key_value is not None:
8692
kv_seq_len += past_key_value[0].shape[-2]
87-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
88-
cos = cos[kv_seq_len - q_len :]
89-
sin = sin[kv_seq_len - q_len :]
93+
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
94+
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
9095

9196
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
9297
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
@@ -160,6 +165,8 @@ def forward(
160165
past_key_value: Optional[Tuple[torch.Tensor]] = None,
161166
output_attentions: Optional[bool] = False,
162167
use_cache: Optional[bool] = False,
168+
cache_position: Optional[torch.LongTensor] = None,
169+
**kwargs,
163170
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
164171
"""
165172
Args:
@@ -190,6 +197,8 @@ def forward(
190197
past_key_value=past_key_value,
191198
output_attentions=output_attentions,
192199
use_cache=use_cache,
200+
cache_position=cache_position,
201+
**kwargs,
193202
)
194203

195204
hidden_states = residual + hidden_states

src/petals/models/llama/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def forward(
4747
output_attentions: Optional[bool] = None,
4848
output_hidden_states: Optional[bool] = None,
4949
return_dict: Optional[bool] = None,
50+
cache_position: Optional[torch.LongTensor] = None,
5051
) -> BaseModelOutputWithPast:
5152
if input_ids is not None and inputs_embeds is not None:
5253
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -62,6 +63,8 @@ def forward(
6263
assert (
6364
attention_mask is None or (attention_mask == 1).all()
6465
), f"Custom attention masks are not supported, {attention_mask=}"
66+
if cache_position is not None:
67+
assert position_ids is not None and torch.all(torch.eq(cache_position, position_ids)).item()
6568
assert (
6669
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
6770
), f"Non-consecutive position_ids are not supported, {position_ids=}"

src/petals/server/block_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def iterate_rpc_inference(
153153
points: int,
154154
quant_type: QuantType,
155155
args_structure: Any = None,
156-
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool]]:
156+
) -> AsyncIterator[Tuple[Sequence[runtime_pb2.Tensor], bool, Dict]]:
157157
assert len(cache_handles) == len(requested_backends)
158158

159159
prefix_length = 0
@@ -224,7 +224,7 @@ async def iterate_rpc_inference(
224224
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
225225
]
226226
can_push = not has_prompts
227-
yield output_tensors, can_push
227+
yield output_tensors, can_push, step_metadata
228228

229229
# prepare for next step
230230
prefix_length += length_increment

src/petals/server/handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ async def rpc_inference(
171171
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
172172
) as cache_handles:
173173
background_tasks = set()
174-
async for output_tensors, can_push in iterate_rpc_inference(
174+
async for output_tensors, can_push, step_metadata in iterate_rpc_inference(
175175
requested_uids=requested_uids,
176176
requested_backends=requested_backends,
177177
active_adapter=self._get_active_adapter(metadata),
@@ -186,7 +186,7 @@ async def rpc_inference(
186186
args_structure=args_structure,
187187
):
188188
if can_push:
189-
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], metadata))
189+
task = asyncio.create_task(self._push_outputs(request, output_tensors[0], step_metadata))
190190
background_tasks.add(task) # Keep reference until it is done to save it from GC
191191
task.add_done_callback(background_tasks.discard)
192192
yield runtime_pb2.ExpertResponse(tensors=output_tensors)

0 commit comments

Comments
 (0)