Skip to content

Commit daa6fc2

Browse files
committed
loadworker > 1
1 parent 4c572ca commit daa6fc2

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
from .base_weight import BaseWeight
2-
from .mm_weight import MMWeight, ROWMMWeight, COLMMWeight, MultiROWMMWeight, CustomMMWeight, CustomBMMWeight
2+
from .mm_weight import (
3+
MMWeight,
4+
MultiMMWeight,
5+
ROWMMWeight,
6+
COLMMWeight,
7+
MultiROWMMWeight,
8+
CustomMMWeight,
9+
CustomBMMWeight,
10+
)
311
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
412
from .fused_moe_weight import FusedMoeWeight

lightllm/common/basemodel/layer_weights/transformer_layer_weight.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# from lightllm.common.layers.mm import MM
44
from .base_layer_weight import BaseLayerWeight
5-
from .meta_weights import MMWeight, ROWMMWeight, FusedMoeWeight
5+
from .meta_weights import BaseWeight, MultiMMWeight, MMWeight, FusedMoeWeight
66
from lightllm.utils.log_utils import init_logger
77

88
logger = init_logger(__name__)
@@ -34,6 +34,18 @@ def _init_weight_names(self):
3434
def _init_weight(self):
3535
pass
3636

37+
def load_hf_weights(self, weights):
38+
"""
39+
load weights
40+
"""
41+
for attr_name in dir(self):
42+
attr = getattr(self, attr_name, None)
43+
if isinstance(attr, MultiMMWeight):
44+
with self.lock:
45+
attr.load_hf_weights(weights)
46+
elif isinstance(attr, BaseWeight):
47+
attr.load_hf_weights(weights)
48+
3749
def set_quantization(self):
3850
if self.quant_cfg.quant_type is None:
3951
return

0 commit comments

Comments
 (0)