We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4a23972 commit c036e77Copy full SHA for c036e77
megatron/core/utils.py
@@ -2082,8 +2082,8 @@ def get_thd_batch_on_this_cp_rank(
2082
max_seqlen_kv=int(max_seqlen[0].item()),
2083
)
2084
2085
- cp_size = get_context_parallel_world_size() if cp_size is None else cp_size
2086
- cp_rank = get_context_parallel_rank() if cp_rank is None else cp_rank
+ cp_size = parallel_state.get_context_parallel_world_size() if cp_size is None else cp_size
+ cp_rank = parallel_state.get_context_parallel_rank() if cp_rank is None else cp_rank
2087
if cp_size > 1: # slice batch along sequence dimension for context parallelism
2088
assert tex is not None and is_te_min_version("1.10.0"), (
2089
"Please update Transformer Engine to >= 1.10 to use "
0 commit comments