Skip to content

Commit 2b53aae

Browse files
committed
Add another trainer
1 parent 1ecf3c1 commit 2b53aae

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# %%
2+
from tqdm.auto import tqdm
3+
4+
from gbmi.exp_modular_arithmetic import SEEDS
5+
from gbmi.exp_modular_arithmetic.train import FAST_CLOCK_CONFIG, train_or_load_model
6+
from gbmi.training_tools.logging import ModelMatrixLoggingOptions
7+
from gbmi.utils import set_params
8+
9+
with tqdm(SEEDS, desc="Seed") as pbar:
10+
for seed in pbar:
11+
pbar.set_postfix({"seed": seed})
12+
runtime, model = train_or_load_model(
13+
set_params(
14+
FAST_CLOCK_CONFIG,
15+
{
16+
"seed": seed,
17+
},
18+
post_init=True,
19+
),
20+
# force="load",
21+
# force="train",
22+
)
23+
24+
25+
# %%
26+
import shutil
27+
from pathlib import Path
28+
29+
import torch
30+
31+
base = Path(".").resolve()
32+
wandbs = (base / "artifacts").glob("*/*.pth")
33+
total = len(list((base / "artifacts").glob("*/*.pth")))
34+
model_base = base / "models"
35+
model_base.mkdir(exist_ok=True, parents=True)
36+
# %%
37+
with tqdm(wandbs, total=total) as pbar:
38+
for p in pbar:
39+
total -= 1
40+
cache = torch.load(p, map_location="cpu")
41+
pbar.set_postfix(
42+
{
43+
"seed": cache["run_config"]["seed"],
44+
"orig_name": p.name,
45+
"suffix_drop": "-".join(p.name.split("-")[-6:]),
46+
}
47+
)
48+
seed = cache["run_config"]["seed"]
49+
shutil.copy(
50+
p,
51+
model_base / f"{'-'.join(p.name.split('-')[:-6])}-{seed}{p.suffix}",
52+
)
53+
# break
54+
# %%

gbmi/exp_modular_arithmetic/train.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,20 @@ def log_softmax(self, x: Tensor, **kwargs) -> Tensor:
187187
checkpoint_every=(500, "epochs"),
188188
)
189189

190+
FAST_CLOCK_CONFIG = Config(
191+
experiment=ModularArithmetic(
192+
p=113,
193+
training_ratio=0.8,
194+
logging_options=ModelMatrixLoggingOptions.none(),
195+
),
196+
seed=0,
197+
deterministic=False,
198+
train_for=(10000, "epochs"),
199+
log_every_n_steps=1,
200+
validate_every=(100, "epochs"),
201+
checkpoint_every=(2000, "epochs"),
202+
)
203+
190204
PIZZA_CONFIG = Config(
191205
experiment=ModularArithmetic(
192206
p=59,

0 commit comments

Comments
 (0)