Skip to content

Commit c036e77

Browse files
authored
Missing import fix (#3241)
1 parent 4a23972 commit c036e77

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

megatron/core/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,8 +2082,8 @@ def get_thd_batch_on_this_cp_rank(
20822082
max_seqlen_kv=int(max_seqlen[0].item()),
20832083
)
20842084

2085-
cp_size = get_context_parallel_world_size() if cp_size is None else cp_size
2086-
cp_rank = get_context_parallel_rank() if cp_rank is None else cp_rank
2085+
cp_size = parallel_state.get_context_parallel_world_size() if cp_size is None else cp_size
2086+
cp_rank = parallel_state.get_context_parallel_rank() if cp_rank is None else cp_rank
20872087
if cp_size > 1: # slice batch along sequence dimension for context parallelism
20882088
assert tex is not None and is_te_min_version("1.10.0"), (
20892089
"Please update Transformer Engine to >= 1.10 to use "

0 commit comments

Comments
 (0)