Skip to content

Commit df08771

Browse files
authored
Create specformer.ipynb
1 parent 843a09f commit df08771

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

notebooks/specformer.ipynb

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from fillm.run.model import *
2+
from datasets import load_dataset
3+
4+
CACHE_DIR = '/mnt/ceph/users/lparker/datasets_astroclip'
5+
dataset = load_dataset('/mnt/home/lparker/Documents/AstroFoundationModel/AstroCLIP/astroclip_datasets/legacy_survey.py', cache_dir=CACHE_DIR)
6+
dataset.set_format(type='torch', columns=['spectrum'])
7+
8+
def load_model_from_ckpt(ckpt_path: str):
9+
"""
10+
Load a model from a checkpoint.
11+
"""
12+
if Path(ckpt_path).is_dir():
13+
ckpt_path = Path(ckpt_path) / "ckpt.pt"
14+
15+
chkpt = torch.load(ckpt_path)
16+
config = chkpt["config"]
17+
state_dict = chkpt["model"]
18+
model_name = config["model"]['kind']
19+
model_keys = get_model_keys(model_name)
20+
21+
model_args = {k: config['model'][k] for k in model_keys}
22+
23+
model_ctr, config_cls = model_registry[model_name]
24+
model_config = config_cls(**model_args)
25+
model_ = model_ctr(model_config)
26+
model_.load_state_dict(state_dict)
27+
28+
return {"model": model_, "config": config}
29+
30+
model_path = "/mnt/home/sgolkar/ceph/saves/fillm/run-seqformer-2708117"
31+
out = load_model_from_ckpt(model_path)
32+
33+
config = out['config']
34+
spec_model = out['model']

0 commit comments

Comments
 (0)