Skip to content

Commit 93ccab9

Browse files
committed
Add clock training
1 parent 1ac8029 commit 93ccab9

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# %%
2+
from tqdm.auto import tqdm
3+
4+
from gbmi.exp_modular_arithmetic import SEEDS
5+
from gbmi.exp_modular_arithmetic.train import CLOCK_CONFIG, train_or_load_model
6+
from gbmi.training_tools.logging import ModelMatrixLoggingOptions
7+
from gbmi.utils import set_params
8+
9+
with tqdm(SEEDS, desc="Seed") as pbar:
10+
for seed in pbar:
11+
pbar.set_postfix({"seed": seed})
12+
runtime, model = train_or_load_model(
13+
set_params(
14+
CLOCK_CONFIG,
15+
{
16+
"seed": seed,
17+
("experiment", "train_for"): (10000, "epochs"),
18+
("experiment", "logging_options"): ModelMatrixLoggingOptions.none(),
19+
},
20+
post_init=True,
21+
),
22+
# force="load",
23+
# force="train",
24+
)
25+
26+
27+
# %%
28+
import shutil
29+
from pathlib import Path
30+
31+
import torch
32+
33+
base = Path(".").resolve()
34+
wandbs = (base / "artifacts").glob("*/*.pth")
35+
total = len(list((base / "artifacts").glob("*/*.pth")))
36+
model_base = base / "models"
37+
model_base.mkdir(exist_ok=True, parents=True)
38+
# %%
39+
with tqdm(wandbs, total=total) as pbar:
40+
for p in pbar:
41+
total -= 1
42+
cache = torch.load(p, map_location="cpu")
43+
pbar.set_postfix(
44+
{
45+
"seed": cache["run_config"]["seed"],
46+
"orig_name": p.name,
47+
"suffix_drop": "-".join(p.name.split("-")[-6:]),
48+
}
49+
)
50+
seed = cache["run_config"]["seed"]
51+
shutil.copy(
52+
p,
53+
model_base / f"{'-'.join(p.name.split('-')[:-6])}-{seed}{p.suffix}",
54+
)
55+
# break
56+
# %%

0 commit comments

Comments
 (0)