|
2 | 2 | import p3b7 as bmk
|
3 | 3 | import candle
|
4 | 4 |
|
5 |
| -import numpy as np |
| 5 | +import pandas as pd |
| 6 | +from pathlib import Path |
6 | 7 |
|
7 |
| -import torch.nn as nn |
8 | 8 | from torch.utils.data import DataLoader
|
9 | 9 |
|
10 | 10 | from data import P3B3, Egress
|
|
13 | 13 | from meters import AccuracyMeter
|
14 | 14 | from metrics import F1Meter
|
15 | 15 |
|
16 |
| -from prune import ( |
17 |
| - negative_prune, min_max_prune, |
18 |
| - create_prune_masks, remove_prune_masks |
19 |
| -) |
| 16 | +from prune import create_prune_masks, remove_prune_masks |
20 | 17 |
|
21 | 18 |
|
22 | 19 | TASKS = {
|
@@ -134,47 +131,35 @@ def evaluate(model, loader, device):
|
134 | 131 |
|
135 | 132 | accmeter.update_accuracy()
|
136 | 133 |
|
137 |
| - print(f'Validation accuracy:') |
| 134 | + print(f'{"Validation accuracy:"}') |
138 | 135 | accmeter.print_task_accuracies()
|
139 | 136 |
|
140 | 137 | loss /= len(loader.dataset)
|
141 | 138 |
|
142 | 139 | return loss
|
143 | 140 |
|
144 | 141 |
|
145 |
| -def save_dataframe(metrics, filename): |
| 142 | +def save_dataframe(metrics, filename, args): |
146 | 143 | """Save F1 metrics"""
|
147 | 144 | df = pd.DataFrame(metrics, index=[0])
|
148 |
| - path = Path(ARGS.savepath).joinpath(f'f1/{filename}.csv') |
| 145 | + path = Path(args.savepath).joinpath(f'f1/{filename}.csv') |
149 | 146 | df.to_csv(path, index=False)
|
150 | 147 |
|
151 | 148 |
|
152 | 149 | def run(args):
|
153 | 150 | args = candle.ArgumentStruct(**args)
|
154 | 151 | args.cuda = torch.cuda.is_available()
|
155 |
| - args.device = torch.device(f"cuda" if args.cuda else "cpu") |
156 |
| - |
157 |
| - if args.use_synthetic_data: |
158 |
| - train_data, valid_data = get_synthetic_data(args) |
159 |
| - |
160 |
| - hparams = Hparams( |
161 |
| - kernel1=args.kernel1, |
162 |
| - kernel2=args.kernel2, |
163 |
| - kernel3=args.kernel3, |
164 |
| - embed_dim=args.embed_dim, |
165 |
| - n_filters=args.n_filters, |
166 |
| - ) |
167 |
| - else: |
168 |
| - train_data, valid_data = get_egress_data(tasks) |
169 |
| - |
170 |
| - hparams = Hparams( |
171 |
| - kernel1=args.kernel1, |
172 |
| - kernel2=args.kernel2, |
173 |
| - kernel3=args.kernel3, |
174 |
| - embed_dim=args.embed_dim, |
175 |
| - n_filters=args.n_filters, |
176 |
| - vocab_size=len(train_data.vocab) |
177 |
| - ) |
| 152 | + args.device = torch.device(f'{"cuda"}' if args.cuda else "cpu") |
| 153 | + |
| 154 | + train_data, valid_data = get_synthetic_data(args) |
| 155 | + |
| 156 | + hparams = Hparams( |
| 157 | + kernel1=args.kernel1, |
| 158 | + kernel2=args.kernel2, |
| 159 | + kernel3=args.kernel3, |
| 160 | + embed_dim=args.embed_dim, |
| 161 | + n_filters=args.n_filters, |
| 162 | + ) |
178 | 163 |
|
179 | 164 | train_loader = DataLoader(train_data, batch_size=args.batch_size)
|
180 | 165 | valid_loader = DataLoader(valid_data, batch_size=args.batch_size)
|
|
0 commit comments