Skip to content
This repository was archived by the owner on Aug 30, 2023. It is now read-only.

Commit ce93f3a

Browse files
authored
Merge pull request #27 from mlgill/memory_mapped_data_loading
Memory mapped data loading
2 parents b7503d8 + 93a4a07 commit ce93f3a

File tree

5 files changed

+37
-8
lines changed

5 files changed

+37
-8
lines changed

prosit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import io
1+
from . import io_local
22
from . import constants
33
from . import model
44
from . import alignment

prosit/io_local.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from . import utils
2+
3+
4+
def get_array(tensor, keys):
5+
utils.check_mandatory_keys(tensor, keys)
6+
return [tensor[key] for key in keys]
7+
8+
9+
def to_hdf5(dictionary, path):
10+
import h5py
11+
12+
with h5py.File(path, "w") as f:
13+
for key, data in dictionary.items():
14+
f.create_dataset(key, data=data, dtype=data.dtype, compression="gzip")
15+
16+
17+
def from_hdf5(path, n_samples=None):
18+
from keras.utils import HDF5Matrix
19+
import h5py
20+
21+
# Get a list of the keys for the datasets
22+
with h5py.File(path, 'r') as f:
23+
dataset_list = list(f.keys())
24+
25+
# Assemble into a dictionary
26+
data = dict()
27+
for dataset in dataset_list:
28+
data[dataset] = HDF5Matrix(path, dataset, start=0, end=n_samples, normalizer=None)
29+
return data

prosit/prediction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import numpy as np
44

55
from . import model as model_lib
6-
from . import io
6+
from . import io_local
77
from . import constants
88
from . import sanitize
99

1010

1111
def predict(data, d_model):
1212
# check for mandatory keys
13-
x = io.get_array(data, d_model["config"]["x"])
13+
x = io_local.get_array(data, d_model["config"]["x"])
1414

1515
keras.backend.set_session(d_model["session"])
1616
with d_model["graph"].as_default():

prosit/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tensorflow as tf
88

99
from . import model
10-
from . import io
10+
from . import io_local
1111
from . import constants
1212
from . import tensorize
1313
from . import prediction

prosit/training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from . import io
3+
from . import io_local
44
from . import losses
55
from . import model as model_lib
66
from . import constants
@@ -28,8 +28,8 @@ def train(tensor, model, model_config, callbacks):
2828
else:
2929
loss = losses.get(model_config["loss"])
3030
optimizer = model_config["optimizer"]
31-
x = io.get_array(tensor, model_config["x"])
32-
y = io.get_array(tensor, model_config["y"])
31+
x = io_local.get_array(tensor, model_config["x"])
32+
y = io_local.get_array(tensor, model_config["y"])
3333
model.compile(optimizer=optimizer, loss=loss)
3434
model.fit(
3535
x=x,
@@ -48,6 +48,6 @@ def train(tensor, model, model_config, callbacks):
4848
model_dir = constants.MODEL_DIR
4949

5050
model, model_config = model_lib.load(model_dir, trained=True)
51-
tensor = io.from_hdf5(data_path)
51+
tensor = io_local.from_hdf5(data_path)
5252
callbacks = get_callbacks(model_dir)
5353
train(tensor, model, model_config, callbacks)

0 commit comments

Comments
 (0)