17
17
import os
18
18
import time
19
19
from multiprocessing .shared_memory import SharedMemory
20
- from typing import Any , Dict
20
+ from typing import Any , Dict , List
21
21
22
22
import numpy as np
23
23
import paddle
24
- from paddle import nn
25
24
from paddleformers .utils .log import logger
26
25
27
26
from fastdeploy .config import FDConfig
30
29
class DynamicWeightManager :
31
30
"""Manages model weights loading, updating and shared state across processes."""
32
31
33
- def __init__ (self , fd_config : FDConfig , model : nn . Layer ):
32
+ def __init__ (self , fd_config : FDConfig , models ):
34
33
"""Initialize with config and model instances."""
35
34
self .fd_config = fd_config
36
35
self .load_config = fd_config .load_config
@@ -41,7 +40,10 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
41
40
self .meta_src_id = self ._get_gpu_id ()
42
41
self .first_load = True
43
42
self .ipc_path = f"/shared_ipc_meta/ipc_metas_{ self .meta_src_id } "
44
- self .model : nn .Layer = model
43
+ if not isinstance (models , List ):
44
+ self .model_list = [models ]
45
+ else :
46
+ self .model_list = models
45
47
self ._capture_model_state ()
46
48
self .update_parameters ()
47
49
self .finalize_update ()
@@ -54,9 +56,10 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
54
56
@paddle .no_grad ()
55
57
def _capture_model_state (self ):
56
58
"""Capture and store initial model parameters state."""
57
- for name , param in self .model .state_dict ().items ():
58
- logger .debug (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } " )
59
- self .state_dict [name ] = param
59
+ for model in self .model_list :
60
+ for name , param in model .state_dict ().items ():
61
+ logger .info (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } " )
62
+ self .state_dict [name ] = param
60
63
61
64
def update_parameters (self , pid : int = 0 ) -> None :
62
65
"""Core method to update model parameters based on strategy."""
@@ -133,8 +136,9 @@ def clear_parameters(self, pid: int = 0) -> None:
133
136
134
137
paddle .device .cuda .empty_cache ()
135
138
# step2: release model weight
136
- for param in self .model .state_dict ().values ():
137
- param ._clear_data ()
139
+ for model in self .model_list :
140
+ for param in model .state_dict ().values ():
141
+ param ._clear_data ()
138
142
139
143
self ._verify_parameters ("clearance" )
140
144
0 commit comments