-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
对应代码片段如下:
angelslim/compressor/speculative/train/models/draft/base_model.py
def _load_from_safetensors(
self, model_path, embed_weight_key="model.embed_tokens.weight"
):
"""Load embedding weights from safetensors format."""
try:
index_file = os.path.join(model_path, "model.safetensors.index.json")
if not os.path.exists(index_file):
return None
with open(index_file, "r") as f:
index_json = json.load(f)
if embed_weight_key in index_json["weight_map"]:
emb_path = index_json["weight_map"][embed_weight_key]
else:
raise KeyError("Embedding weights key not found in index.")
safetensors_file = os.path.join(model_path, emb_path)
with safe_open(safetensors_file, framework="pt", device="cpu") as f:
tensor_slice = f.get_slice(embed_weight_key)
_, hidden_dim = tensor_slice.get_shape()
tensor = tensor_slice[:, :hidden_dim].float()
return tensor
except Exception as e:
print(f"Failed to load from safetensors: {e}")
return None
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels