|
13 | 13 | from megatron.core.distributed import DistributedDataParallel as DDP |
14 | 14 | from megatron.core.optimizer import ChainedOptimizer |
15 | 15 | from megatron.core.packed_seq_params import PackedSeqParams |
16 | | -from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank |
17 | 16 | from megatron.training import get_args, get_wandb_writer |
18 | 17 | from packaging import version |
19 | 18 |
|
@@ -86,17 +85,19 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams: |
86 | 85 | qkv_format='thd') |
87 | 86 |
|
88 | 87 |
|
89 | | -def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: torch.Tensor, dim: int): |
90 | | - # TODO: compat bshd |
| 88 | +def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: Optional[torch.Tensor], dim: int): |
91 | 89 | if dim < 0: |
92 | 90 | dim = (dim + inputs.ndim) % inputs.ndim |
93 | 91 | new_inputs = [] |
94 | 92 | cp_size = mpu.get_context_parallel_world_size() |
95 | 93 | cp_rank = mpu.get_context_parallel_rank() |
96 | | - for i in range(cu_seqlens.shape[0] - 1): |
97 | | - slices = [slice(None)] * inputs.ndim |
98 | | - slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) |
99 | | - val = inputs[tuple(slices)] |
| 94 | + for i in range(1 if cu_seqlens is None else (cu_seqlens.shape[0] - 1)): |
| 95 | + if cu_seqlens is None: |
| 96 | + val = inputs |
| 97 | + else: |
| 98 | + slices = [slice(None)] * inputs.ndim |
| 99 | + slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) |
| 100 | + val = inputs[tuple(slices)] |
100 | 101 | view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) |
101 | 102 | val = val.view(view_shape) |
102 | 103 | index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', |
@@ -127,15 +128,13 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): |
127 | 128 | keys.append('input_ids') |
128 | 129 |
|
129 | 130 | packed_seq_params = batch.get('packed_seq_params') |
130 | | - if packed_seq_params is None: |
131 | | - return mcore_get_batch_on_this_cp_rank(batch) |
132 | 131 | for key, val in batch.items(): |
133 | 132 | if key not in keys: |
134 | 133 | continue |
135 | 134 | if args.task_type == 'seq_cls' and key == 'labels': |
136 | 135 | continue |
137 | 136 | if val is not None: |
138 | | - batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1) |
| 137 | + batch[key] = split_cp_inputs(val, getattr(packed_seq_params, 'cu_seqlens_q', None), -1) |
139 | 138 |
|
140 | 139 | return batch |
141 | 140 |
|
|
0 commit comments