Skip to content

Commit bea76cd

Browse files
committed
Add --force
1 parent a70a3be commit bea76cd

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

gbmi/exp_argmax_of_n/run_train_argmax_of_10.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
default=",".join(sorted(map(str, SEEDS))),
1212
help="Comma-separated list of seeds to use",
1313
)
14+
parser.add_argument(
15+
"--force",
16+
choices=["train", "load", "none"],
17+
default="train",
18+
help="Force training or loading",
19+
)
1420
args = parser.parse_args()
1521

1622
for d_vocab in tqdm((64, 128), desc="d_vocab"):
1723
with tqdm(map(int, args.seeds.split(",")), desc="Seed", leave=False) as pbar:
1824
for seed in pbar:
1925
cfg = ARGMAX_OF_10_CONFIG(seed, d_vocab=d_vocab, deterministic=False)
2026
pbar.set_postfix({"seed": seed, "d_vocab": d_vocab, "cfg": cfg})
21-
runtime, model = train_or_load_model(cfg) # , force="train"
27+
runtime, model = train_or_load_model(cfg, force=args.force)

gbmi/exp_argmax_of_n/run_train_argmax_of_20.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
default=",".join(sorted(map(str, SEEDS))),
1212
help="Comma-separated list of seeds to use",
1313
)
14+
parser.add_argument(
15+
"--force",
16+
choices=["train", "load", "none"],
17+
default="train",
18+
help="Force training or loading",
19+
)
1420
args = parser.parse_args()
1521

1622
for d_vocab in tqdm((64, 512), desc="d_vocab"):
1723
with tqdm(map(int, args.seeds.split(",")), desc="Seed", leave=False) as pbar:
1824
for seed in pbar:
1925
cfg = ARGMAX_OF_20_CONFIG(seed, d_vocab=d_vocab, deterministic=False)
2026
pbar.set_postfix({"seed": seed, "d_vocab": d_vocab, "cfg": cfg})
21-
runtime, model = train_or_load_model(cfg) # , force="train"
27+
runtime, model = train_or_load_model(cfg, force=args.force)

gbmi/exp_argmax_of_n/run_train_argmax_of_4.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,15 @@
1111
default=",".join(sorted(map(str, SEEDS))),
1212
help="Comma-separated list of seeds to use",
1313
)
14+
parser.add_argument(
15+
"--force",
16+
choices=["train", "load", "none"],
17+
default="train",
18+
help="Force training or loading",
19+
)
1420
args = parser.parse_args()
1521

1622
with tqdm(map(int, args.seeds.split(",")), desc="Seed") as pbar:
1723
for seed in pbar:
1824
pbar.set_postfix({"seed": seed})
19-
runtime, model = train_or_load_model(
20-
ARGMAX_OF_4_CONFIG(seed)
21-
) # , force="train"
25+
runtime, model = train_or_load_model(ARGMAX_OF_4_CONFIG(seed), force=args.force)

gbmi/exp_argmax_of_n/run_train_argmax_of_5.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@
1111
default=",".join(sorted(map(str, SEEDS))),
1212
help="Comma-separated list of seeds to use",
1313
)
14+
parser.add_argument(
15+
"--force",
16+
choices=["train", "load", "none"],
17+
default="train",
18+
help="Force training or loading",
19+
)
1420
args = parser.parse_args()
1521

1622
with tqdm(map(int, args.seeds.split(",")), desc="Seed") as pbar:
1723
for seed in pbar:
1824
cfg = ARGMAX_OF_5_CONFIG(seed, deterministic=False)
1925
pbar.set_postfix({"seed": seed, "cfg": cfg})
20-
runtime, model = train_or_load_model(cfg) # , force="train"
26+
runtime, model = train_or_load_model(cfg, force=args.force)

0 commit comments

Comments
 (0)