Skip to content

Commit 18285d6

Browse files
committed
Reuse df read from validator
1 parent 1784015 commit 18285d6

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

movement/io/load_bboxes.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def from_via_tracks_file(
342342

343343
# Create an xarray.Dataset from the data
344344
bboxes_arrays = _numpy_arrays_from_via_tracks_file(
345-
via_file.path, via_file.frame_regexp
345+
via_file.df, via_file.frame_regexp
346346
)
347347
ds = from_numpy(
348348
position_array=bboxes_arrays["position_array"],
@@ -369,9 +369,9 @@ def from_via_tracks_file(
369369

370370

371371
def _numpy_arrays_from_via_tracks_file(
372-
file_path: Path, frame_regexp: str = DEFAULT_FRAME_REGEXP
372+
df_in: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP
373373
) -> dict:
374-
"""Extract numpy arrays from the input VIA tracks .csv file.
374+
"""Extract numpy arrays from VIA tracks dataframe.
375375
376376
The extracted numpy arrays are returned in a dictionary with the following
377377
keys:
@@ -390,8 +390,9 @@ def _numpy_arrays_from_via_tracks_file(
390390
391391
Parameters
392392
----------
393-
file_path : pathlib.Path
394-
Path to the VIA tracks .csv file containing the bounding box tracks.
393+
df_in : pd.DataFrame
394+
Input dataframe obtained from directly loading a valid
395+
VIA tracks .csv file as a pandas dataframe.
395396
396397
frame_regexp : str
397398
Regular expression pattern to extract the frame number from the frame
@@ -408,7 +409,7 @@ def _numpy_arrays_from_via_tracks_file(
408409
# Extract 2D dataframe from input data
409410
# (sort data by ID and frame number, and
410411
# fill empty frame-ID pairs with nans)
411-
df = _df_from_via_tracks_file(file_path, frame_regexp)
412+
df = _df_from_via_tracks_df(df_in, frame_regexp)
412413

413414
# Extract arrays
414415
n_individuals = df["ID"].nunique()
@@ -444,12 +445,14 @@ def _numpy_arrays_from_via_tracks_file(
444445
return array_dict
445446

446447

447-
def _df_from_via_tracks_file(
448-
file_path: Path, frame_regexp: str = DEFAULT_FRAME_REGEXP
448+
def _df_from_via_tracks_df(
449+
df_in: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP
449450
) -> pd.DataFrame:
450-
"""Load VIA tracks .csv file as a dataframe.
451+
"""Extract dataframe from VIA tracks dataframe.
451452
452-
Read the VIA tracks .csv file as a pandas dataframe with columns:
453+
The VIA tracks dataframe is obtained from directly loading a valid
454+
VIA tracks .csv file as a pandas dataframe. The output dataframe contains
455+
the following columns:
453456
- ID: the integer ID of the tracked bounding box.
454457
- frame_number: the frame number of the tracked bounding box.
455458
- x: the x-coordinate of the tracked bounding box's top-left corner.
@@ -471,7 +474,7 @@ def _df_from_via_tracks_file(
471474
logger.info(
472475
"Parsing dataframe (this may take a few minutes for large files)..."
473476
)
474-
df = _parsed_df_from_file(file_path, frame_regexp)
477+
df = _parsed_df_from_via_tracks_df(df_in, frame_regexp)
475478
logger.info("Parsing complete.")
476479

477480
# Fill in missing combinations of ID and
@@ -481,21 +484,20 @@ def _df_from_via_tracks_file(
481484
return df
482485

483486

484-
def _parsed_df_from_file(
485-
file_path: Path, frame_regexp: str = DEFAULT_FRAME_REGEXP
487+
def _parsed_df_from_via_tracks_df(
488+
df: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP
486489
) -> pd.DataFrame:
487-
"""Compute parsed dataframe from input VIA tracks .csv file.
490+
"""Parse VIA tracks dataframe.
488491
489-
Parses dictionary-like string columns in input file, and casts
492+
Parses dictionary-like string columns in VIA tracks dataframe, and casts
490493
columns to the expected types. It returns a copy of the relevant subset
491-
of columns. Note that this function should run after validation of the
492-
input file with ValidVIATracksCSV.
494+
of columns.
493495
494496
Parameters
495497
----------
496-
file_path : pathlib.Path
497-
Path to the valid VIA tracks .csv file containing the bounding box
498-
tracks.
498+
df : pd.DataFrame
499+
Input dataframe obtained from directly loading a valid
500+
VIA tracks .csv file as a pandas dataframe.
499501
500502
frame_regexp : str, optional
501503
The regular expression to extract the frame number from the filename.
@@ -515,7 +517,7 @@ def _parsed_df_from_file(
515517
516518
"""
517519
# Read VIA tracks .csv file as a pandas dataframe
518-
df = pd.read_csv(file_path, sep=",", header=0)
520+
# df = pd.read_csv(file_path, sep=",", header=0)
519521

520522
# Loop thru rows of columns with dict-like data
521523
# (this is typically faster than iterrows())

movement/validators/files.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ class ValidVIATracksCSV:
390390

391391
path: Path = field(validator=validators.instance_of(Path))
392392
frame_regexp: str = DEFAULT_FRAME_REGEXP
393+
df: pd.DataFrame = field(init=False, factory=pd.DataFrame)
393394

394395
@path.validator
395396
def _file_contains_valid_header(self, attribute, value):
@@ -416,6 +417,9 @@ def _file_contains_valid_header(self, attribute, value):
416417
)
417418
)
418419

420+
# Read CSV once and store for later use
421+
self.df = pd.read_csv(value, sep=",", header=0)
422+
419423
@path.validator
420424
def _file_contains_valid_frame_numbers(self, attribute, value):
421425
"""Ensure that the VIA tracks .csv file contains valid frame numbers.
@@ -435,7 +439,7 @@ def _file_contains_valid_frame_numbers(self, attribute, value):
435439
file extension.
436440
437441
"""
438-
df = pd.read_csv(value, sep=",", header=0)
442+
df = self.df
439443

440444
# Extract list of file attributes (dicts)
441445
file_attributes_dicts = [json.loads(d) for d in df.file_attributes]
@@ -533,7 +537,7 @@ def _file_contains_tracked_bboxes(self, attribute, value):
533537
- Checking that the bounding boxes have a track ID defined.
534538
- Checking that the track ID can be cast as an integer.
535539
"""
536-
df = pd.read_csv(value, sep=",", header=0)
540+
df = self.df
537541

538542
for row in df.itertuples():
539543
row_region_shape_attrs = json.loads(row.region_shape_attributes)
@@ -596,7 +600,7 @@ def _file_contains_unique_track_ids_per_filename(self, attribute, value):
596600
597601
It checks that bounding boxes IDs are defined once per image file.
598602
"""
599-
df = pd.read_csv(value, sep=",", header=0)
603+
df = self.df
600604

601605
list_unique_filenames = list(set(df.filename))
602606
for file in list_unique_filenames:

0 commit comments

Comments
 (0)