Skip to content

Commit eb6b5dd

Browse files
authored
[fix] revert reward update and evaluation (#6295)
* Revert "rewrite reward fn" This reverts commit d06042b. * Revert "upgrade reward math verification" This reverts commit a6085ff. * Revert "fix bug" This reverts commit 01640eb. * Revert "reuse comm-group" This reverts commit bd61918. * Revert "Support evaluation during training" This reverts commit 57a8839.
1 parent 17928ad commit eb6b5dd

File tree

9 files changed

+82
-307
lines changed

9 files changed

+82
-307
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,3 @@ applications/ColossalChat/logs
165165
applications/ColossalChat/tests/logs
166166
applications/ColossalChat/wandb
167167
applications/ColossalChat/model
168-
applications/ColossalChat/eval

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def __init__(
3636
minibatch_size: int = 1,
3737
save_interval: int = 100,
3838
save_dir: str = "./model",
39-
eval_interval: int = -1,
4039
):
4140
self.num_producers = num_producers
4241
self.num_episodes = num_episodes
@@ -52,7 +51,6 @@ def __init__(
5251
self.save_dir = save_dir
5352
assert batch_size % minibatch_size == 0, "batch_size should be divisible by microbatch_size"
5453
self.num_microbatches = batch_size // minibatch_size
55-
self.eval_interval = eval_interval
5654

5755
self.model_config = model_config
5856
self.plugin_config = plugin_config
@@ -95,6 +93,7 @@ def setup(self) -> None:
9593
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
9694

9795
self.buffer = []
96+
9897
self.recv_cnt = 0
9998

10099
def state_dict(self) -> Dict[str, torch.Tensor]:
@@ -111,27 +110,6 @@ def loop(self) -> None:
111110
with tqdm(range(self.num_update_per_episode), desc=f"Episode {episode}", disable=self.rank != 0) as pbar:
112111
for step in pbar:
113112
i = 0
114-
if self.eval_interval > 0 and step % self.eval_interval == 0:
115-
eval_statistics = None
116-
eval_global_step = None
117-
for r in range(self.num_producers):
118-
print(f"[T{dist.get_rank()}] Recv eval result episode {episode} step {step} from {r}")
119-
local_eval_result = ray_broadcast_tensor_dict(
120-
None, src=0, device=self.device, group_name=f"sync_data_{r}"
121-
)
122-
assert "consumer_global_step" in local_eval_result
123-
eval_global_step = local_eval_result.pop("consumer_global_step").item()
124-
if eval_statistics is None:
125-
eval_statistics = local_eval_result
126-
else:
127-
eval_statistics = {
128-
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
129-
}
130-
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
131-
if dist.get_rank() == 0:
132-
if hasattr(self, "wandb_run"):
133-
self.wandb_run.log(eval_statistics, step=eval_global_step)
134-
print(f"Eval statistics: {eval_statistics}")
135113
for _ in range(self.num_recv_per_update):
136114
# receive data from producers
137115
for r in range(self.num_producers):
@@ -217,7 +195,6 @@ def __init__(
217195
minibatch_size=1,
218196
save_interval: int = 100,
219197
save_dir="./model",
220-
eval_interval: int = -1,
221198
):
222199
super().__init__(
223200
num_producers,
@@ -232,9 +209,6 @@ def __init__(
232209
model_config,
233210
plugin_config,
234211
minibatch_size,
235-
save_interval,
236-
save_dir,
237-
eval_interval,
238212
)
239213
path = model_config.pop("path")
240214
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(
4040
project_name=None,
4141
save_interval: int = 100,
4242
save_dir="./model",
43-
eval_interval: int = -1,
4443
):
4544
print(f"Using GRPO config: {grpo_config}")
4645
if grpo_config.get("loss_variation", "sample_level") == "token_level":
@@ -73,7 +72,6 @@ def __init__(
7372
minibatch_size,
7473
save_interval=save_interval,
7574
save_dir=save_dir,
76-
eval_interval=eval_interval,
7775
)
7876
path = model_config.pop("path")
7977
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
@@ -530,5 +528,4 @@ def state_dict(self):
530528
self.policy_model._force_wait_all_gather()
531529
model = self.policy_model.unwrap()
532530
state_dict = model.state_dict()
533-
state_dict["consumer_global_step"] = torch.tensor([self.global_step], device=self.device)
534531
return state_dict

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def __init__(
205205
generate_config = generate_config.copy()
206206
generate_config.update(self.FORCE_GENERATE_CONFIG)
207207
generate_config.update({"n": num_generations})
208-
self.generate_config = generate_config
209-
self.sample_params = SamplingParams(**generate_config)
208+
self.generate_config = SamplingParams(**generate_config)
210209
self.model_config = model_config
211210
self.tokenizer = tokenizer
212211
self.num_generations = num_generations
@@ -220,9 +219,8 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
220219
micro_batch_input_ids_no_padding = [
221220
micro_batch_input_ids[i][first_non_padding_token_idx[i] :] for i in range(micro_batch_size)
222221
]
223-
sample_params = kwargs.get("sample_params", self.sample_params)
224222
outputs = self.llm.generate(
225-
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=sample_params, use_tqdm=False
223+
prompt_token_ids=micro_batch_input_ids_no_padding, sampling_params=self.generate_config, use_tqdm=False
226224
)
227225
out_tokens = []
228226
out_len = []
@@ -268,11 +266,11 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
268266
"response_idx": response_idx,
269267
}
270268

271-
data = {k: v.view(micro_batch_size, -1, v.size(-1)) for k, v in data.items()}
269+
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
272270

273271
if "gt_answer" in kwargs:
274272
# repeat gt_answer for each prompt.
275-
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(data["input_ids"].size(1), dim=1)
273+
data["gt_answer"] = kwargs["gt_answer"].repeat_interleave(self.num_generations, dim=1)
276274
data = {k: v.to(get_current_device()) for k, v in data.items()}
277275
return data
278276

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def launch_distributed(
3434
inference_microbatch_size: int,
3535
train_batch_size: int,
3636
train_minibatch_size: int,
37-
train_dataset_config: Dict[str, Any],
37+
dataset_config: Dict[str, Any],
3838
dataloaders_config: Dict[str, Any],
3939
inference_model_config: Dict[str, Any],
4040
generate_config: Dict[str, Any],
@@ -50,9 +50,6 @@ def launch_distributed(
5050
project_name: Optional[str] = None,
5151
save_interval: int = 100,
5252
save_dir: str = "./model",
53-
eval_dataset_config: Optional[Dict[str, Any]] = None,
54-
eval_interval: int = 100,
55-
eval_save_dir: Optional[str] = None,
5653
):
5754

5855
if core_algo not in ALGO_MAP:
@@ -63,9 +60,9 @@ def launch_distributed(
6360
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
6461
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
6562

66-
dataset_path = train_dataset_config["path"]
63+
dataset_path = dataset_config["path"]
6764
num_samples = get_jsonl_size_fast(dataset_path)
68-
global_inference_batch_size = inference_batch_size * num_producers # TODO: this doesn't support TP on producer
65+
global_inference_batch_size = inference_batch_size * num_producers
6966
num_update_per_episode = num_samples // global_inference_batch_size
7067
num_recv_per_update = inference_batch_size // inference_microbatch_size
7168

@@ -77,7 +74,7 @@ def launch_distributed(
7774
num_consumer_procs=num_consumer_procs,
7875
num_episodes=num_episodes,
7976
batch_size=inference_batch_size,
80-
train_dataset_config=train_dataset_config,
77+
dataset_config=dataset_config,
8178
dataloaders_config=dataloaders_config,
8279
model_config=inference_model_config,
8380
generate_config=generate_config,
@@ -86,10 +83,6 @@ def launch_distributed(
8683
backend=inference_backend,
8784
num_generations=num_generations,
8885
consumer_plugin_config=plugin_config,
89-
eval_dataset_config=eval_dataset_config,
90-
eval_interval=eval_interval * num_recv_per_update,
91-
evaluation_function_type=grpo_config["reward_fn_type"],
92-
eval_save_dir=eval_save_dir,
9386
)
9487
procs.append(producer)
9588
generate_config_consumer = copy.deepcopy(generate_config)
@@ -118,7 +111,6 @@ def launch_distributed(
118111
project_name=project_name,
119112
save_interval=save_interval,
120113
save_dir=save_dir,
121-
eval_interval=eval_interval,
122114
)
123115
procs.append(consumer)
124116
ray.get([p.setup.remote() for p in procs])

0 commit comments

Comments
 (0)