Skip to content

Commit 37569cc

Browse files
authored
[feat]add fast_weights_iterator (#3258)
* add fast_weights_iterator * update * update
1 parent 5f0b30f commit 37569cc

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

fastdeploy/model_executor/layers/linear.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,6 @@ def __init__(
162162
self.linear_shift = None
163163
self.linear_smooth = None
164164

165-
if fd_config.model_config.is_quantized:
166-
self.weight_key = f"{prefix}.quant_weight"
167-
self.weight_scale_key = f"{prefix}.weight_scale"
168-
self.act_scale_key = f"{prefix}.activation_scale"
169-
170165
def load_prequant_weight(self, state_dict: dict):
171166
"""
172167
Load the prequantized weight from the state dictionary.

fastdeploy/model_executor/load_weight_utils.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from paddleformers.transformers import PretrainedModel
2525
from paddleformers.transformers.model_utils import load_tp_checkpoint
2626
from paddleformers.utils.log import logger
27+
from paddleformers.utils.safetensors import fast_safe_open
2728
from safetensors import safe_open
2829
from tqdm import tqdm
2930

@@ -155,18 +156,28 @@ def get_expert_ranges(fd_config):
155156
return state_dict
156157

157158

158-
def safetensors_weights_iterator(
159-
safe_tensor_list: list[str],
160-
):
159+
def safetensors_weights_iterator(safe_tensor_list: list[str]):
161160
"""
162161
safetensors_weights_iterator
163162
"""
164163
for st_file in tqdm(
165164
safe_tensor_list,
166165
desc="Loading safetensors checkpoint shards",
167166
):
168-
from paddleformers.utils.safetensors import fast_safe_open
167+
with safe_open(st_file, framework="np") as f:
168+
for name in f.keys():
169+
param = f.get_tensor(name)
170+
yield name, param
169171

172+
173+
def fast_weights_iterator(safe_tensor_list: list[str]):
174+
"""
175+
paddleformers' iterator for safetensors
176+
"""
177+
for st_file in tqdm(
178+
safe_tensor_list,
179+
desc="Loading safetensors checkpoint shards",
180+
):
170181
with fast_safe_open(st_file, framework="np") as f:
171182
for name in f.keys():
172183
param = f.get_slice(name)
@@ -215,13 +226,12 @@ def load_pre_sharded_checkpoint(model_path: str, local_rank: int, use_fastsafete
215226
"""
216227
load_pre_sharded_checkpoint
217228
"""
218-
from fastdeploy.model_executor.layers.utils import get_tensor
219229

220230
state_dict = {}
221231
_, safetensor_files = get_all_safetensors(os.path.join(model_path, f"rank{local_rank}"))
222232
weights_iterator = safetensors_weights_iterator(safetensor_files)
223233
for name, weight in weights_iterator:
224-
state_dict[name] = get_tensor(weight)
234+
state_dict[name] = weight
225235
return state_dict
226236

227237

fastdeploy/model_executor/model_loader/default_loader_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222

2323
from fastdeploy.config import FDConfig, LoadConfig, ModelConfig
2424
from fastdeploy.model_executor.load_weight_utils import (
25+
fast_weights_iterator,
2526
get_all_safetensors,
2627
measure_time,
27-
safetensors_weights_iterator,
2828
)
2929
from fastdeploy.model_executor.model_loader.base_loader import BaseModelLoader
3030
from fastdeploy.model_executor.models.model_base import ModelRegistry
@@ -49,7 +49,7 @@ def clean_memory_fragments(self) -> None:
4949
@measure_time
5050
def load_weights(self, model, fd_config: FDConfig) -> None:
5151
_, safetensor_files = get_all_safetensors(fd_config.model_config.model)
52-
weights_iterator = safetensors_weights_iterator(safetensor_files)
52+
weights_iterator = fast_weights_iterator(safetensor_files)
5353
model.load_weights(weights_iterator)
5454
self.clean_memory_fragments()
5555

0 commit comments

Comments
 (0)