Skip to content

Commit 94b6e7a

Browse files
authored
[MTP][RL]support rl reshard wenxin-tools-145 (#4173)
* support mtp reshard in rl mode * fix function
1 parent 389c5dd commit 94b6e7a

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

fastdeploy/model_executor/model_loader/default_loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer:
7171
# register rl model
7272
import fastdeploy.rl # noqa
7373

74+
if fd_config.speculative_config.model_type != "mtp":
75+
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
76+
else:
77+
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
78+
7479
architectures = architectures + "RL"
7580
context = paddle.LazyGuard()
7681
else:

fastdeploy/model_executor/model_loader/default_loader_v1.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ def load_model(self, fd_config: FDConfig) -> nn.Layer:
5959
# register rl model
6060
import fastdeploy.rl # noqa
6161

62+
if fd_config.speculative_config.model_type != "mtp":
63+
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MoeForCausalLM")
64+
else:
65+
architectures = architectures.replace("Ernie5ForCausalLM", "Ernie5MTPForCausalLM")
66+
6267
architectures = architectures + "RL"
6368

6469
with context:

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
import os
1818
import time
1919
from multiprocessing.shared_memory import SharedMemory
20-
from typing import Any, Dict
20+
from typing import Any, Dict, List
2121

2222
import numpy as np
2323
import paddle
24-
from paddle import nn
2524
from paddleformers.utils.log import logger
2625

2726
from fastdeploy.config import FDConfig
@@ -30,7 +29,7 @@
3029
class DynamicWeightManager:
3130
"""Manages model weights loading, updating and shared state across processes."""
3231

33-
def __init__(self, fd_config: FDConfig, model: nn.Layer):
32+
def __init__(self, fd_config: FDConfig, models):
3433
"""Initialize with config and model instances."""
3534
self.fd_config = fd_config
3635
self.load_config = fd_config.load_config
@@ -41,7 +40,10 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
4140
self.meta_src_id = self._get_gpu_id()
4241
self.first_load = True
4342
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
44-
self.model: nn.Layer = model
43+
if not isinstance(models, List):
44+
self.model_list = [models]
45+
else:
46+
self.model_list = models
4547
self._capture_model_state()
4648
self.update_parameters()
4749
self.finalize_update()
@@ -54,9 +56,10 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
5456
@paddle.no_grad()
5557
def _capture_model_state(self):
5658
"""Capture and store initial model parameters state."""
57-
for name, param in self.model.state_dict().items():
58-
logger.debug(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
59-
self.state_dict[name] = param
59+
for model in self.model_list:
60+
for name, param in model.state_dict().items():
61+
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
62+
self.state_dict[name] = param
6063

6164
def update_parameters(self, pid: int = 0) -> None:
6265
"""Core method to update model parameters based on strategy."""
@@ -133,8 +136,9 @@ def clear_parameters(self, pid: int = 0) -> None:
133136

134137
paddle.device.cuda.empty_cache()
135138
# step2: release model weight
136-
for param in self.model.state_dict().values():
137-
param._clear_data()
139+
for model in self.model_list:
140+
for param in model.state_dict().values():
141+
param._clear_data()
138142

139143
self._verify_parameters("clearance")
140144

0 commit comments

Comments
 (0)