Skip to content

Commit 5a5709d

Browse files
committed
Write machine labels when extracting annotated frames
1 parent 084437b commit 5a5709d

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

src/napari_deeplabcut/_widgets.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from datetime import datetime
44
from functools import partial
55
from math import ceil, log10
6+
import pandas as pd
7+
from pathlib import Path
68
from types import MethodType
79
from typing import Optional, Sequence, Union
810

@@ -36,8 +38,12 @@
3638

3739
from napari_deeplabcut import keypoints
3840
from napari_deeplabcut._reader import _load_config
39-
from napari_deeplabcut._writer import _write_config, _write_image
40-
from napari_deeplabcut.misc import encode_categories, to_os_dir_sep
41+
from napari_deeplabcut._writer import _write_config, _write_image, _form_df
42+
from napari_deeplabcut.misc import (
43+
encode_categories,
44+
to_os_dir_sep,
45+
guarantee_multiindex_rows,
46+
)
4147

4248

4349
def _get_and_try_preferred_reader(
@@ -257,19 +263,41 @@ def _form_video_action_menu(self):
257263
return extract_button, crop_button
258264

259265
def _extract_single_frame(self, *args):
260-
layer = None
261-
for layer_ in self.viewer.layers:
262-
if isinstance(layer_, Image):
263-
layer = layer_
264-
break
265-
if layer is not None:
266+
image_layer = None
267+
points_layer = None
268+
for layer in self.viewer.layers:
269+
if isinstance(layer, Image):
270+
image_layer = layer
271+
elif isinstance(layer, Points):
272+
points_layer = layer
273+
if image_layer is not None:
266274
ind = self.viewer.dims.current_step[0]
267-
frame = layer.data[ind]
268-
n_frames = layer.data.shape[0]
275+
frame = image_layer.data[ind]
276+
n_frames = image_layer.data.shape[0]
269277
name = f"img{str(ind).zfill(int(ceil(log10(n_frames))))}.png"
270-
output_path = os.path.join(layer.metadata["root"], name)
278+
output_path = os.path.join(image_layer.metadata["root"], name)
271279
_write_image(frame, str(output_path))
272280

281+
# If annotations were loaded, they should be written to a machinefile.h5 file
282+
if points_layer is not None:
283+
df = _form_df(
284+
points_layer.data,
285+
{
286+
"metadata": points_layer.metadata,
287+
"properties": points_layer.properties,
288+
},
289+
)
290+
df = df.iloc[ind:ind + 1]
291+
df.index = pd.MultiIndex.from_tuples([Path(output_path).parts[-3:]])
292+
filepath = os.path.join(image_layer.metadata["root"], "machinelabels-iter0.h5")
293+
if Path(filepath).is_file():
294+
df_prev = pd.read_hdf(filepath)
295+
guarantee_multiindex_rows(df_prev)
296+
df = pd.concat([df_prev, df])
297+
df = df[~df.index.duplicated(keep="first")]
298+
df.to_hdf(filepath, key="machinelabels")
299+
300+
273301
def _store_crop_coordinates(self, *args):
274302
if not (project_path := self._images_meta.get("project")):
275303
return

src/napari_deeplabcut/_writer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@ def _write_config(config_path: str, params: dict):
1717
yaml.safe_dump(params, file)
1818

1919

20-
def write_hdf(filename, data, metadata):
21-
file, _ = os.path.splitext(filename) # FIXME Unused currently
22-
temp = pd.DataFrame(data[:, -1:0:-1], columns=["x", "y"])
20+
def _form_df(points_data, metadata):
21+
temp = pd.DataFrame(points_data[:, -1:0:-1], columns=["x", "y"])
2322
properties = metadata["properties"]
2423
meta = metadata["metadata"]
2524
temp["bodyparts"] = properties["label"]
2625
temp["individuals"] = properties["id"]
27-
temp["inds"] = data[:, 0].astype(int)
26+
temp["inds"] = points_data[:, 0].astype(int)
2827
temp["likelihood"] = properties["likelihood"]
2928
temp["scorer"] = meta["header"].scorer
3029
df = temp.set_index(["scorer", "individuals", "bodyparts", "inds"]).stack()
@@ -40,7 +39,13 @@ def write_hdf(filename, data, metadata):
4039
if meta["paths"]:
4140
df.index = [meta["paths"][i] for i in df.index]
4241
misc.guarantee_multiindex_rows(df)
42+
return df
43+
4344

45+
def write_hdf(filename, data, metadata):
46+
file, _ = os.path.splitext(filename) # FIXME Unused currently
47+
df = _form_df(data, metadata)
48+
meta = metadata["metadata"]
4449
name = metadata["name"]
4550
root = meta["root"]
4651
if "machine" in name: # We are attempting to save refined model predictions

0 commit comments

Comments
 (0)