Skip to content

eagle3训练,模型加载时,没考虑2b小模型没有model.safetensors.index.json的情况 #173

@wincle

Description

@wincle

对应代码片段如下:
angelslim/compressor/speculative/train/models/draft/base_model.py

def _load_from_safetensors(
    self, model_path, embed_weight_key="model.embed_tokens.weight"
):
    """Load embedding weights from safetensors format."""
    try:
        index_file = os.path.join(model_path, "model.safetensors.index.json")
        if not os.path.exists(index_file):
            return None

        with open(index_file, "r") as f:
            index_json = json.load(f)

        if embed_weight_key in index_json["weight_map"]:
            emb_path = index_json["weight_map"][embed_weight_key]
        else:
            raise KeyError("Embedding weights key not found in index.")

        safetensors_file = os.path.join(model_path, emb_path)

        with safe_open(safetensors_file, framework="pt", device="cpu") as f:
            tensor_slice = f.get_slice(embed_weight_key)
            _, hidden_dim = tensor_slice.get_shape()
            tensor = tensor_slice[:, :hidden_dim].float()

        return tensor
    except Exception as e:
        print(f"Failed to load from safetensors: {e}")
        return None

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions