Skip to content

Commit 1d4ac0e

Browse files
committed
Reformatted with ruff, changed requirements
1 parent 5bc8866 commit 1d4ac0e

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
--extra-index-url https://download.pytorch.org/whl/cpu
2-
torch~=2.6.0
3-
torchvision~=0.21.0
4-
transformers~=4.55.0
5-
huggingface-hub~=0.34.0
2+
torch
3+
torchvision
4+
transformers
5+
huggingface-hub
6+
accelerate

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,16 @@
3838
# apertus_mod.apply_rotary_pos_emb = debug_rope
3939
### == END ROPE DEBUG ===
4040

41+
4142
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
4243
"""
4344
Print a tensor in llama.cpp debug style.
44-
45+
4546
Supports:
4647
- 2D tensors (seq, hidden)
4748
- 3D tensors (batch, seq, hidden)
4849
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
49-
50+
5051
Shows first and last max_vals of each vector per sequence position.
5152
"""
5253
t = tensor.detach().to(torch.float32).cpu()
@@ -85,24 +86,24 @@ def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int =
8586
# If no overlap, we'll add a separator between first and last sequences
8687
indices = first_indices + last_indices
8788
separator_index = len(first_indices)
88-
89+
8990
for i, si in enumerate(indices):
9091
# Add separator if needed
9192
if separator_index is not None and i == separator_index:
9293
print(" ...")
93-
94+
9495
# Extract appropriate slice
9596
vec = t[0, si]
9697
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
9798
flat = vec.flatten().tolist()
98-
else: # 2D or 3D case
99+
else: # 2D or 3D case
99100
flat = vec.tolist()
100101

101102
# First and last slices
102103
first = flat[:max_vals]
103104
last = flat[-max_vals:] if len(flat) >= max_vals else flat
104105
first_str = ", ".join(f"{v:12.4f}" for v in first)
105-
last_str = ", ".join(f"{v:12.4f}" for v in last)
106+
last_str = ", ".join(f"{v:12.4f}" for v in last)
106107

107108
print(f" [{first_str}, ..., {last_str}]")
108109

@@ -125,15 +126,17 @@ def fn(_m, input, output):
125126
return fn
126127

127128

128-
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
129+
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
129130

130-
parser = argparse.ArgumentParser(description='Process model with specified path')
131-
parser.add_argument('--model-path', '-m', help='Path to the model')
131+
parser = argparse.ArgumentParser(description="Process model with specified path")
132+
parser.add_argument("--model-path", "-m", help="Path to the model")
132133
args = parser.parse_args()
133134

134-
model_path = os.environ.get('MODEL_PATH', args.model_path)
135+
model_path = os.environ.get("MODEL_PATH", args.model_path)
135136
if model_path is None:
136-
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
137+
parser.error(
138+
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
139+
)
137140

138141
config = AutoConfig.from_pretrained(model_path)
139142

@@ -150,18 +153,26 @@ def fn(_m, input, output):
150153

151154
if unreleased_model_name:
152155
model_name_lower = unreleased_model_name.lower()
153-
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
156+
unreleased_module_path = (
157+
f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
158+
)
154159
class_name = f"{unreleased_model_name}ForCausalLM"
155160
print(f"Importing unreleased model module: {unreleased_module_path}")
156161

157162
try:
158-
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
159-
model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained
163+
model_class = getattr(
164+
importlib.import_module(unreleased_module_path), class_name
165+
)
166+
model = model_class.from_pretrained(
167+
model_path
168+
) # Note: from_pretrained, not fromPretrained
160169
except (ImportError, AttributeError) as e:
161170
print(f"Failed to import or load model: {e}")
162171
exit(1)
163172
else:
164-
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", offload_folder="offload")
173+
model = AutoModelForCausalLM.from_pretrained(
174+
model_path, device_map="auto", offload_folder="offload"
175+
)
165176

166177
for name, module in model.named_modules():
167178
if len(list(module.children())) == 0: # only leaf modules

0 commit comments

Comments
 (0)