Skip to content

Commit 54d675e

Browse files
committed
Appease flake8. Remove deprecated data loading procedure
1 parent aa50f2b commit 54d675e

File tree

1 file changed

+17
-32
lines changed

1 file changed

+17
-32
lines changed

Pilot3/P3B7/p3b7_baseline_pytorch.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import p3b7 as bmk
33
import candle
44

5-
import numpy as np
5+
import pandas as pd
6+
from pathlib import Path
67

7-
import torch.nn as nn
88
from torch.utils.data import DataLoader
99

1010
from data import P3B3, Egress
@@ -13,10 +13,7 @@
1313
from meters import AccuracyMeter
1414
from metrics import F1Meter
1515

16-
from prune import (
17-
negative_prune, min_max_prune,
18-
create_prune_masks, remove_prune_masks
19-
)
16+
from prune import create_prune_masks, remove_prune_masks
2017

2118

2219
TASKS = {
@@ -134,47 +131,35 @@ def evaluate(model, loader, device):
134131

135132
accmeter.update_accuracy()
136133

137-
print(f'Validation accuracy:')
134+
print(f'{"Validation accuracy:"}')
138135
accmeter.print_task_accuracies()
139136

140137
loss /= len(loader.dataset)
141138

142139
return loss
143140

144141

145-
def save_dataframe(metrics, filename):
142+
def save_dataframe(metrics, filename, args):
146143
"""Save F1 metrics"""
147144
df = pd.DataFrame(metrics, index=[0])
148-
path = Path(ARGS.savepath).joinpath(f'f1/{filename}.csv')
145+
path = Path(args.savepath).joinpath(f'f1/{filename}.csv')
149146
df.to_csv(path, index=False)
150147

151148

152149
def run(args):
153150
args = candle.ArgumentStruct(**args)
154151
args.cuda = torch.cuda.is_available()
155-
args.device = torch.device(f"cuda" if args.cuda else "cpu")
156-
157-
if args.use_synthetic_data:
158-
train_data, valid_data = get_synthetic_data(args)
159-
160-
hparams = Hparams(
161-
kernel1=args.kernel1,
162-
kernel2=args.kernel2,
163-
kernel3=args.kernel3,
164-
embed_dim=args.embed_dim,
165-
n_filters=args.n_filters,
166-
)
167-
else:
168-
train_data, valid_data = get_egress_data(tasks)
169-
170-
hparams = Hparams(
171-
kernel1=args.kernel1,
172-
kernel2=args.kernel2,
173-
kernel3=args.kernel3,
174-
embed_dim=args.embed_dim,
175-
n_filters=args.n_filters,
176-
vocab_size=len(train_data.vocab)
177-
)
152+
args.device = torch.device(f'{"cuda"}' if args.cuda else "cpu")
153+
154+
train_data, valid_data = get_synthetic_data(args)
155+
156+
hparams = Hparams(
157+
kernel1=args.kernel1,
158+
kernel2=args.kernel2,
159+
kernel3=args.kernel3,
160+
embed_dim=args.embed_dim,
161+
n_filters=args.n_filters,
162+
)
178163

179164
train_loader = DataLoader(train_data, batch_size=args.batch_size)
180165
valid_loader = DataLoader(valid_data, batch_size=args.batch_size)

0 commit comments

Comments
 (0)