Skip to content

Commit bc0171d

Browse files
committed
fix transformers backend
1 parent 57b49da commit bc0171d

File tree

3 files changed

+34
-10
lines changed

3 files changed

+34
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,4 @@ coverage.xml
163163
# log, test files - ColossalChat
164164
applications/ColossalChat/logs
165165
applications/ColossalChat/tests/logs
166+
applications/ColossalChat/wandb

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,22 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
6161
self.generate_config = generate_config.copy()
6262
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
6363
self.tokenizer = tokenizer
64+
self.num_generations = 8
6465

6566
@torch.no_grad()
6667
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
68+
micro_batch_size = input_ids.size(0)
6769
input_ids = input_ids.to(get_current_device())
6870
attention_mask = attention_mask.to(get_current_device())
69-
out = self.model.generate(input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config)
71+
gt_answer = None
72+
if "gt_answer" in kwargs:
73+
gt_answer = kwargs.pop("gt_answer")
74+
if self.num_generations > 1:
75+
input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
76+
attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)
77+
out = self.model.generate(
78+
input_ids, attention_mask=attention_mask, **kwargs, **self.generate_config, tokenizer=self.tokenizer
79+
)
7080
input_len = input_ids.shape[-1]
7181
new_token_ids = out.sequences[:, input_len:]
7282
# get log probs
@@ -76,10 +86,13 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
7686
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
7787
action_log_probs = torch.cat(action_log_probs, dim=1)
7888
# get action mask
89+
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
7990
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
8091
if self.tokenizer.eos_token_id is not None:
8192
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
8293
action_mask[indices[0], indices[1] + 1 :] = 0
94+
response_idx[:, 0] = input_len
95+
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
8396

8497
if attention_mask.size(0) != action_mask.size(0):
8598
assert action_mask.size(0) % attention_mask.size(0) == 0
@@ -91,7 +104,15 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
91104
"attention_mask": attention_mask,
92105
"action_log_probs": action_log_probs,
93106
"action_mask": action_mask,
107+
"response_idx": response_idx,
94108
}
109+
110+
data = {k: v.view(micro_batch_size, self.num_generations, v.size(-1)) for k, v in data.items()}
111+
112+
if gt_answer is not None:
113+
# repeat gt_answer for each prompt.
114+
data["gt_answer"] = gt_answer.repeat_interleave(self.num_generations, dim=1)
115+
data = {k: v.to(get_current_device()) for k, v in data.items()}
95116
return data
96117

97118
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:

applications/ColossalChat/rl_example.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
1111
parser.add_argument("-t", "--num-trainers", type=int, default=2)
1212
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
13-
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=32)
14-
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=16)
15-
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
13+
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
14+
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
15+
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
1616
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=1)
1717
parser.add_argument("-b", "--backend", type=str, default="transformers")
1818
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GRPO"])
@@ -24,29 +24,31 @@
2424
train_model_config = dict(path=args.model)
2525
generate_config = dict(
2626
top_k=50,
27-
top_p=0.8,
27+
top_p=0.9,
28+
temperature=1.0,
2829
)
2930

3031
if args.backend == "transformers":
3132
inference_model_config.update(
3233
dict(
33-
attn_implementation="flash_attention_2",
34+
use_flash_attention_2=True,
3435
torch_dtype=torch.bfloat16,
3536
)
3637
)
3738
train_model_config.update(
3839
dict(
39-
attn_implementation="flash_attention_2",
40+
use_flash_attention_2=True,
4041
torch_dtype=torch.bfloat16,
4142
use_cache=False,
4243
)
4344
)
4445
generate_config.update(
4546
dict(
46-
max_length=512,
47+
max_length=1024 + 512,
4748
do_sample=True,
4849
max_new_tokens=None,
4950
early_stopping=False,
51+
stop_strings=["</answer>"],
5052
)
5153
)
5254
elif args.backend == "vllm":
@@ -82,12 +84,12 @@
8284
num_producers=args.num_inferencer,
8385
num_proc_per_producer=1,
8486
num_consumer_procs=args.num_trainers,
85-
num_episodes=1,
87+
num_episodes=10,
8688
inference_batch_size=args.inference_batch_size,
8789
inference_microbatch_size=args.inference_microbatch_size,
8890
train_batch_size=args.train_batch_size,
8991
train_microbatch_size=args.train_microbatch_size,
90-
dataset_config={"path": args.dataset, "max_length": 256},
92+
dataset_config={"path": args.dataset, "max_length": 300},
9193
dataloaders_config={},
9294
inference_model_config=inference_model_config,
9395
generate_config=generate_config,

0 commit comments

Comments
 (0)