Skip to content

Commit 8b17fd3

Browse files
committed
Fix missing keywords, device errors.
1 parent 8853346 commit 8b17fd3

File tree

5 files changed

+25
-4
lines changed

5 files changed

+25
-4
lines changed

Pilot3/P3B6/p3b6.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
{"name": "weight_decay", "action": "store", "type": float},
1111
{"name": "grad_clip", "action": "store", "type": int},
1212
{"name": "unrolled", "action": "store", "type": candle.str2bool},
13+
{"name": "device", "action": "store", "type": str},
14+
{"name": "num_train_samples", "action": "store", "type": int},
15+
{"name": "num_valid_samples", "action": "store", "type": int},
16+
{"name": "num_test_samples", "action": "store", "type": int},
17+
{"name": "num_classes", "action": "store", "type": int},
18+
{"name": "eps", "action": "store", "type": float},
19+
1320
]
1421

1522
required = [

Pilot3/P3B7/p3b7.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
{"name": "grad_clip", "action": "store", "type": int},
1212
{"name": "unrolled", "action": "store", "type": candle.str2bool},
1313
{"name": "use_synthetic_data", "action": "store", "type": candle.str2bool},
14+
{"name": "eps", "action": "store", "type": float},
15+
{"name": "device", "action": "store", "type": str},
16+
{"name": "embed_dim", "action": "store", "type": int},
17+
{"name": "n_filters", "action": "store", "type": int},
18+
{"name": "kernel1", "action": "store", "type": int},
19+
{"name": "kernel2", "action": "store", "type": int},
20+
{"name": "kernel3", "action": "store", "type": int},
21+
1422
]
1523

1624
required = [

Pilot3/P3B8/default_model.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ learning_rate = 2e-5
44
eps = 1e-8
55
weight_decay = 0.0
66
batch_size = 10
7-
num_epochs = 10
7+
epochs = 10
88
rng_seed = 13
99
num_train_samples = 10000
1010
num_valid_samples = 10000

Pilot3/P3B8/p3b8.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,20 @@
1010
{"name": "weight_decay", "action": "store", "type": float},
1111
{"name": "grad_clip", "action": "store", "type": int},
1212
{"name": "unrolled", "action": "store", "type": candle.str2bool},
13+
{"name": "device", "action": "store", "type": str},
14+
{"name": "num_train_samples", "action": "store", "type": int},
15+
{"name": "num_valid_samples", "action": "store", "type": int},
16+
{"name": "num_test_samples", "action": "store", "type": int},
17+
{"name": "num_classes", "action": "store", "type": int},
18+
{"name": "eps", "action": "store", "type": float},
1319
]
1420

1521
required = [
1622
"learning_rate",
1723
"weight_decay",
1824
"rng_seed",
1925
"batch_size",
20-
"num_epochs",
26+
"epochs",
2127
]
2228

2329

Pilot3/P3B8/p3b8_baseline_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def validate(dataloader, model, args, device, epoch):
9090
for idx, batch in enumerate(dataloader):
9191

9292
input_ids = batch["tokens"].to(device)
93-
labels = batch["label"].to(args.device)
93+
labels = batch["label"].to(device)
9494

9595
output = model(input_ids, labels=labels)
9696

@@ -137,7 +137,7 @@ def run(args):
137137
optimizer = torch.optim.Adam(params, lr=args.learning_rate, eps=args.eps)
138138
criterion = nn.BCEWithLogitsLoss()
139139

140-
for epoch in range(args.num_epochs):
140+
for epoch in range(args.epochs):
141141
train(train_loader, model, optimizer, criterion, args, epoch)
142142
validate(valid_loader, model, args, args.device, epoch)
143143

0 commit comments

Comments
 (0)