File tree Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change 2727 Checkpoint ,
2828 Comm ,
2929 Compile ,
30- Float8Linear ,
30+ Job ,
3131 LRScheduler ,
32+ MemoryEstimation ,
3233 Model ,
3334 Optimizer ,
3435 Parallelism ,
36+ Quantize ,
3537 Training ,
3638)
3739from torchtitan .experiments .forge .engine import ForgeEngine
@@ -93,6 +95,7 @@ def cleanup_old_weight_versions(
9395
9496@dataclass
9597class RLTrainer (ForgeActor ):
98+ job : Job = field (default_factory = Job )
9699 model : Model = field (default_factory = Model )
97100 optimizer : Optimizer = field (default_factory = Optimizer )
98101 lr_scheduler : LRScheduler = field (default_factory = LRScheduler )
@@ -102,15 +105,17 @@ class RLTrainer(ForgeActor):
102105 activation_checkpoint : ActivationCheckpoint = field (
103106 default_factory = ActivationCheckpoint
104107 )
105- use_vllm_builtin_load : bool = True
106108 compile : Compile = field (default_factory = Compile )
107- float8 : Float8Linear = field (default_factory = Float8Linear )
109+ quantize : Quantize = field (default_factory = Quantize )
108110 comm : Comm = field (default_factory = Comm )
111+ memory_estimation : MemoryEstimation = field (default_factory = MemoryEstimation )
112+ # Non JobConfig-related fields
109113 loss : Callable = lambda logits , ** targets : logits
110114 state_dict_key : str = "model_state_dict"
111115 use_dcp : bool = True
112116 dcp_path : str = "forge_dcp_tmp"
113117 vllm_tp_DEPRECATED : int = 1 # noqa: N815
118+ use_vllm_builtin_load : bool = True
114119
115120 def __post_init__ (self ):
116121 """Initializes config types and env variables.
You can’t perform that action at this time.
0 commit comments