-
Notifications
You must be signed in to change notification settings - Fork 15
Add basic API docs #336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add basic API docs #336
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #336 +/- ##
=======================================
Coverage ? 64.72%
=======================================
Files ? 79
Lines ? 7707
Branches ? 0
=======================================
Hits ? 4988
Misses ? 2719
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good! I just had some nits and minor suggestions for filling in docstrings that we didn't include before.
src/forge/actors/reference_model.py
Outdated
@dataclass | ||
class ReferenceModel(ForgeActor): | ||
""" | ||
Reference model implementation for the TorchForge service. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggested docstring:
"""A reference model actor for reinforcement learning (RL) training.
Based on TorchTitan's engine architecture, this actor provides a frozen model that only
runs forward passes without gradient computation. It is typically used to maintain
algorithmic consistency in policy optimization methods such as GRPO (Group Relative
Policy Optimization) or PPO (Proximal Policy Optimization), where it serves as a
fixed reference point to compute KL divergence penalties against the training policy.
The reference model is loaded from a checkpoint and runs in evaluation mode with
inference_mode enabled to optimize memory and compute efficiency.
Attributes:
model (Model): Model configuration (architecture, vocab size, etc.)
parallelism (Parallelism): Parallelism strategy configuration (TP, PP, CP, DP)
checkpoint (Checkpoint): Checkpoint loading configuration
compile (Compile): Torch compilation settings
comm (Comm): Communication backend configuration
training (Training): Training-related settings (dtype, garbage collection, etc.)
"""
src/forge/actors/trainer.py
Outdated
@dataclass | ||
class RLTrainer(ForgeActor): | ||
""" | ||
RL Trainer implementation for the TorchForge service. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggested docstring:
"""A reinforcement learning trainer actor for policy optimization training.
Built on top of TorchTitan's training engine, this actor provides a complete training
loop for reinforcement learning. It performs forward and backward passes with
gradient computation, optimization steps, and checkpoint management.
Unlike the ReferenceModel actor which only runs forward passes, RLTrainer actively
updates the policy model parameters through gradient descent.
The trainer supports the same distributed distributed training strategies that TorchTitan
does, including but not limited to, tensor parallelism, data parallelism, and
FSDP (Fully Sharded Data Parallel).
It is typically used in conjunction with ReferenceModel for policy optimization algorithms
like GRPO (Group Relative Policy Optimization), where it optimizes the policy against a
loss that includes KL divergence penalties from the reference model.
The trainer handles:
- Forward and backward propagation with automatic mixed precision (AMP)
- Optimizer steps with learning rate scheduling
- Distributed checkpoint saving and loading
- Weight synchronization via torchstore for distributed inference
- Memory management with garbage collection
Attributes:
job (Job): Job configuration (name, dump path, etc.)
model (Model): Model configuration (architecture, vocab size, etc.)
optimizer (Optimizer): Optimizer configuration (type, learning rate, etc.)
lr_scheduler (LRScheduler): Learning rate scheduler configuration
training (Training): Training settings (steps, batch size, dtype, etc.)
parallelism (Parallelism): Parallelism strategy configuration (TP, PP, CP, DP)
checkpoint (Checkpoint): Checkpoint loading and saving configuration
activation_checkpoint (ActivationCheckpoint): Activation checkpointing settings
compile (Compile): Torch compilation settings
quantize (Quantize): Quantization settings
comm (Comm): Communication backend configuration
memory_estimation (MemoryEstimation): Memory profiling configuration
loss (Callable): Loss function to compute training loss from logits and targets
state_dict_key (str): Key for state dict storage in torchstore
use_dcp (bool): Whether to use distributed checkpoint (DCP) format
dcp_path (str): Path for DCP storage
"""
Co-authored-by: Allen Wang <[email protected]>
torchmonarch
instead of a custom monarch wheel, and sets up library and CUDA paths to ensure all dependencies (especially native ones) are available.ForgeActor
class now includes detailed docstrings for its resource attributes (procs
,hosts
,with_gpus
,num_replicas
,mesh_name
), and theoptions
method provides example usage in the docstring. This makes it easier for users to understand resource configuration for distributed training.Service.stop
methodService
methods are copied toServiceActor
methods, ensuring complete documentation for Sphinx autodoc.