@@ -469,15 +469,29 @@ def get_new_version(self, **kwargs):
469469 def forward (
470470 self ,
471471 tensordict : TensorDictBase ,
472+ * ,
472473 tensordict_out : TensorDictBase | None = None ,
474+ logits_only : bool = False ,
473475 ** kwargs ,
474476 ) -> TensorDictBase :
475477 tensordict_orig = tensordict
476478 if not tensordict .ndim :
479+ if tensordict_out is not None :
480+ raise ValueError (
481+ "tensordict_out must not be provided when tensordict.ndim == 0. If this is needed, "
482+ "please submit an issue on github."
483+ )
477484 # unsqueeze - squeeze the input
478- return self (lazy_stack ([tensordict ]))[0 ]
485+ return self . forward (lazy_stack ([tensordict ]), logits_only = logits_only )[0 ]
479486 elif tensordict .ndim > 1 :
480- return self (tensordict .reshape (- 1 )).view (tensordict .shape )
487+ if tensordict_out is not None :
488+ raise ValueError (
489+ "tensordict_out must not be provided when tensordict.ndim > 1. If this is needed, "
490+ "please submit an issue on github."
491+ )
492+ return self .forward (tensordict .reshape (- 1 ), logits_only = logits_only ).view (
493+ tensordict .shape
494+ )
481495
482496 if not isinstance (tensordict , LazyStackedTensorDict ):
483497 tensordict = tensordict .to_lazystack (0 )
@@ -517,17 +531,23 @@ def forward(
517531 if self .generate :
518532 out = self ._from_transformers_generate_history (tensordict , cfg , out )
519533 else :
520- out = self ._from_transformers_logprobs_history (tensordict , cfg , out )
534+ out = self ._from_transformers_logprobs_history (
535+ tensordict , cfg , out , logits_only = logits_only
536+ )
521537 elif self .input_mode == "text" :
522538 if self .generate :
523539 out = self ._from_transformers_generate_text (tensordict , cfg , out )
524540 else :
525- out = self ._from_transformers_logprobs_text (tensordict , cfg , out )
541+ out = self ._from_transformers_logprobs_text (
542+ tensordict , cfg , out , logits_only = logits_only
543+ )
526544 elif self .input_mode == "tokens" :
527545 if self .generate :
528546 out = self ._from_transformers_generate_tokens (tensordict , cfg , out )
529547 else :
530- out = self ._from_transformers_logprobs_tokens (tensordict , cfg , out )
548+ out = self ._from_transformers_logprobs_tokens (
549+ tensordict , cfg , out , logits_only = logits_only
550+ )
531551
532552 if _source_device :
533553 out = out .to (_source_device )
@@ -690,7 +710,7 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
690710 result .set (self .history_key , history_chat )
691711 return result
692712
693- def _from_transformers_logprobs_history (self , td , cfg , out ):
713+ def _from_transformers_logprobs_history (self , td , cfg , out , logits_only = False ):
694714 """Compute log-probs from history input."""
695715 from torchrl .data .llm import History
696716
@@ -731,7 +751,9 @@ def _from_transformers_logprobs_history(self, td, cfg, out):
731751 raise ValueError (
732752 f"Expected TensorDictBase for history input, got { type (response_tokens )} "
733753 )
734- result = self ._logprobs_from_history_tokens (response_tokens , cfg , out )
754+ result = self ._logprobs_from_history_tokens (
755+ response_tokens , cfg , out , logits_only = logits_only
756+ )
735757 text_result = Text ._from_tensordict (result .empty ())
736758 result .set (self .text_key , text_result )
737759 result [self .text_key , "full" ] = text_full
@@ -952,7 +974,9 @@ def _cat_tensors(
952974 result = result .to (cast )
953975 return result
954976
955- def _logprobs_from_history_tokens (self , response_tokens , cfg , out ):
977+ def _logprobs_from_history_tokens (
978+ self , response_tokens , cfg , out , logits_only = False
979+ ):
956980 """Compute log-probs from history tokens."""
957981 pad_val = self .tokenizer .pad_token_id
958982
@@ -996,6 +1020,7 @@ def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
9961020 tokens_full_padded ,
9971021 attention_mask_full_padded ,
9981022 pad_val ,
1023+ logits_only = logits_only ,
9991024 )
10001025
10011026 # Build output TensorClass objects
@@ -1051,19 +1076,20 @@ def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
10511076 tokens_obj .padded = MetaData (self .pad_output )
10521077 out .set (self .tokens_key , tokens_obj )
10531078
1054- log_probs_obj = LogProbs ._from_tensordict (
1055- TensorDict (batch_size = out .batch_size ).to_lazystack (0 )
1056- )
1057- if self .pad_output :
1058- log_probs_obj .full = log_probs_full_padded
1059- else :
1060- log_probs_full_unpadded = _unpad_tensors (
1061- log_probs_full_padded , attention_mask_full_padded , as_nested = False
1079+ if not logits_only :
1080+ log_probs_obj = LogProbs ._from_tensordict (
1081+ TensorDict (batch_size = out .batch_size ).to_lazystack (0 )
10621082 )
1063- log_probs_obj .full = log_probs_full_unpadded
1064- log_probs_obj .response = None
1065- log_probs_obj .padded = MetaData (self .pad_output )
1066- out .set (self .log_probs_key , log_probs_obj )
1083+ if self .pad_output :
1084+ log_probs_obj .full = log_probs_full_padded
1085+ else :
1086+ log_probs_full_unpadded = _unpad_tensors (
1087+ log_probs_full_padded , attention_mask_full_padded , as_nested = False
1088+ )
1089+ log_probs_obj .full = log_probs_full_unpadded
1090+ log_probs_obj .response = None
1091+ log_probs_obj .padded = MetaData (self .pad_output )
1092+ out .set (self .log_probs_key , log_probs_obj )
10671093
10681094 # Add logits to output if we're in a get_dist call
10691095 if self ._in_get_dist_call :
@@ -1095,7 +1121,7 @@ def _from_transformers_generate_text(self, td, cfg, out) -> TensorDictBase:
10951121 raise ValueError (f"Expected list of text for text input, got { type (text )} " )
10961122 return self ._generate_from_text (text , cfg , out )
10971123
1098- def _from_transformers_logprobs_text (self , td , cfg , out ):
1124+ def _from_transformers_logprobs_text (self , td , cfg , out , logits_only = False ):
10991125 """Compute log-probs from text input."""
11001126 # Validate input
11011127 if self .input_key not in td :
@@ -1168,6 +1194,7 @@ def _from_transformers_logprobs_text(self, td, cfg, out):
11681194 input_ids_full_padded ,
11691195 attention_mask_full_padded ,
11701196 self .tokenizer .pad_token_id ,
1197+ logits_only = logits_only ,
11711198 )
11721199
11731200 # Build output TensorClass objects
@@ -1212,19 +1239,20 @@ def _from_transformers_logprobs_text(self, td, cfg, out):
12121239 masks_obj .padded = MetaData (self .pad_output )
12131240 out .set (self .masks_key , masks_obj )
12141241
1215- log_probs_obj = LogProbs ._from_tensordict (
1216- TensorDict (batch_size = out .batch_size ).to_lazystack (0 )
1217- )
1218- if self .pad_output :
1219- log_probs_obj .full = log_probs_full_padded
1220- else :
1221- log_probs_full_unpadded = _unpad_tensors (
1222- log_probs_full_padded , attention_mask_full_padded , as_nested = False
1242+ if not logits_only :
1243+ log_probs_obj = LogProbs ._from_tensordict (
1244+ TensorDict (batch_size = out .batch_size ).to_lazystack (0 )
12231245 )
1224- log_probs_obj .full = log_probs_full_unpadded
1225- log_probs_obj .response = None
1226- log_probs_obj .padded = MetaData (self .pad_output )
1227- out .set (self .log_probs_key , log_probs_obj )
1246+ if self .pad_output :
1247+ log_probs_obj .full = log_probs_full_padded
1248+ else :
1249+ log_probs_full_unpadded = _unpad_tensors (
1250+ log_probs_full_padded , attention_mask_full_padded , as_nested = False
1251+ )
1252+ log_probs_obj .full = log_probs_full_unpadded
1253+ log_probs_obj .response = None
1254+ log_probs_obj .padded = MetaData (self .pad_output )
1255+ out .set (self .log_probs_key , log_probs_obj )
12281256
12291257 # Add logits to output if we're in a get_dist call
12301258 if self ._in_get_dist_call :
@@ -1416,7 +1444,11 @@ def _generate_from_tokens(
14161444 return out
14171445
14181446 def _from_transformers_logprobs_tokens (
1419- self , td : TensorDictBase , cfg : dict | None , out : TensorDictBase
1447+ self ,
1448+ td : TensorDictBase ,
1449+ cfg : dict | None ,
1450+ out : TensorDictBase ,
1451+ logits_only = False ,
14201452 ) -> TensorDictBase :
14211453 """Compute log-probs from tokens input."""
14221454 # Validate input
@@ -1470,6 +1502,7 @@ def _from_transformers_logprobs_tokens(
14701502 input_ids_full_padded ,
14711503 attention_mask_full_padded ,
14721504 self .tokenizer .pad_token_id ,
1505+ logits_only = logits_only ,
14731506 )
14741507
14751508 # Build output TensorClass objects
@@ -1514,19 +1547,20 @@ def _from_transformers_logprobs_tokens(
15141547 masks_obj .padded = MetaData (self .pad_output )
15151548 out .set (self .masks_key , masks_obj )
15161549
1517- log_probs_obj = LogProbs ._from_tensordict (
1518- TensorDict (batch_size = out .batch_size ).to_lazystack (0 )
1519- )
1520- if self .pad_output :
1521- log_probs_obj .full = log_probs_full_padded
1522- else :
1523- log_probs_full_unpadded = _unpad_tensors (
1524- log_probs_full_padded , attention_mask_full_padded , as_nested = False
1550+ if not logits_only :
1551+ log_probs_obj = LogProbs ._from_tensordict (
1552+ TensorDict (batch_size = out .batch_size ).to_lazystack (0 )
15251553 )
1526- log_probs_obj .full = log_probs_full_unpadded
1527- log_probs_obj .response = None
1528- log_probs_obj .padded = MetaData (self .pad_output )
1529- out .set (self .log_probs_key , log_probs_obj )
1554+ if self .pad_output :
1555+ log_probs_obj .full = log_probs_full_padded
1556+ else :
1557+ log_probs_full_unpadded = _unpad_tensors (
1558+ log_probs_full_padded , attention_mask_full_padded , as_nested = False
1559+ )
1560+ log_probs_obj .full = log_probs_full_unpadded
1561+ log_probs_obj .response = None
1562+ log_probs_obj .padded = MetaData (self .pad_output )
1563+ out .set (self .log_probs_key , log_probs_obj )
15301564
15311565 # Add logits to output if we're in a get_dist call
15321566 if self ._in_get_dist_call :
@@ -1567,7 +1601,7 @@ def _log_probs_generate(cls, tokens, logits, pad_val=-100, pad: bool = True):
15671601 return log_probs , logits
15681602
15691603 def _compute_log_probs_from_model_output (
1570- self , model_output , input_ids , attention_mask , pad_val
1604+ self , model_output , input_ids , attention_mask , pad_val , logits_only = False
15711605 ):
15721606 """Compute log-probs from model output without modifying original tensors.
15731607
@@ -1576,6 +1610,7 @@ def _compute_log_probs_from_model_output(
15761610 input_ids: Original input token ids
15771611 attention_mask: Original attention mask
15781612 pad_val: Padding token value to ignore in loss computation
1613+ logits_only: Whether to return only the logits.
15791614
15801615 Returns:
15811616 tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
@@ -1600,6 +1635,8 @@ def _compute_log_probs_from_model_output(
16001635 raise ValueError (
16011636 f"The logits shape { shifted_logits .shape } does not match the input ids shape { shifted_input_ids .shape } "
16021637 )
1638+ if logits_only :
1639+ return None , shifted_logits
16031640
16041641 # Compute log-probs
16051642 td = TensorDict (
0 commit comments