Skip to content

Commit 407c237

Browse files
authored
model-conversion : fix pyright errors (ggml-org#15770)
This commit addresses type errors reported by pyright in the model conversion scripts.
1 parent cdedb70 commit 407c237

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import argparse
44
import os
55
import importlib
6-
import sys
76
import torch
87
import numpy as np
98

10-
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
9+
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
1110
from pathlib import Path
1211

1312
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
@@ -43,6 +42,8 @@
4342
model = model_class.from_pretrained(model_path)
4443
except (ImportError, AttributeError) as e:
4544
print(f"Failed to import or load model: {e}")
45+
print("Falling back to AutoModelForCausalLM")
46+
model = AutoModelForCausalLM.from_pretrained(model_path)
4647
else:
4748
model = AutoModelForCausalLM.from_pretrained(model_path)
4849
print(f"Model class: {type(model)}")

examples/model-conversion/scripts/utils/inspect-org-model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
file_path = os.path.join(model_path, file_name)
4141
print(f"\n--- From {file_name} ---")
4242

43-
with safe_open(file_path, framework="pt") as f:
43+
with safe_open(file_path, framework="pt") as f: # type: ignore
4444
for tensor_name in sorted(tensor_names):
4545
tensor = f.get_tensor(tensor_name)
4646
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
@@ -49,7 +49,7 @@
4949
# Single file model (original behavior)
5050
print("Single-file model detected")
5151

52-
with safe_open(single_file_path, framework="pt") as f:
52+
with safe_open(single_file_path, framework="pt") as f: # type: ignore
5353
keys = f.keys()
5454
print("Tensors in model:")
5555
for key in sorted(keys):

0 commit comments

Comments
 (0)