File tree Expand file tree Collapse file tree 1 file changed +34
-0
lines changed Expand file tree Collapse file tree 1 file changed +34
-0
lines changed Original file line number Diff line number Diff line change 1+ from fillm.run.model import *
2+ from datasets import load_dataset
3+
4+ CACHE_DIR = '/mnt/ceph/users/lparker/datasets_astroclip'
5+ dataset = load_dataset('/mnt/home/lparker/Documents/AstroFoundationModel/AstroCLIP/astroclip_datasets/legacy_survey.py', cache_dir=CACHE_DIR)
6+ dataset.set_format(type='torch', columns=['spectrum'])
7+
8+ def load_model_from_ckpt(ckpt_path: str):
9+ " " "
10+ Load a model from a checkpoint.
11+ " " "
12+ if Path(ckpt_path).is_dir():
13+ ckpt_path = Path(ckpt_path) / "ckpt.pt"
14+
15+ chkpt = torch.load(ckpt_path)
16+ config = chkpt["config"]
17+ state_dict = chkpt["model"]
18+ model_name = config["model"]['kind']
19+ model_keys = get_model_keys(model_name)
20+
21+ model_args = {k: config['model'][k] for k in model_keys}
22+
23+ model_ctr, config_cls = model_registry[model_name]
24+ model_config = config_cls(**model_args)
25+ model_ = model_ctr(model_config)
26+ model_.load_state_dict(state_dict)
27+
28+ return {"model": model_, "config": config}
29+
30+ model_path = "/mnt/home/sgolkar/ceph/saves/fillm/run-seqformer-2708117"
31+ out = load_model_from_ckpt(model_path)
32+
33+ config = out['config']
34+ spec_model = out['model']
You can’t perform that action at this time.
0 commit comments