Skip to content

Conversation

Jack-Khuu
Copy link
Contributor

@Jack-Khuu Jack-Khuu commented Aug 29, 2025

Replaces the existing HF based RefModel with a torchtitan backed TitanRefModel

Note: This PR persists the old implementation using AutoModelForCausalLM until GRPO is ready for full migration to torchtitan


python apps/grpo/main.py

Example Output

Generated 60 rollouts w/ average reward 0.0
Generated 70 rollouts w/ average reward 0.1
Generated 80 rollouts w/ average reward 0.1
Completed 20 training steps
Latest loss: 75.34173269197345
Generated 90 rollouts w/ average reward 0.1
Generated 100 rollouts w/ average reward 0.0
Generated 110 rollouts w/ average reward 0.1
Generated 120 rollouts w/ average reward 0.1
Completed 30 training steps
Latest loss: 32.164904564619064
Generated 130 rollouts w/ average reward 0.0
Generated 140 rollouts w/ average reward 0.1
Generated 150 rollouts w/ average reward 0.1
Generated 160 rollouts w/ average reward 0.1
Completed 40 training steps
Latest loss: 18.617477774620056
Generated 170 rollouts w/ average reward 0.1
Generated 180 rollouts w/ average reward 0.1
Generated 190 rollouts w/ average reward 0.1
Generated 200 rollouts w/ average reward 0.1
Completed 50 training steps
Latest loss: 9.065430700778961

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 29, 2025
@Jack-Khuu Jack-Khuu requested a review from pbontrager August 29, 2025 00:44
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))

@endpoint
async def forward(self, token_ids: list[int]) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can keep things bs=1 for now, but we're going to need to add our own batching solution where we take n requests off the queue at a time and process them as a batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batching itself is an easy add using the queue in ReferenceActor (the one in the PR that's unused) or we can surface it as a common API in Replica (Just processes_single_request atm)

@Jack-Khuu Jack-Khuu changed the title [WIP] Creating ReferenceActor [WIP] Adds TitanRefModel in place of HF based Reference Model Sep 1, 2025
if request_output.finished:
_, fut = self.requests.pop(request_output.request_id)
fut.set_result(request_output.outputs)
fut.set_result(request_output)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adopted from #97

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this instead of raw outputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pragmatically: Less merge conflict with Philip's PR

I don't have strong preference, but it does make the output self contained which is nice when we need to pass the results around

@Jack-Khuu Jack-Khuu changed the title [WIP] Adds TitanRefModel in place of HF based Reference Model Adds TitanRefModel in place of HF based Reference Model Sep 1, 2025
@Jack-Khuu Jack-Khuu marked this pull request as ready for review September 1, 2025 20:26
from .replay_buffer import ReplayBuffer

return ReplayBuffer
elif name == "TitanRefModel":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙃

if request_output.finished:
_, fut = self.requests.pop(request_output.request_id)
fut.set_result(request_output.outputs)
fut.set_result(request_output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this instead of raw outputs?

@Jack-Khuu Jack-Khuu merged commit 46deb59 into main Sep 1, 2025
4 checks passed
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post land review



@dataclass
class TitanRefModel(ForgeActor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On naming: I would just call this ReferenceModel, we don't use Titan anywhere else and we havn't generalized it yet


# Refer to titan JobConfig for enabling more ForgeEngine configuration
model: Model = field(default_factory=Model)
parallelism: Parallelism = field(default_factory=Parallelism)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we would need Checkpoint and Compile too. Or is checkpoint not needed for loading?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed by default, Checkpointing is used when we want to write out or load from a checkpoint

Former doesn't happen with the reference model, the latter does unlock custom ref models (or using checkpoints as ref models)

if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet")
else:
# (jackkhuu) Not sure if either context are needed for inference here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you separate out the request call from forward? Also look at the updated RefActor in main to make sure they're in sync

return sequence_log_probs


"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we comment out everything below this?

) # Remove batch dimension for single response


def compute_sequence_logprobs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should update this to match the one from main. This should probably go in the trainer.py file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copy paste from main (at the time Joe's)

compute_logprobs is the file is the implementation from #97

@Jack-Khuu
Copy link
Contributor Author

Follow up changes in #118

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants