Skip to content

Commit edb83cf

Browse files
authored
[ci] add the recipe/ directory to pre-commit hooks (#580)
1 parent 48ae2ee commit edb83cf

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ repos:
4343
name: Run Linter Check (Ruff)
4444
types_or: [ python, pyi, jupyter ]
4545
args: [ --fix ]
46-
files: ^(areal|examples|docs)/
46+
files: ^(areal|examples|docs|recipe)/
4747
- id: ruff-format # Run the formatter.
4848
name: Run Formatter (Ruff)
4949
types_or: [ python, pyi, jupyter ]
50-
files: ^(areal|examples|docs)/
50+
files: ^(areal|examples|docs|recipe)/
5151

5252
# Clean notebook outputs and metadata
5353
- repo: https://github.com/kynan/nbstripout

recipe/AEnt/actor.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
11
import functools
2-
from typing import Dict, List, Any
2+
from typing import Any
33

44
import torch
55

6+
from recipe.AEnt.aent_args import AEntPPOActorConfig
7+
from recipe.AEnt.functional import gather_logprobs_clamped_entropy
68

7-
from areal.api.cli_args import MicroBatchSpec, PPOActorConfig
9+
from areal.api.cli_args import MicroBatchSpec
810
from areal.api.engine_api import TrainEngine
911
from areal.engine.fsdp_engine import FSDPEngine
1012
from areal.engine.ppo.actor import PPOActor
1113
from areal.utils import stats_tracker
1214
from areal.utils.data import split_padded_tensor_dict_into_mb_list
1315
from areal.utils.functional import (
1416
dynamic_sampling,
15-
gather_logprobs,
1617
gather_logprobs_entropy,
1718
ppo_actor_loss_fn,
18-
reward_overlong_penalty,
1919
)
20-
from recipe.AEnt.aent_args import AEntPPOActorConfig
21-
from recipe.AEnt.functional import gather_logprobs_clamped_entropy
2220

2321

2422
class AEntPPOActor(PPOActor):
25-
2623
def __init__(self, config: AEntPPOActorConfig, engine: TrainEngine):
2724
super().__init__(config, engine)
2825
self.entropy_coeff = config.aent.entropy_coeff
@@ -39,7 +36,7 @@ def __init__(self, config: AEntPPOActorConfig, engine: TrainEngine):
3936
@stats_tracker.scope_func_wrapper("aent_ppo_actor")
4037
def aent_ppo_update(
4138
self, data: dict[str, Any], global_step: int
42-
) -> List[Dict[str, float]]:
39+
) -> list[dict[str, float]]:
4340
with stats_tracker.scope("dynamic_sampling"):
4441
if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0:
4542
data, sampling_stat = dynamic_sampling(data, self.group_size)
@@ -156,7 +153,6 @@ def aent_ppo_update(
156153

157154

158155
class FSDPAEntPPOActor(FSDPEngine):
159-
160156
def __init__(self, config: AEntPPOActorConfig):
161157
super().__init__(config)
162158
self.actor = AEntPPOActor(config, self)
@@ -169,14 +165,14 @@ def compute_logp(self, *args, **kwargs) -> torch.Tensor | None:
169165
def compute_advantages(self, *args, **kwargs) -> None:
170166
self.actor.compute_advantages(*args, **kwargs)
171167

172-
def aent_ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
168+
def aent_ppo_update(self, *args, **kwargs) -> list[dict[str, float]]:
173169
return self.actor.aent_ppo_update(*args, **kwargs)
174170

175171

176172
# AEnt regularized grpo loss
177173
def aent_grpo_loss_fn(
178174
logits: torch.Tensor,
179-
input_data: Dict,
175+
input_data: dict,
180176
temperature: float,
181177
eps_clip: float,
182178
entropy_coeff: float,

recipe/AEnt/gsm8k_aent_grpo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
import torch.distributed as dist
66
from torchdata.stateful_dataloader import StatefulDataLoader
77

8+
from recipe.AEnt.actor import FSDPAEntPPOActor
9+
from recipe.AEnt.aent_args import AEntGRPOConfig
10+
811
from areal.api.alloc_mode import AllocationMode
9-
from areal.api.cli_args import GRPOConfig, load_expr_config
12+
from areal.api.cli_args import load_expr_config
1013
from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta
1114
from areal.dataset import get_custom_dataset
1215
from areal.engine.ppo.actor import FSDPPPOActor
@@ -15,7 +18,6 @@
1518
from areal.utils import seeding, stats_tracker
1619
from areal.utils.data import (
1720
broadcast_tensor_container,
18-
cycle_dataloader,
1921
tensor_container_to,
2022
)
2123
from areal.utils.device import log_gpu_stats
@@ -25,8 +27,6 @@
2527
from areal.utils.saver import Saver
2628
from areal.utils.stats_logger import StatsLogger
2729
from areal.workflow.rlvr import RLVRWorkflow
28-
from recipe.AEnt.actor import FSDPAEntPPOActor
29-
from recipe.AEnt.aent_args import AEntGRPOConfig
3030

3131

3232
def gsm8k_reward_fn(prompt, completions, prompt_ids, completion_ids, answer, **kwargs):

0 commit comments

Comments
 (0)