Skip to content

Commit f086f05

Browse files
committed
Improvements to validation and reduce duplication with loading
1 parent 18285d6 commit f086f05

File tree

3 files changed

+171
-300
lines changed

3 files changed

+171
-300
lines changed

movement/io/load_bboxes.py

Lines changed: 31 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Load bounding boxes tracking data into ``movement``."""
22

3-
import json
43
from pathlib import Path
54
from typing import Literal
65

@@ -341,9 +340,7 @@ def from_via_tracks_file(
341340
logger.info(f"Validated VIA tracks .csv file {via_file.path}.")
342341

343342
# Create an xarray.Dataset from the data
344-
bboxes_arrays = _numpy_arrays_from_via_tracks_file(
345-
via_file.df, via_file.frame_regexp
346-
)
343+
bboxes_arrays = _numpy_arrays_from_valid_file_object(via_file)
347344
ds = from_numpy(
348345
position_array=bboxes_arrays["position_array"],
349346
shape_array=bboxes_arrays["shape_array"],
@@ -368,10 +365,10 @@ def from_via_tracks_file(
368365
return ds
369366

370367

371-
def _numpy_arrays_from_via_tracks_file(
372-
df_in: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP
368+
def _numpy_arrays_from_valid_file_object(
369+
valid_via_file: ValidVIATracksCSV,
373370
) -> dict:
374-
"""Extract numpy arrays from VIA tracks dataframe.
371+
"""Extract numpy arrays from VIA tracks file object.
375372
376373
The extracted numpy arrays are returned in a dictionary with the following
377374
keys:
@@ -390,15 +387,8 @@ def _numpy_arrays_from_via_tracks_file(
390387
391388
Parameters
392389
----------
393-
df_in : pd.DataFrame
394-
Input dataframe obtained from directly loading a valid
395-
VIA tracks .csv file as a pandas dataframe.
396-
397-
frame_regexp : str
398-
Regular expression pattern to extract the frame number from the frame
399-
filename. By default, the frame number is expected to be encoded in
400-
the filename as an integer number led by at least one zero, followed
401-
by the file extension.
390+
valid_via_file : ValidVIATracksCSV
391+
A validated VIA tracks file object.
402392
403393
Returns
404394
-------
@@ -409,9 +399,9 @@ def _numpy_arrays_from_via_tracks_file(
409399
# Extract 2D dataframe from input data
410400
# (sort data by ID and frame number, and
411401
# fill empty frame-ID pairs with nans)
412-
df = _df_from_via_tracks_df(df_in, frame_regexp)
402+
df = _parsed_df_from_valid_file_object(valid_via_file)
413403

414-
# Extract arrays
404+
# Extract numpy arrays
415405
n_individuals = df["ID"].nunique()
416406
n_frames = df["frame_number"].nunique()
417407
all_data = df[["x", "y", "w", "h", "confidence"]].to_numpy()
@@ -445,160 +435,44 @@ def _numpy_arrays_from_via_tracks_file(
445435
return array_dict
446436

447437

448-
def _df_from_via_tracks_df(
449-
df_in: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP
438+
def _parsed_df_from_valid_file_object(
439+
valid_via_file: ValidVIATracksCSV,
450440
) -> pd.DataFrame:
451-
"""Extract dataframe from VIA tracks dataframe.
452-
453-
The VIA tracks dataframe is obtained from directly loading a valid
454-
VIA tracks .csv file as a pandas dataframe. The output dataframe contains
455-
the following columns:
456-
- ID: the integer ID of the tracked bounding box.
457-
- frame_number: the frame number of the tracked bounding box.
458-
- x: the x-coordinate of the tracked bounding box's top-left corner.
459-
- y: the y-coordinate of the tracked bounding box's top-left corner.
460-
- w: the width of the tracked bounding box.
461-
- h: the height of the tracked bounding box.
462-
- confidence: the confidence score of the tracked bounding box.
463-
464-
The dataframe is sorted by ID and frame number, and for each ID,
465-
empty frames are filled in with NaNs. The coordinates of the bboxes
466-
are assumed to be in the image coordinate system (i.e., the top-left
467-
corner of a bbox is its corner with minimum x and y coordinates).
468-
469-
The frame number is extracted from the filename using the provided
470-
regexp if it is not defined as a 'file_attribute' in the VIA tracks .csv
471-
file.
472-
"""
473-
# Parse input dataframe
474-
logger.info(
475-
"Parsing dataframe (this may take a few minutes for large files)..."
476-
)
477-
df = _parsed_df_from_via_tracks_df(df_in, frame_regexp)
478-
logger.info("Parsing complete.")
441+
"""Build a sorted DataFrame from a validated VIA file object.
479442
480-
# Fill in missing combinations of ID and
481-
# frame number if required
482-
df = _fill_in_missing_rows(df)
483-
484-
return df
485-
486-
487-
def _parsed_df_from_via_tracks_df(
488-
df: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP
489-
) -> pd.DataFrame:
490-
"""Parse VIA tracks dataframe.
491-
492-
Parses dictionary-like string columns in VIA tracks dataframe, and casts
493-
columns to the expected types. It returns a copy of the relevant subset
494-
of columns.
443+
Creates a DataFrame with ID, frame_number, x, y, w, h, and confidence
444+
columns. Missing (ID, frame_number) combinations are filled with NaNs,
445+
and the result is sorted by ID and frame_number.
495446
496447
Parameters
497448
----------
498-
df : pd.DataFrame
499-
Input dataframe obtained from directly loading a valid
500-
VIA tracks .csv file as a pandas dataframe.
501-
502-
frame_regexp : str, optional
503-
The regular expression to extract the frame number from the filename.
449+
valid_via_file : ValidVIATracksCSV
450+
A validated VIA tracks file object.
504451
505452
Returns
506453
-------
507454
pd.DataFrame
508-
The parsed dataframe with the following columns:
509-
- ID: the integer ID of the tracked bounding box.
510-
- frame_number: the frame number of the tracked bounding box.
511-
- x: the x-coordinate of the tracked bounding box's top-left corner.
512-
- y: the y-coordinate of the tracked bounding box's top-left corner.
513-
- w: the width of the tracked bounding box.
514-
- h: the height of the tracked bounding box.
515-
- confidence: the confidence score of the tracked bounding box, filled
516-
with NaN where not defined.
455+
Sorted DataFrame with all ID/frame combinations.
517456
518457
"""
519-
# Read VIA tracks .csv file as a pandas dataframe
520-
# df = pd.read_csv(file_path, sep=",", header=0)
521-
522-
# Loop thru rows of columns with dict-like data
523-
# (this is typically faster than iterrows())
524-
region_shapes = df["region_shape_attributes"].tolist()
525-
region_attrs = df["region_attributes"].tolist()
526-
file_attrs = df["file_attributes"].tolist()
527-
528-
# Initialize lists for results
529-
x, y, w, h, ids, confidences, frame_numbers = [], [], [], [], [], [], []
530-
for rs, ra, fa in zip(
531-
region_shapes, region_attrs, file_attrs, strict=True
532-
):
533-
# Regions shape
534-
region_shape_dict = json.loads(rs)
535-
x.append(region_shape_dict.get("x"))
536-
y.append(region_shape_dict.get("y"))
537-
w.append(region_shape_dict.get("width"))
538-
h.append(region_shape_dict.get("height"))
539-
540-
# Region attributes
541-
region_attrs_dict = json.loads(ra)
542-
ids.append(region_attrs_dict.get("track"))
543-
confidences.append(region_attrs_dict.get("confidence", np.nan))
544-
545-
# File attributes
546-
file_attrs_dict = json.loads(fa)
547-
# assigns None if not defined under "frame"
548-
frame_numbers.append(file_attrs_dict.get("frame"))
549-
550-
# Assign lists to dataframe
551-
df["x"], df["y"], df["w"], df["h"] = x, y, w, h
552-
df["ID"], df["confidence"] = ids, confidences
553-
554-
# After loop, handle frame_number for entire column at once
555-
if None not in frame_numbers:
556-
df["frame_number"] = frame_numbers
557-
else:
558-
df["frame_number"] = df["filename"].str.extract(
559-
frame_regexp, expand=False
560-
)
561-
562-
# Remove string columns to free memory
563-
df = df.drop(
564-
columns=[
565-
"region_shape_attributes",
566-
"region_attributes",
567-
"file_attributes",
568-
]
458+
# Build dataframe from file validator object data, then sort and reindex
459+
df = pd.DataFrame(
460+
{
461+
"ID": valid_via_file.ids,
462+
"frame_number": valid_via_file.frame_numbers,
463+
"x": np.array(valid_via_file.x, dtype=np.float32),
464+
"y": np.array(valid_via_file.y, dtype=np.float32),
465+
"w": np.array(valid_via_file.w, dtype=np.float32),
466+
"h": np.array(valid_via_file.h, dtype=np.float32),
467+
"confidence": np.array(
468+
valid_via_file.confidence_values, dtype=np.float32
469+
),
470+
}
569471
)
570472

571-
# Apply type conversions
572-
df["ID"] = df["ID"].astype(int)
573-
df["frame_number"] = df["frame_number"].astype(int)
574-
df[["x", "y", "w", "h", "confidence"]] = df[
575-
["x", "y", "w", "h", "confidence"]
576-
].astype(np.float32)
577-
578-
# Return relevant subset of columns as copy
579-
return df[["ID", "frame_number", "x", "y", "w", "h", "confidence"]].copy()
580-
581-
582-
def _fill_in_missing_rows(df: pd.DataFrame) -> pd.DataFrame:
583-
"""Add rows for missing (ID, frame_number) combinations and fill with NaNs.
584-
585-
Parameters
586-
----------
587-
df : pd.DataFrame
588-
The dataframe to fill in missing rows in.
589-
590-
Returns
591-
-------
592-
pd.DataFrame
593-
The dataframe with rows for previously missing (ID, frame_number)
594-
combinations added and filled in with NaNs. The dataframe is sorted
595-
by ID and frame number.
596-
597-
"""
598-
# Fill in missing rows if required
599473
# If every ID is defined for every frame:
600474
# just sort and reindex (does not add rows)
601-
if len(df) == len(df["ID"].unique()) * len(df["frame_number"].unique()):
475+
if len(df) == df["ID"].nunique() * df["frame_number"].nunique():
602476
df = df.sort_values(
603477
by=["ID", "frame_number"],
604478
axis=0,

0 commit comments

Comments
 (0)