Skip to content

Commit 87395aa

Browse files
committed
More lint
1 parent d3b2a12 commit 87395aa

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212

1313
import torch
1414
import torch.nn.functional as F
15-
from torchtitan.config.job_config import Model as TitanJobModelConfig
1615
from datasets import load_dataset
1716
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
18-
# from forge.actors.reference_model import ReferenceModel
17+
from forge.actors.reference_model import ReferenceModel # noqa: F401
1918
from forge.actors.replay_buffer import ReplayBuffer
2019
from forge.controller.actor import ForgeActor
2120
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2221
from forge.data.rewards import MathReward, ThinkingReward
2322
from forge.util.metric_logging import get_metric_logger
2423
from monarch.actor import endpoint
2524
from torch import nn
25+
from torchtitan.config.job_config import Model as TitanJobModelConfig
2626
from transformers import AutoModelForCausalLM
2727
from vllm.transformers_utils.tokenizer import get_tokenizer
2828

src/forge/actors/reference_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
from monarch.actor import current_rank, current_size, endpoint
1717

18-
from torchtitan.config.job_config import Compile, Checkpoint, Model, Parallelism
18+
from torchtitan.config.job_config import Checkpoint, Compile, Model, Parallelism
1919
from torchtitan.distributed import utils as dist_utils
2020
from torchtitan.experiments.forge.engine import ForgeEngine
2121
from torchtitan.experiments.forge.job_config import ForgeJobConfig
@@ -84,7 +84,7 @@ async def setup(self):
8484
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
8585

8686
@endpoint
87-
async def forward(self, episode: 'Episode') -> torch.Tensor:
87+
async def forward(self, episode: "Episode") -> torch.Tensor:
8888
"""
8989
Given an episode, return the log_probability of the
9090
token_ids, shape (completion_len, )

0 commit comments

Comments
 (0)