Skip to content

Commit f94cd5f

Browse files
authored
Don't save AutoBridge in args (#1181)
1 parent b749038 commit f94cd5f

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

slime/backends/megatron_utils/actor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import ray
99
import torch
1010
import torch.distributed as dist
11-
from megatron.bridge import AutoBridge
1211
from megatron.core import mpu
1312
from ray.actor import ActorHandle
1413
from torch_memory_saver import torch_memory_saver
@@ -67,8 +66,6 @@ def init(
6766
if i == dist.get_rank():
6867
self.hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True)
6968
self.tokenizer = AutoTokenizer.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True)
70-
if args.megatron_to_hf_mode == "bridge":
71-
args.bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
7269

7370
dist.barrier(group=get_gloo_group())
7471

slime/backends/megatron_utils/checkpoint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from megatron.training.checkpointing import load_checkpoint as _load_checkpoint_megatron
88
from megatron.training.checkpointing import save_checkpoint
99
from megatron.training.global_vars import get_args
10+
1011
from slime.utils import megatron_bridge_utils
1112

1213
logger = logging.getLogger(__name__)
@@ -48,12 +49,15 @@ def _is_megatron_checkpoint(path: str | Path) -> bool:
4849

4950
def _load_checkpoint_hf(ddp_model, optimizer, args, load_path: str):
5051
assert args.megatron_to_hf_mode == "bridge", "Only bridge mode is supported for loading HF checkpoint"
52+
from megatron.bridge import AutoBridge
53+
5154
import slime_plugins.megatron_bridge # noqa: F401
5255

5356
logger.info(f"Load checkpoint from HuggingFace model into Megatron (path={load_path})")
5457

5558
with megatron_bridge_utils.patch_megatron_model(ddp_model):
56-
args.bridge.load_hf_weights(ddp_model)
59+
bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
60+
bridge.load_hf_weights(ddp_model)
5761

5862
# Copied from Megatron-core :: load_checkpoint (with simplifications)
5963
if (args.fp16 or args.bf16) and optimizer is not None:

slime/backends/megatron_utils/data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def get_batch(
5151
assert "tokens" in keys
5252
batch = data_iterator.get_next(keys)
5353

54-
packed_seq_params = None
5554
tokens = batch["tokens"]
5655
# use 0 as the pad token id should be fine?
5756
pad_token_id = 0

slime/backends/megatron_utils/model_provider.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def get_model_provider_func(
5454
role: Literal["actor", "critic"] = "actor",
5555
):
5656
if args.megatron_to_hf_mode == "bridge":
57-
provider = args.bridge.to_megatron_provider(load_weights=False)
57+
from megatron.bridge import AutoBridge
58+
59+
bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
60+
provider = bridge.to_megatron_provider(load_weights=False)
5861
provider.finalize()
5962
return provider.provide
6063

0 commit comments

Comments
 (0)