diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f08aab228..9296d3b95 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -8,13 +8,14 @@ import logging import uuid from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable import torch -import torch.nn.functional as F from datasets import load_dataset from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig +from forge.actors.reference_model import HFReferenceModel from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import compute_logprobs, Episode from forge.controller.actor import ForgeActor from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from forge.data.rewards import MathReward, ThinkingReward @@ -28,21 +29,6 @@ logger.setLevel(logging.DEBUG) -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - context_length = logits.shape[1] - input_ids.shape[1] - - # Truncate request logits and drop last - logits = logits[:, context_length - 1 : -1] - - # Compute logprobs - logprobs = torch.log_softmax(logits / temperature, dim=-1) - logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) - - return logprobs - - class SimpleGRPOLoss(nn.Module): """Simplified GRPO Loss for simplified single step updates Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. @@ -68,41 +54,6 @@ def forward(self, logprobs, ref_logprobs, advantages, padding_mask): return loss -@dataclass -class Episode: - # TODO: add adtional layer for multi-turn - episode_id: str - request: str - policy_version: int - pad_id: int - request_len: int - response_len: int - target: Optional[Any] = None - # processed data - response: Optional[str] = None - request_tokens: Optional[list[int]] = None - response_tokens: Optional[list[int]] = None - ref_logprobs: Optional[torch.Tensor] = None - reward: Optional[float] = None - advantage: Optional[float] = None - - @property - def request_tensor(self): - tensor = torch.tensor(self.request_tokens, dtype=torch.long) - if tensor.shape[0] < self.request_len: # left pad - diff = self.request_len - tensor.shape[0] - tensor = F.pad(tensor, (diff, 0), value=self.pad_id) - return tensor - - @property - def response_tensor(self): - tensor = torch.tensor(self.response_tokens, dtype=torch.long) - if tensor.shape[0] < self.response_len: # right pad - diff = self.response_len - tensor.shape[0] - tensor = F.pad(tensor, (0, diff), value=self.pad_id) - return tensor - - @dataclass class Group: group_id: str @@ -250,38 +201,6 @@ async def compute(self, group: Group) -> list[float]: return x -class RefModel(ForgeActor): - def __init__(self, model_name, device: torch.device | None = None): - super().__init__() - self.model_name = model_name - - if device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = device - - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - dtype=torch.bfloat16, - trust_remote_code=True, - ).to(self.device) - self.model.eval() - - self.logger.info(f"Model initialized on {self.device}") - - @endpoint - async def forward(self, episode: Episode) -> torch.Tensor: - req, res = episode.request_tensor, episode.response_tensor - input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0) - mask = input_ids != episode.pad_id - - with torch.inference_mode(): - logits = self.model(input_ids=input_ids, attention_mask=mask).logits - - input_ids = input_ids[:, len(req) :] - return compute_logprobs(logits, input_ids) - - @dataclass class DatasetActor(ForgeActor): """Actor wrapper for HuggingFace dataset to provide async interface.""" @@ -387,7 +306,7 @@ async def main(): ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True), - RefModel, + HFReferenceModel, model_name=model, ), spawn_service( diff --git a/src/forge/actors/__init__.py b/src/forge/actors/__init__.py index 70198120b..54e450cd7 100644 --- a/src/forge/actors/__init__.py +++ b/src/forge/actors/__init__.py @@ -24,9 +24,9 @@ def __getattr__(name): from .replay_buffer import ReplayBuffer return ReplayBuffer - elif name == "TitanRefModel": - from .reference_actor import TitanRefModel + elif name == "ReferenceModel": + from .reference_model import ReferenceModel - return TitanRefModel + return ReferenceModel else: raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/forge/actors/reference_actor.py b/src/forge/actors/reference_actor.py deleted file mode 100644 index 28e0f9814..000000000 --- a/src/forge/actors/reference_actor.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import asyncio -import logging -import math -import os - -from collections import deque -from collections.abc import Mapping -from dataclasses import dataclass, field, fields - -from typing import Any - -import torch -from monarch.actor import current_rank, current_size, endpoint -from omegaconf import DictConfig, OmegaConf -from torch import nn - -from torchtitan.components.lr_scheduler import LRSchedulersContainer -from torchtitan.config.job_config import Comm, Model, Parallelism -from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.experiments.forge.engine import ForgeEngine -from torchtitan.experiments.forge.job_config import ForgeJobConfig -from transformers import AutoModelForCausalLM - -from forge.controller import ForgeActor - - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -@dataclass -class TitanRefModel(ForgeActor): - """ - Represents a reference actor leveraging a torchtitan model for execution - - Intended for generating reference_logprobs - for example in KL Divergence - """ - - # Refer to titan JobConfig for enabling more ForgeEngine configuration - model: Model = field(default_factory=Model) - parallelism: Parallelism = field(default_factory=Parallelism) - - # Populated in setup - # TODO: Commented out since engine_config parsing extracts from class members - # engine: ForgeEngine | None = None - - def __post_init__(self): - """Initializes config types and env variables.""" - # Instantiate dict fields - for f in fields(self): - attr = getattr(self, f.name) - if isinstance(attr, Mapping): - setattr(self, f.name, f.type(**attr)) - elif not isinstance(attr, f.type): - raise TypeError( - f"{f.name} should be a {f.type} type or a dict like object" - ) - - """ - torchrun normally hands env variables, but we need to do it ourselves - in monarch for now. - """ - self.rank = current_rank().rank - self.size = math.prod(current_size().values()) - - env = { - "RANK": str(self.rank), - "LOCAL_RANK": str(self.rank), - "LOCAL_WORLD_SIZE": str(self.size), - "GROUP_RANK": str(self.size), - "GROUP_WORLD_SIZE": str(self.size), - "ROLE_RANK": str(self.rank), - "ROLE_WORLD_SIZE": str(self.size), - "ROLE_NAME": "rank", - "WORLD_SIZE": str(self.size), - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - } - os.environ.update(env) - - @endpoint - async def setup(self): - engine_config = {f.name: getattr(self, f.name) for f in fields(self)} - self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) - - @endpoint - async def forward(self, request: list[int], response: list[int]) -> torch.Tensor: - """ - Given a request and response tokens, return the log_probability of the - token_ids, shape (completion_len, ) - - """ - model_parts = self.engine.model_parts - parallel_dims = self.engine.parallel_dims - - # Use provided token_ids directly - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - input_ids = torch.tensor( - request + response, dtype=torch.long, device=device - ).unsqueeze(0) - - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None - ) - - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet") - else: - # (jackkhuu) Not sure if either context are needed for inference here - with self.engine.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 - with self.engine.maybe_enable_amp: - # Titan Tranformer - logits = model_parts[0](input_ids) - - # Compute logprobs - input_ids = input_ids[:, len(request) :] - # (bsz=1, completion_len) - logprobs = compute_logprobs(logits, input_ids) - # (completion_len, ) - return logprobs.squeeze(0) - - return pred - - -# Based on torchtune's grpo -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - """ - Compute log probs of the completion input_ids given the logits of the whole sequence. - Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts. - - Args: - logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model. - input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion. - - Returns: - torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens. - - Raises: - ValueError: If the inferred context length is less than or equal to 0. - """ - context_len = logits.shape[1] - input_ids.shape[1] - completion_len = input_ids.shape[1] - if context_len <= 0: - raise ValueError( - "Context length must be greater than 0. Otherwise the probability of the first token is undefined." - ) - - # (bsz, completion_len, vocab_size) - logits = logits[:, context_len - 1 : -1, :] - assert logits.shape == ( - input_ids.shape[0], - completion_len, - logits.shape[-1], - ), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}" - token_logprobs = torch.log_softmax(logits / temperature, dim=-1) - # (bsz, completion_len, 1) - logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1)) - # (bsz, completion_len) - logprobs = logprobs.squeeze(-1) - - return logprobs - - -# Maintained to keep Old GRPO app prior to full migration off of HF -class HuggingFaceRefModel(ForgeActor): - """ - Represents a reference actor leveraging HuggingFace for execution - """ - - def __init__(self, model_name, device: torch.device | None = None): - super().__init__() - self.model_name = model_name - - # Set device - if device is None: - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - else: - self.device = device - - # Initialize model and tokenizer - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - trust_remote_code=True, - ).to(self.device) - - # Set model to eval mode for reference computations - self.model.eval() - - self.logger.info(f"Model initialized on {self.device}") - - @endpoint - async def forward(self, token_ids: list[int]) -> torch.Tensor: - # Use provided token_ids directly - input_ids = ( - torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device) - ) - # Create attention mask of all 1s since we have actual tokens (no padding) - attention_mask = torch.ones_like(input_ids).to(self.device) - - # Compute log probabilities using shared utility function - sequence_log_probs = compute_sequence_logprobs( - self.model, input_ids, attention_mask, requires_grad=False - ) - - return ( - sequence_log_probs.squeeze() - ) # Remove batch dimension for single response - - -def compute_sequence_logprobs( - model: torch.nn.Module, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - requires_grad: bool = True, -) -> torch.Tensor: - context_manager = torch.enable_grad() if requires_grad else torch.no_grad() - - with context_manager: - outputs = model(input_ids=input_ids, attention_mask=attention_mask) - logits = outputs.logits - - # Apply log softmax to get log probabilities - log_probs = torch.log_softmax(logits, dim=-1) - - # Extract log probabilities for the actual tokens (excluding the first token for next-token prediction) - shifted_input_ids = input_ids[:, 1:] # Remove first token - shifted_log_probs = log_probs[:, :-1, :] # Remove last logit - - # Gather log probabilities for actual tokens - token_log_probs = torch.gather( - shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) - ).squeeze(-1) - - # Sum log probabilities across sequence (masked by attention) - shifted_attention_mask = attention_mask[:, 1:] - sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1) - - return sequence_log_probs - - -""" -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Experimental: DO NOT USE (YET) - -ReferenceActor: Coordinate requests to reference models -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -""" - - -@dataclass -class ReferenceActor(ForgeActor): - """ - DO NOT USE (YET) - - Not updated/used; Original plan was to use this for coordination, but - it might be overkil if we can rely on the Service Replicas to handle - the queue. - We MAY need to still do this for DP and batching support - - For now if you think you need this: directly spin up services of the - reference models - """ - - model: Model = field(default_factory=Model) - # parallelism: Parallelism = field(default_factory=Parallelism) - # comm: Comm = field(default_factory=Comm) - - # For RefModel - ref_model: ForgeActor | None = None - device: torch.device | None = None - - # For processing - running: bool = False - queue: deque | None = None - - def __post_init__(self): - """Initializes config types and env variables. - - torchrun normally hands env variables, but we need to do it ourselves - in monarch for now. - - """ - # Instantiate dict fields - for f in fields(self): - attr = getattr(self, f.name) - if isinstance(attr, Mapping): - setattr(self, f.name, f.type(**attr)) - elif not isinstance(attr, f.type): - raise TypeError( - f"{f.name} should be a {f.type} type or a dict like object" - ) - - # This might need to be changed to a distributed friendly container - # We also don't have a traditional scheduler? - self.queue = deque() - - self.rank = current_rank().rank - self.size = math.prod(current_size().values()) - - env = { - "RANK": str(self.rank), - "LOCAL_RANK": str(self.rank), - "LOCAL_WORLD_SIZE": str(self.size), - "GROUP_RANK": str(self.size), - "GROUP_WORLD_SIZE": str(self.size), - "ROLE_RANK": str(self.rank), - "ROLE_WORLD_SIZE": str(self.size), - "ROLE_NAME": "rank", - "WORLD_SIZE": str(self.size), - "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", - } - os.environ.update(env) - - @endpoint - async def setup(self): - engine_config = {f.name: getattr(self, f.name) for f in fields(self)} - self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) - - # Spawn the RefModel - self.ref_model = await spawn_service( - default_service_cfg, - HuggingFaceRefModel, - model_name=self.model.name, - device=self.device, - ) - - # Kick off background processing - self.start_processing() - - def start_processing(self): - """Start the replica's processing loop if not already running.""" - if self._run_task is None or self._run_task.done(): - self._run_task = asyncio.create_task(self.run()) - - @endpoint - async def forward(self, token_ids: list[int]) -> torch.Tensor: - """ - Enque the tokens and await response - """ - fut = asyncio.Future() - self.queue.append((token_ids, fut)) - return await fut - - async def run(self): - """ - Simple loop to pass things along to the ref model - """ - - # TODO: Consider creating a unified base class for this pattern - self.running = True - - while self.running: - request, fut = self.queue.popleft() - model_output = await self.ref_model.forward(request) - fut.set_result(model_output) - - @endpoint - async def stop(self) -> None: - self.running = False diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py new file mode 100644 index 000000000..adaa02520 --- /dev/null +++ b/src/forge/actors/reference_model.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import asyncio +import logging +import math +import os + +from collections import deque +from collections.abc import Mapping +from dataclasses import dataclass, field, fields + +from typing import Any + +import torch + +from forge.actors.trainer import compute_logprobs, Episode +from forge.controller import ForgeActor +from monarch.actor import current_rank, current_size, endpoint +from omegaconf import DictConfig, OmegaConf +from torch import nn + +from torchtitan.components.lr_scheduler import LRSchedulersContainer +from torchtitan.config.job_config import Checkpoint, Comm, Compile, Model, Parallelism +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.experiments.forge.engine import ForgeEngine +from torchtitan.experiments.forge.job_config import ForgeJobConfig +from transformers import AutoModelForCausalLM + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@dataclass +class ReferenceModel(ForgeActor): + """ + Represents a reference actor leveraging a torchtitan model for execution + + Intended for generating reference_logprobs - for example in KL Divergence + """ + + # Refer to titan JobConfig for enabling more ForgeEngine configuration + model: Model = field(default_factory=Model) + checkpoint: Checkpoint = field(default_factory=Checkpoint) + parallelism: Parallelism = field(default_factory=Parallelism) + compile: Compile = field(default_factory=Compile) + + # Populated in setup + # TODO: Commented out since engine_config parsing extracts from class members + # engine: ForgeEngine | None = None + + def __post_init__(self): + """Initializes config types and env variables.""" + # Instantiate dict fields + for f in fields(self): + attr = getattr(self, f.name) + if isinstance(attr, Mapping): + setattr(self, f.name, f.type(**attr)) + elif not isinstance(attr, f.type): + raise TypeError( + f"{f.name} should be a {f.type} type or a dict like object" + ) + + """ + torchrun normally hands env variables, but we need to do it ourselves + in monarch for now. + """ + self.rank = current_rank().rank + self.size = math.prod(current_size().values()) + + env = { + "RANK": str(self.rank), + "LOCAL_RANK": str(self.rank), + "LOCAL_WORLD_SIZE": str(self.size), + "GROUP_RANK": str(self.size), + "GROUP_WORLD_SIZE": str(self.size), + "ROLE_RANK": str(self.rank), + "ROLE_WORLD_SIZE": str(self.size), + "ROLE_NAME": "rank", + "WORLD_SIZE": str(self.size), + "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", + } + os.environ.update(env) + + @endpoint + async def setup(self): + engine_config = {f.name: getattr(self, f.name) for f in fields(self)} + self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) + + @endpoint + async def forward(self, request: list[int], response: list[int]) -> torch.Tensor: + """ + Given a request and response tokens, return the log_probability of the + token_ids, shape (completion_len, ) + + """ + model_parts = self.engine.model_parts + parallel_dims = self.engine.parallel_dims + + # Use provided token_ids directly + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + input_ids = torch.tensor( + request + response, dtype=torch.long, device=device + ).unsqueeze(0) + + optional_context_parallel_ctx = ( + dist_utils.create_context_parallel_ctx( + cp_mesh=parallel_dims.world_mesh["cp"], + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_no_restore_buffers={inputs, labels}, + cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + ) + if parallel_dims.cp_enabled + else None + ) + + if parallel_dims.pp_enabled: + raise NotImplementedError("PP not implemented yet") + else: + # (jackkhuu) Not sure if either context are needed for inference here + with self.engine.train_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + with self.engine.maybe_enable_amp: + logits = model_parts[0](input_ids) + + # Compute logprobs + input_ids = input_ids[:, len(request) :] + # (bsz=1, completion_len) + logprobs = compute_logprobs(logits, input_ids) + # (completion_len, ) + return logprobs.squeeze(0) + + return pred + + +class HFReferenceModel(ForgeActor): + def __init__(self, model_name, device: torch.device | None = None): + super().__init__() + self.model_name = model_name + + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = device + + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + self.model.eval() + + self.logger.info(f"Model initialized on {self.device}") + + @endpoint + async def forward(self, episode: Episode) -> torch.Tensor: + req, res = episode.request_tensor, episode.response_tensor + input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0) + mask = input_ids != episode.pad_id + + with torch.inference_mode(): + logits = self.model(input_ids=input_ids, attention_mask=mask).logits + + input_ids = input_ids[:, len(req) :] + return compute_logprobs(logits, input_ids) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4232ca5ca..e416bdead 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -10,8 +10,12 @@ import os from collections.abc import Mapping from dataclasses import dataclass, field, fields +from typing import Any, Optional import torch +import torch.nn.functional as F + +from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint from torchtitan.config.job_config import ( ActivationCheckpoint, @@ -30,12 +34,85 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + """ + Compute log probs of the completion input_ids given the logits of the whole sequence. + Warning: only works if all prompts in the batch have the same length. TODO: support variable length prompts. + + Args: + logits (torch.Tensor): (batch_size, seq_len, vocab_size), the logits output from the model. + input_ids (torch.Tensor): (batch_size, completion_len), the token ids for the completion. + + Returns: + torch.Tensor: (batch_size, completion_len), the log probabilities of the completion tokens. + + Raises: + ValueError: If the inferred context length is less than or equal to 0. + """ + context_len = logits.shape[1] - input_ids.shape[1] + completion_len = input_ids.shape[1] + if context_len <= 0: + raise ValueError( + "Context length must be greater than 0. Otherwise the probability of the first token is undefined." + ) + + # (bsz, completion_len, vocab_size) + logits = logits[:, context_len - 1 : -1, :] + assert logits.shape == ( + input_ids.shape[0], + completion_len, + logits.shape[-1], + ), f"logits shape incorrect, {logits.shape=}, {input_ids.shape=}, {logits.shape[-1]=}" + token_logprobs = torch.log_softmax(logits / temperature, dim=-1) + # (bsz, completion_len, 1) + logprobs = torch.gather(token_logprobs, 2, input_ids.unsqueeze(-1)) + # (bsz, completion_len) + logprobs = logprobs.squeeze(-1) + + return logprobs + + +@dataclass +class Episode: + # TODO: add adtional layer for multi-turn + episode_id: str + request: str + policy_version: int + pad_id: int + request_len: int + response_len: int + target: Optional[Any] = None + # processed data + response: Optional[str] = None + request_tokens: Optional[list[int]] = None + response_tokens: Optional[list[int]] = None + ref_logprobs: Optional[torch.Tensor] = None + reward: Optional[float] = None + advantage: Optional[float] = None + + @property + def request_tensor(self): + tensor = torch.tensor(self.request_tokens, dtype=torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self): + tensor = torch.tensor(self.response_tokens, dtype=torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + @dataclass class RLTrainer(ForgeActor): model: Model = field(default_factory=Model)