Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@

# %% create training and validation dataset
# TODO: Modify path to the input data
input_data_path = "/training_data.zarr"
# TODO: Modify path to the output data
track_data_path = "/training_data_tracks.zarr"

input_track_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/1-preprocess/label-free/3-track/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_cropped.zarr"
output_track_path = "/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/classical/data/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_classical_fake_tracks.zarr"
# TODO: Modify the channel name to the one you are using for the segmentation mask
segmentation_channel_name = "Nucl_mask"
segmentation_channel_name = "nuclei_prediction_labels_labels"
# TODO: Modify the z-slice to the one you are using for the segmentation mask
Z_SLICE = 30
Z_SLICE = 0
# %%
"""
Add csvs with fake tracking to tracking data.
Expand Down Expand Up @@ -67,38 +65,33 @@ def save_track_df(track_df, well_id, pos_name, out_path):
def main():
# Load the input segmentation data
zarr_input = open_ome_zarr(
input_data_path,
layout="hcs",
mode="r+",
input_track_path,
mode="r",
)
chan_names = zarr_input.channel_names
assert (
segmentation_channel_name in chan_names
), "Channel name not found in the input data"
assert segmentation_channel_name in chan_names, (
"Channel name not found in the input data"
)

# Create the empty store for the tracking data
position_names = []
for ds, position in zarr_input.positions():
position_names.append(tuple(ds.split("/")))

create_empty_plate(
store_path=track_data_path,
store_path=output_track_path,
position_keys=position_names,
channel_names=segmentation_channel_name,
channel_names=[segmentation_channel_name],
shape=(1, 1, 1, *position.data.shape[3:]),
chunks=position.data.chunks,
scale=position.scale,
)

#
# Populate the tracking data
with open_ome_zarr(track_data_path, layout="hcs", mode="r+") as track_store:
with open_ome_zarr(output_track_path, layout="hcs", mode="r+") as track_store:
# Create progress bar for wells and positions
for well_id, well_data in tqdm(zarr_input.wells(), desc="Processing wells"):
for pos_name, pos_data in tqdm(
well_data.positions(),
desc=f"Processing positions in {well_id}",
leave=False,
):
for pos_name, pos_data in well_data.positions():
data = pos_data.data
T, C, Z, Y, X = data.shape
track_df_all = pd.DataFrame()
Expand All @@ -110,7 +103,7 @@ def main():
track_pos["0"][0, 0, 0] = seg_mask
track_df = create_track_df(seg_mask, time)
track_df_all = pd.concat([track_df_all, track_df])
save_track_df(track_df_all, well_id, pos_name, track_data_path)
save_track_df(track_df_all, well_id, pos_name, output_track_path)
zarr_input.close()


Expand Down