Skip to content

Commit 0d51f65

Browse files
authored
Fix micro_batch_size (#19)
1 parent 5d0389d commit 0d51f65

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

.github/workflows/unittest.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,23 @@ jobs:
2424
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
2525
run: |
2626
export UID
27-
export GID
27+
export GID=$(id -g)
2828
docker compose up -d
2929
sleep 15s
3030
3131
- name: Check ray status
3232
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
3333
run: |
34+
export UID
35+
export GID=$(id -g)
3436
docker compose exec trinity-node-1 ray status
3537
docker compose exec trinity-node-2 ray status
3638
3739
- name: Run unittest
3840
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
3941
run: |
42+
export UID
43+
export GID=$(id -g)
4044
docker compose exec trinity-node-1 pytest tests --ignore=tests/data --ctrf report.json
4145
continue-on-error: true
4246

trinity/buffer/reader/sql_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,6 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
7474
exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
7575
logger.info(f"get {len(exp_list)} experiences:")
7676
logger.info(f"reward = {[exp.reward for exp in exp_list]}")
77-
logger.info(f"fisrt prompt_text = {exp_list[0].prompt_text}")
77+
logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
7878
logger.info(f"first response_text = {exp_list[0].response_text}")
7979
return exp_list

trinity/common/verl_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class Checkpoint:
6868
class Actor:
6969
strategy: str = "fsdp"
7070
ppo_mini_batch_size: int = 256
71+
ppo_micro_batch_size: Optional[int] = None
7172
ppo_micro_batch_size_per_gpu: int = 1
7273
use_dynamic_bsz: bool = False
7374
ppo_max_token_len_per_gpu: int = (
@@ -94,6 +95,7 @@ class Actor:
9495
@dataclass
9596
class Ref:
9697
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
98+
log_prob_micro_batch_size: Optional[int] = None
9799
log_prob_micro_batch_size_per_gpu: int = 1
98100
log_prob_use_dynamic_bsz: bool = False
99101
log_prob_max_token_len_per_gpu: int = 0
@@ -119,6 +121,7 @@ class Rollout:
119121
max_num_batched_tokens: int = 8192
120122
max_model_len: Optional[int] = None
121123
max_num_seqs: int = 1024
124+
log_prob_micro_batch_size: Optional[int] = None
122125
log_prob_micro_batch_size_per_gpu: int = 1
123126
log_prob_use_dynamic_bsz: bool = False
124127
log_prob_max_token_len_per_gpu: int = 0
@@ -155,6 +158,7 @@ class Critic:
155158
optim: Optim = field(default_factory=Optim)
156159
model: CriticModel = field(default_factory=CriticModel)
157160
ppo_mini_batch_size: int = 0
161+
ppo_micro_batch_size: Optional[int] = None
158162
ppo_micro_batch_size_per_gpu: int = 1
159163
forward_micro_batch_size: Optional[int] = None
160164
forward_micro_batch_size_per_gpu: Optional[int] = None

0 commit comments

Comments
 (0)