Skip to content

Commit 227b5fc

Browse files
committed
[Performance] Avoid computing log-probs when retrieving dist (#3081)
1 parent 130ed3f commit 227b5fc

File tree

3 files changed

+124
-57
lines changed

3 files changed

+124
-57
lines changed

torchrl/modules/llm/policies/common.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from tensordict import lazy_stack, NestedKey, TensorDictBase
13-
from tensordict.nn import TensorDictModuleBase, TensorDictSequential
13+
from tensordict.nn import TensorDictModuleBase
1414
from tensordict.tensorclass import TensorClass
1515
from tensordict.utils import _zip_strict
1616
from torch import distributions as D
@@ -488,7 +488,7 @@ def get_dist(
488488
"You can create a new version of this wrapper using the `get_new_version` method."
489489
)
490490

491-
td_out = self(tensordict.copy())
491+
td_out = self.forward(tensordict.copy(), logits_only=True)
492492

493493
# Get logits/log-probs
494494
if as_padded_tensor is None:
@@ -563,7 +563,7 @@ def _get_dist_with_prompt_mask(
563563
"get_dist_with_prompt_mask is not implemented for generate=True. "
564564
"You can create a new version of this wrapper using the `get_new_version` method."
565565
)
566-
td_out = self(tensordict.copy())
566+
td_out = self.forward(tensordict.copy(), logits_only=True)
567567

568568
# Try to get prompt tokens first
569569
if self.pad_output:
@@ -674,7 +674,7 @@ def _get_dist_with_assistant_mask(
674674
"get_dist_with_assistant_mask is not implemented for generate=True. "
675675
"You can create a new version of this wrapper using the `get_new_version` method."
676676
)
677-
td_out = self(tensordict.copy())
677+
td_out = self.forward(tensordict.copy(), logits_only=True)
678678
# Update the tokens key to reflect the tokenized history when querying the log-probs
679679
tensordict.update(
680680
td_out,
@@ -743,7 +743,7 @@ def _get_dist_with_attention_mask(
743743
"get_dist_with_attention_mask is not implemented for generate=True. "
744744
"You can create a new version of this wrapper using the `get_new_version` method."
745745
)
746-
td_out = self(tensordict.copy())
746+
td_out = self.forward(tensordict.copy(), logits_only=True)
747747
if self.pad_output:
748748
logits = td_out.get(logits_key)
749749
attention_mask = td_out.get(attention_mask_key)
@@ -800,7 +800,7 @@ def _get_dist_with_custom_mask(
800800
"get_dist_with_custom_mask is not implemented for generate=True. "
801801
"You can create a new version of this wrapper using the `get_new_version` method."
802802
)
803-
td_out = self(tensordict.copy())
803+
td_out = self.forward(tensordict.copy(), logits_only=True)
804804
if self.pad_output:
805805
logits = td_out.get(logits_key)
806806
else:
@@ -847,8 +847,24 @@ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribut
847847
"""
848848
return self._get_dist_with_attention_mask(tensordict, **kwargs)
849849

850-
# Sampling is taken care of by the sub-modules
851-
forward = TensorDictSequential.forward
850+
def forward(
851+
self,
852+
tensordict: TensorDictBase,
853+
*,
854+
tensordict_out: TensorDictBase | None = None,
855+
logits_only: bool = False,
856+
**kwargs,
857+
) -> TensorDictBase: # noqa: D417
858+
"""Forward pass for the LLM policy.
859+
860+
Args:
861+
tensordict (TensorDictBase): The input tensordict.
862+
863+
Keyword Args:
864+
tensordict_out (TensorDictBase | None): The output tensordict.
865+
logits_only (bool): Whether to return only the logits. Only effective if generate=False. Defaults to `False`.
866+
"""
867+
raise NotImplementedError
852868

853869
def _check_padded(self, val: torch.Tensor) -> torch.Tensor:
854870
"""Check that a value is a padded tensor."""

torchrl/modules/llm/policies/transformers_wrapper.py

Lines changed: 84 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -469,15 +469,29 @@ def get_new_version(self, **kwargs):
469469
def forward(
470470
self,
471471
tensordict: TensorDictBase,
472+
*,
472473
tensordict_out: TensorDictBase | None = None,
474+
logits_only: bool = False,
473475
**kwargs,
474476
) -> TensorDictBase:
475477
tensordict_orig = tensordict
476478
if not tensordict.ndim:
479+
if tensordict_out is not None:
480+
raise ValueError(
481+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
482+
"please submit an issue on github."
483+
)
477484
# unsqueeze - squeeze the input
478-
return self(lazy_stack([tensordict]))[0]
485+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
479486
elif tensordict.ndim > 1:
480-
return self(tensordict.reshape(-1)).view(tensordict.shape)
487+
if tensordict_out is not None:
488+
raise ValueError(
489+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
490+
"please submit an issue on github."
491+
)
492+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
493+
tensordict.shape
494+
)
481495

482496
if not isinstance(tensordict, LazyStackedTensorDict):
483497
tensordict = tensordict.to_lazystack(0)
@@ -517,17 +531,23 @@ def forward(
517531
if self.generate:
518532
out = self._from_transformers_generate_history(tensordict, cfg, out)
519533
else:
520-
out = self._from_transformers_logprobs_history(tensordict, cfg, out)
534+
out = self._from_transformers_logprobs_history(
535+
tensordict, cfg, out, logits_only=logits_only
536+
)
521537
elif self.input_mode == "text":
522538
if self.generate:
523539
out = self._from_transformers_generate_text(tensordict, cfg, out)
524540
else:
525-
out = self._from_transformers_logprobs_text(tensordict, cfg, out)
541+
out = self._from_transformers_logprobs_text(
542+
tensordict, cfg, out, logits_only=logits_only
543+
)
526544
elif self.input_mode == "tokens":
527545
if self.generate:
528546
out = self._from_transformers_generate_tokens(tensordict, cfg, out)
529547
else:
530-
out = self._from_transformers_logprobs_tokens(tensordict, cfg, out)
548+
out = self._from_transformers_logprobs_tokens(
549+
tensordict, cfg, out, logits_only=logits_only
550+
)
531551

532552
if _source_device:
533553
out = out.to(_source_device)
@@ -690,7 +710,7 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
690710
result.set(self.history_key, history_chat)
691711
return result
692712

693-
def _from_transformers_logprobs_history(self, td, cfg, out):
713+
def _from_transformers_logprobs_history(self, td, cfg, out, logits_only=False):
694714
"""Compute log-probs from history input."""
695715
from torchrl.data.llm import History
696716

@@ -731,7 +751,9 @@ def _from_transformers_logprobs_history(self, td, cfg, out):
731751
raise ValueError(
732752
f"Expected TensorDictBase for history input, got {type(response_tokens)}"
733753
)
734-
result = self._logprobs_from_history_tokens(response_tokens, cfg, out)
754+
result = self._logprobs_from_history_tokens(
755+
response_tokens, cfg, out, logits_only=logits_only
756+
)
735757
text_result = Text._from_tensordict(result.empty())
736758
result.set(self.text_key, text_result)
737759
result[self.text_key, "full"] = text_full
@@ -952,7 +974,9 @@ def _cat_tensors(
952974
result = result.to(cast)
953975
return result
954976

955-
def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
977+
def _logprobs_from_history_tokens(
978+
self, response_tokens, cfg, out, logits_only=False
979+
):
956980
"""Compute log-probs from history tokens."""
957981
pad_val = self.tokenizer.pad_token_id
958982

@@ -996,6 +1020,7 @@ def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
9961020
tokens_full_padded,
9971021
attention_mask_full_padded,
9981022
pad_val,
1023+
logits_only=logits_only,
9991024
)
10001025

10011026
# Build output TensorClass objects
@@ -1051,19 +1076,20 @@ def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
10511076
tokens_obj.padded = MetaData(self.pad_output)
10521077
out.set(self.tokens_key, tokens_obj)
10531078

1054-
log_probs_obj = LogProbs._from_tensordict(
1055-
TensorDict(batch_size=out.batch_size).to_lazystack(0)
1056-
)
1057-
if self.pad_output:
1058-
log_probs_obj.full = log_probs_full_padded
1059-
else:
1060-
log_probs_full_unpadded = _unpad_tensors(
1061-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
1079+
if not logits_only:
1080+
log_probs_obj = LogProbs._from_tensordict(
1081+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
10621082
)
1063-
log_probs_obj.full = log_probs_full_unpadded
1064-
log_probs_obj.response = None
1065-
log_probs_obj.padded = MetaData(self.pad_output)
1066-
out.set(self.log_probs_key, log_probs_obj)
1083+
if self.pad_output:
1084+
log_probs_obj.full = log_probs_full_padded
1085+
else:
1086+
log_probs_full_unpadded = _unpad_tensors(
1087+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
1088+
)
1089+
log_probs_obj.full = log_probs_full_unpadded
1090+
log_probs_obj.response = None
1091+
log_probs_obj.padded = MetaData(self.pad_output)
1092+
out.set(self.log_probs_key, log_probs_obj)
10671093

10681094
# Add logits to output if we're in a get_dist call
10691095
if self._in_get_dist_call:
@@ -1095,7 +1121,7 @@ def _from_transformers_generate_text(self, td, cfg, out) -> TensorDictBase:
10951121
raise ValueError(f"Expected list of text for text input, got {type(text)}")
10961122
return self._generate_from_text(text, cfg, out)
10971123

1098-
def _from_transformers_logprobs_text(self, td, cfg, out):
1124+
def _from_transformers_logprobs_text(self, td, cfg, out, logits_only=False):
10991125
"""Compute log-probs from text input."""
11001126
# Validate input
11011127
if self.input_key not in td:
@@ -1168,6 +1194,7 @@ def _from_transformers_logprobs_text(self, td, cfg, out):
11681194
input_ids_full_padded,
11691195
attention_mask_full_padded,
11701196
self.tokenizer.pad_token_id,
1197+
logits_only=logits_only,
11711198
)
11721199

11731200
# Build output TensorClass objects
@@ -1212,19 +1239,20 @@ def _from_transformers_logprobs_text(self, td, cfg, out):
12121239
masks_obj.padded = MetaData(self.pad_output)
12131240
out.set(self.masks_key, masks_obj)
12141241

1215-
log_probs_obj = LogProbs._from_tensordict(
1216-
TensorDict(batch_size=out.batch_size).to_lazystack(0)
1217-
)
1218-
if self.pad_output:
1219-
log_probs_obj.full = log_probs_full_padded
1220-
else:
1221-
log_probs_full_unpadded = _unpad_tensors(
1222-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
1242+
if not logits_only:
1243+
log_probs_obj = LogProbs._from_tensordict(
1244+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
12231245
)
1224-
log_probs_obj.full = log_probs_full_unpadded
1225-
log_probs_obj.response = None
1226-
log_probs_obj.padded = MetaData(self.pad_output)
1227-
out.set(self.log_probs_key, log_probs_obj)
1246+
if self.pad_output:
1247+
log_probs_obj.full = log_probs_full_padded
1248+
else:
1249+
log_probs_full_unpadded = _unpad_tensors(
1250+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
1251+
)
1252+
log_probs_obj.full = log_probs_full_unpadded
1253+
log_probs_obj.response = None
1254+
log_probs_obj.padded = MetaData(self.pad_output)
1255+
out.set(self.log_probs_key, log_probs_obj)
12281256

12291257
# Add logits to output if we're in a get_dist call
12301258
if self._in_get_dist_call:
@@ -1416,7 +1444,11 @@ def _generate_from_tokens(
14161444
return out
14171445

14181446
def _from_transformers_logprobs_tokens(
1419-
self, td: TensorDictBase, cfg: dict | None, out: TensorDictBase
1447+
self,
1448+
td: TensorDictBase,
1449+
cfg: dict | None,
1450+
out: TensorDictBase,
1451+
logits_only=False,
14201452
) -> TensorDictBase:
14211453
"""Compute log-probs from tokens input."""
14221454
# Validate input
@@ -1470,6 +1502,7 @@ def _from_transformers_logprobs_tokens(
14701502
input_ids_full_padded,
14711503
attention_mask_full_padded,
14721504
self.tokenizer.pad_token_id,
1505+
logits_only=logits_only,
14731506
)
14741507

14751508
# Build output TensorClass objects
@@ -1514,19 +1547,20 @@ def _from_transformers_logprobs_tokens(
15141547
masks_obj.padded = MetaData(self.pad_output)
15151548
out.set(self.masks_key, masks_obj)
15161549

1517-
log_probs_obj = LogProbs._from_tensordict(
1518-
TensorDict(batch_size=out.batch_size).to_lazystack(0)
1519-
)
1520-
if self.pad_output:
1521-
log_probs_obj.full = log_probs_full_padded
1522-
else:
1523-
log_probs_full_unpadded = _unpad_tensors(
1524-
log_probs_full_padded, attention_mask_full_padded, as_nested=False
1550+
if not logits_only:
1551+
log_probs_obj = LogProbs._from_tensordict(
1552+
TensorDict(batch_size=out.batch_size).to_lazystack(0)
15251553
)
1526-
log_probs_obj.full = log_probs_full_unpadded
1527-
log_probs_obj.response = None
1528-
log_probs_obj.padded = MetaData(self.pad_output)
1529-
out.set(self.log_probs_key, log_probs_obj)
1554+
if self.pad_output:
1555+
log_probs_obj.full = log_probs_full_padded
1556+
else:
1557+
log_probs_full_unpadded = _unpad_tensors(
1558+
log_probs_full_padded, attention_mask_full_padded, as_nested=False
1559+
)
1560+
log_probs_obj.full = log_probs_full_unpadded
1561+
log_probs_obj.response = None
1562+
log_probs_obj.padded = MetaData(self.pad_output)
1563+
out.set(self.log_probs_key, log_probs_obj)
15301564

15311565
# Add logits to output if we're in a get_dist call
15321566
if self._in_get_dist_call:
@@ -1567,7 +1601,7 @@ def _log_probs_generate(cls, tokens, logits, pad_val=-100, pad: bool = True):
15671601
return log_probs, logits
15681602

15691603
def _compute_log_probs_from_model_output(
1570-
self, model_output, input_ids, attention_mask, pad_val
1604+
self, model_output, input_ids, attention_mask, pad_val, logits_only=False
15711605
):
15721606
"""Compute log-probs from model output without modifying original tensors.
15731607
@@ -1576,6 +1610,7 @@ def _compute_log_probs_from_model_output(
15761610
input_ids: Original input token ids
15771611
attention_mask: Original attention mask
15781612
pad_val: Padding token value to ignore in loss computation
1613+
logits_only: Whether to return only the logits.
15791614
15801615
Returns:
15811616
tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
@@ -1600,6 +1635,8 @@ def _compute_log_probs_from_model_output(
16001635
raise ValueError(
16011636
f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
16021637
)
1638+
if logits_only:
1639+
return None, shifted_logits
16031640

16041641
# Compute log-probs
16051642
td = TensorDict(

torchrl/modules/llm/policies/vllm_wrapper.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,15 +501,29 @@ def get_new_version(self, **kwargs):
501501
def forward(
502502
self,
503503
tensordict: TensorDictBase,
504+
*,
504505
tensordict_out: TensorDictBase | None = None,
506+
logits_only: bool = False,
505507
**kwargs,
506508
) -> TensorDictBase:
507509
tensordict_orig = tensordict
508510
if not tensordict.ndim:
511+
if tensordict_out is not None:
512+
raise ValueError(
513+
"tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
514+
"please submit an issue on github."
515+
)
509516
# unsqueeze - squeeze the input
510-
return self(lazy_stack([tensordict]))[0]
517+
return self.forward(lazy_stack([tensordict]), logits_only=logits_only)[0]
511518
elif tensordict.ndim > 1:
512-
return self(tensordict.reshape(-1)).view(tensordict.shape)
519+
if tensordict_out is not None:
520+
raise ValueError(
521+
"tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
522+
"please submit an issue on github."
523+
)
524+
return self.forward(tensordict.reshape(-1), logits_only=logits_only).view(
525+
tensordict.shape
526+
)
513527

514528
if not isinstance(tensordict, LazyStackedTensorDict):
515529
tensordict = tensordict.to_lazystack(0)

0 commit comments

Comments
 (0)