Skip to content

Commit 305b96b

Browse files
Reduce gpu memory usage for sleap model...
1 parent 90e891d commit 305b96b

File tree

5 files changed

+44
-44
lines changed

5 files changed

+44
-44
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
This version of diplomat:
2-
- Enhanced skeletal transition curve that takes into account spread of body parts.
3-
- Support for converting DLC hdf5 and sleap hdf5 pose files to diplomat csvs, and passing them directly to the tweak and annotate commands.
1+
Reduces GPU memory usage when running inference on top of sleap.

diplomat/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software.
33
"""
44

5-
__version__ = "0.3.6"
5+
__version__ = "0.3.7"
66
# Can be used by functions to determine if diplomat was invoked through it's CLI interface.
77
CLI_RUN = False
88

diplomat/frontends/deeplabcut/load_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import functools
21
import tempfile
32
from io import BytesIO
43
from pathlib import Path
@@ -153,7 +152,6 @@ def _get_dlc_inputs_and_outputs(meta_path):
153152

154153
def _load_and_convert_model(model_dir: Path, device_index: Optional[int], use_cpu: bool):
155154
import tensorflow as tf
156-
from tensorflow.python.training import py_checkpoint_reader
157155
import tensorflow.compat.v1 as tf_v1
158156
import tf2onnx
159157
tf.compat.v1.disable_eager_execution()
@@ -166,7 +164,6 @@ def _load_and_convert_model(model_dir: Path, device_index: Optional[int], use_cp
166164
latest_meta_file = max(meta_files, key=lambda k: int(k.stem.split("-")[-1]))
167165

168166
inputs, outputs = _get_dlc_inputs_and_outputs(str(latest_meta_file))
169-
print(inputs, outputs)
170167

171168
graph_def, inputs, outputs = from_checkpoint(
172169
str(latest_meta_file), inputs, outputs

diplomat/frontends/sleap/run_utils.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
from .sleap_imports import tf, h5py
23
import json
34
import zipfile
@@ -127,24 +128,26 @@ def _load_configs_from_zip(z: zipfile.ZipFile, include_model = True):
127128

128129
@resolve_lazy_imports
129130
def _load_config_and_model(path, include_model = True):
130-
path = Path(path)
131-
if(zipfile.is_zipfile(path)):
132-
with zipfile.ZipFile(path, "r") as z:
133-
return _load_configs_from_zip(z, include_model)
134-
135-
if(path.is_dir()):
136-
path = path / "training_config.json"
137-
path = path.resolve()
138-
139-
with path.open("rb") as f:
140-
cfg = json.load(f)
141-
_correct_skeletons_in_config(cfg)
142-
model_path = _resolve_model_path(path.parent.iterdir())
143-
if(include_model):
144-
model = tf.keras.models.load_model(model_path, compile=False)
145-
return [(cfg, model)]
146-
else:
147-
return [cfg]
131+
device_ctx = tf.device('/cpu:0') if include_model else contextlib.nullcontext()
132+
with device_ctx:
133+
path = Path(path)
134+
if(zipfile.is_zipfile(path)):
135+
with zipfile.ZipFile(path, "r") as z:
136+
return _load_configs_from_zip(z, include_model)
137+
138+
if(path.is_dir()):
139+
path = path / "training_config.json"
140+
path = path.resolve()
141+
142+
with path.open("rb") as f:
143+
cfg = json.load(f)
144+
_correct_skeletons_in_config(cfg)
145+
model_path = _resolve_model_path(path.parent.iterdir())
146+
if(include_model):
147+
model = tf.keras.models.load_model(model_path, compile=False)
148+
return [(cfg, model)]
149+
else:
150+
return [cfg]
148151

149152

150153
def _load_configs(paths, include_models: bool = True):

diplomat/frontends/sleap/sleap_providers.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -239,31 +239,33 @@ def _reset_input_layer(
239239
Returns:
240240
A copy of `keras_model` with input shape `new_shape`.
241241
"""
242-
243-
if new_shape is None:
244-
new_shape = (None, None, None, keras_model.input_shape[-1])
245-
246-
model_config = keras_model.get_config()
247-
model_config["layers"][0]["config"]["batch_input_shape"] = new_shape
248-
new_model = tf.keras.Model.from_config(
249-
model_config, custom_objects={}
250-
) # Change custom objects if necessary
251-
252-
# Iterate over all the layers that we want to get weights from
253-
weights = [layer.get_weights() for layer in keras_model.layers]
254-
for layer, weight in zip(new_model.layers, weights):
255-
if len(weight) > 0:
256-
layer.set_weights(weight)
242+
with tf.device("/cpu:0"):
243+
if new_shape is None:
244+
new_shape = (None, None, None, keras_model.input_shape[-1])
245+
246+
model_config = keras_model.get_config()
247+
model_config["layers"][0]["config"]["batch_input_shape"] = new_shape
248+
new_model = tf.keras.Model.from_config(
249+
model_config, custom_objects={}
250+
) # Change custom objects if necessary
251+
252+
# Iterate over all the layers that we want to get weights from
253+
weights = [layer.get_weights() for layer in keras_model.layers]
254+
for layer, weight in zip(new_model.layers, weights):
255+
if len(weight) > 0:
256+
layer.set_weights(weight)
257257

258258
return new_model
259259

260260

261261
@resolve_lazy_imports
262262
def _keras_to_onnx_model(keras_model) -> onnx.ModelProto:
263-
input_signature = [
264-
tf.TensorSpec(keras_model.input_shape, tf.float32, name="image")
265-
]
266-
return tf2onnx.convert.from_keras(keras_model, input_signature, opset=17)[0]
263+
with tf.device("/cpu:0"):
264+
input_signature = [
265+
tf.TensorSpec(keras_model.input_shape, tf.float32, name="image")
266+
]
267+
mdl = tf2onnx.convert.from_keras(keras_model, input_signature, opset=17)[0]
268+
return mdl
267269

268270

269271
def _find_model_output(model: ort.InferenceSession, name: str, required: bool = True):

0 commit comments

Comments
 (0)