-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcnn_predictor.py
More file actions
33 lines (24 loc) · 1.09 KB
/
cnn_predictor.py
File metadata and controls
33 lines (24 loc) · 1.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from typing import Tuple, Union
import numpy as np
import pandas as pd
from models.cnn_model import CnnModel
from preprocessing.preprocess_data import features
from preprocessing.utils import label_class_correspondence, scale_data
class CnnPredictor:
"""Given csv or xlsx file of data from LHCb, predicts the particle type."""
def __init__(self):
self.model = CnnModel()
self.model.load_weights()
def predict(self, filename):
"""Predict on the data file."""
if 'csv' in filename:
data = pd.read_csv(filename)
elif 'xlsx' in filename:
data = pd.read_xlsx(filename)
data = scale(data, features).values
pred, _ = self.model.predict_mlp(data.reshape(-1, 49, 1))
prediction = pandas.DataFrame({'ID': ids})
for name in ['Ghost', 'Electron', 'Muon', 'Pion', 'Kaon', 'Proton']:
prediction[name] = pred[:, label_class_correspondence[name]]
prediction.to_csv('predictions.csv.gz', index=False, float_format='%.5f', compression="gzip")
return FileLink('predictions.csv.gz')