Skip to content

Commit b920af4

Browse files
TongLi3701Tong Li
andauthored
update pad seq (#6303)
Co-authored-by: Tong Li <[email protected]>
1 parent eb6b5dd commit b920af4

File tree

4 files changed

+3
-28
lines changed

4 files changed

+3
-28
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from colossalai.utils import get_current_device
1717

1818
from .comm import ray_broadcast_tensor_dict
19-
from .utils import bind_batch, pad_batch, post_recv, unbind_batch
19+
from .utils import bind_batch, post_recv, unbind_batch
2020

2121

2222
class BaseConsumer:
@@ -125,9 +125,6 @@ def loop(self) -> None:
125125
batches = self.buffer[
126126
self.dp_rank * self.minibatch_size : (self.dp_rank + 1) * self.minibatch_size
127127
]
128-
batch = pad_batch(
129-
batches
130-
) # when `imbs` is smaller than `tMbs`, samples may have differ in size, need to pad before stacking
131128
batch = bind_batch(batches)
132129
batch = post_recv(batch)
133130
loss, num_excessive_prompts = self.step(i, pbar, **batch)

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
236236
log_probs.append(p)
237237

238238
# pad them
239-
max_len = max(out_len)
239+
max_len = self.generate_config.max_tokens
240240
action_mask = torch.ones(len(out_tokens), max_len, dtype=attention_mask.dtype)
241241

242242
for i, new_token_ids in enumerate(out_tokens):

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(
7979
else:
8080
raise ValueError(f"Unexpected backend {backend}")
8181

82-
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
82+
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
8383

8484
def setup(self) -> None:
8585
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from collections import defaultdict
21
from typing import Any, Dict, List
32

43
import torch
@@ -27,27 +26,6 @@ def bind_batch(batches: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor
2726
return batch
2827

2928

30-
def pad_batch(batches: List[Dict[str, torch.Tensor]], tokenizer: Any = None) -> List[Dict[str, torch.Tensor]]:
31-
max_len = defaultdict(int)
32-
for sample in batches:
33-
for k in sample:
34-
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
35-
max_len[k] = max(max_len[k], sample[k].size(-1))
36-
for idx, sample in enumerate(batches):
37-
for k in sample:
38-
if k in ["input_ids", "attention_mask", "action_log_probs", "action_mask"]:
39-
# right pad with 0s
40-
if k in ["attention_mask", "action_mask"]:
41-
batches[idx][k] = torch.nn.functional.pad(
42-
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", False
43-
)
44-
else:
45-
batches[idx][k] = torch.nn.functional.pad(
46-
batches[idx][k], (0, max_len[k] - batches[idx][k].size(-1)), "constant", 0
47-
)
48-
return batches
49-
50-
5129
def pre_send(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
5230
# compress mask to save bandwidth
5331
if "attention_mask" in batch:

0 commit comments

Comments
 (0)