diff --git a/examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py b/examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py index b225d17c6..15a2e22cd 100644 --- a/examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py +++ b/examples/DynaCLR/DynaCLR-classical-sampling/create_pseudo_tracks.py @@ -9,14 +9,12 @@ # %% create training and validation dataset # TODO: Modify path to the input data -input_data_path = "/training_data.zarr" -# TODO: Modify path to the output data -track_data_path = "/training_data_tracks.zarr" - +input_track_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/1-preprocess/label-free/3-track/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_cropped.zarr" +output_track_path = "/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/classical/data/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_classical_fake_tracks.zarr" # TODO: Modify the channel name to the one you are using for the segmentation mask -segmentation_channel_name = "Nucl_mask" +segmentation_channel_name = "nuclei_prediction_labels_labels" # TODO: Modify the z-slice to the one you are using for the segmentation mask -Z_SLICE = 30 +Z_SLICE = 0 # %% """ Add csvs with fake tracking to tracking data. @@ -67,14 +65,13 @@ def save_track_df(track_df, well_id, pos_name, out_path): def main(): # Load the input segmentation data zarr_input = open_ome_zarr( - input_data_path, - layout="hcs", - mode="r+", + input_track_path, + mode="r", ) chan_names = zarr_input.channel_names - assert ( - segmentation_channel_name in chan_names - ), "Channel name not found in the input data" + assert segmentation_channel_name in chan_names, ( + "Channel name not found in the input data" + ) # Create the empty store for the tracking data position_names = [] @@ -82,23 +79,19 @@ def main(): position_names.append(tuple(ds.split("/"))) create_empty_plate( - store_path=track_data_path, + store_path=output_track_path, position_keys=position_names, - channel_names=segmentation_channel_name, + channel_names=[segmentation_channel_name], shape=(1, 1, 1, *position.data.shape[3:]), chunks=position.data.chunks, scale=position.scale, ) - + # # Populate the tracking data - with open_ome_zarr(track_data_path, layout="hcs", mode="r+") as track_store: + with open_ome_zarr(output_track_path, layout="hcs", mode="r+") as track_store: # Create progress bar for wells and positions for well_id, well_data in tqdm(zarr_input.wells(), desc="Processing wells"): - for pos_name, pos_data in tqdm( - well_data.positions(), - desc=f"Processing positions in {well_id}", - leave=False, - ): + for pos_name, pos_data in well_data.positions(): data = pos_data.data T, C, Z, Y, X = data.shape track_df_all = pd.DataFrame() @@ -110,7 +103,7 @@ def main(): track_pos["0"][0, 0, 0] = seg_mask track_df = create_track_df(seg_mask, time) track_df_all = pd.concat([track_df_all, track_df]) - save_track_df(track_df_all, well_id, pos_name, track_data_path) + save_track_df(track_df_all, well_id, pos_name, output_track_path) zarr_input.close()