Skip to content

Commit 8fd4e97

Browse files
committed
Update _dump_tools.py
1 parent 065e4aa commit 8fd4e97

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

autointent/_dump_tools.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
158158
"name": val.__class__.__name__,
159159
}
160160
# Save configuration if available
161-
if hasattr(val, 'get_config'):
162-
class_info['config'] = val.get_config()
161+
if hasattr(val, "get_config"):
162+
class_info["config"] = val.get_config()
163163
with (model_path / "class_info.json").open("w") as f:
164164
json.dump(class_info, f)
165165
except Exception as e:
@@ -280,24 +280,24 @@ def load( # noqa: C901, PLR0912, PLR0915
280280
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
281281
logger.exception(msg)
282282
elif child.name == Dumper.torch_models:
283-
for model_dir in child.iterdir():
284-
try:
283+
try:
284+
for model_dir in child.iterdir():
285285
with (model_dir / "class_info.json").open("r") as f:
286286
class_info = json.load(f)
287287
module = __import__(class_info["module"], fromlist=[class_info["name"]])
288288
model_class = getattr(module, class_info["name"])
289-
config = class_info.get('config', {})
289+
config = class_info.get("config", {})
290290
# Initialize model with config if available
291291
model = model_class(**config)
292292
model.load_state_dict(torch.load(model_dir / "model.pt"))
293293
model.eval()
294294
torch_models[model_dir.name] = model
295-
except Exception as e:
296-
logger.exception(f"Error loading torch model {model_dir.name}: {e}")
295+
except Exception as e:
296+
msg = f"Error loading torch model {model_dir.name}: {e}"
297+
logger.exception(msg)
297298
elif child.name == Dumper.containers:
298299
try:
299300
for container_file in child.iterdir():
300-
print(container_file)
301301
with container_file.open("r") as f:
302302
containers = json.load(f)
303303
except Exception as e:
@@ -307,8 +307,6 @@ def load( # noqa: C901, PLR0912, PLR0915
307307
msg = f"Found unexpected child {child}"
308308
logger.error(msg)
309309

310-
print(containers)
311-
312310
obj.__dict__.update(
313311
tags
314312
| simple_attrs

0 commit comments

Comments
 (0)