|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import json |
| 4 | +from glob import glob |
| 5 | + |
| 6 | +from sklearn.model_selection import train_test_split |
| 7 | + |
| 8 | +sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge") |
| 9 | +sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge") |
| 10 | +sys.path.append("../synapse_marker_detection") |
| 11 | + |
| 12 | +from utils.training.training import supervised_training # noqa |
| 13 | +from detection_dataset import DetectionDataset, MinPointSampler # noqa |
| 14 | + |
| 15 | +ROOT = "./la-vision-sgn-new" # noqa |
| 16 | + |
| 17 | +TRAIN = os.path.join(ROOT, "images") |
| 18 | +TRAIN_EMPTY = os.path.join(ROOT, "empty_images") |
| 19 | + |
| 20 | +LABEL = os.path.join(ROOT, "centroids") |
| 21 | +LABEL_EMPTY = os.path.join(ROOT, "empty_centroids") |
| 22 | + |
| 23 | + |
| 24 | +def _get_paths(split, train_folder, label_folder, n=None): |
| 25 | + image_paths = sorted(glob(os.path.join(train_folder, "*.tif"))) |
| 26 | + label_paths = sorted(glob(os.path.join(label_folder, "*.csv"))) |
| 27 | + assert len(image_paths) == len(label_paths) |
| 28 | + if n is not None: |
| 29 | + image_paths, label_paths = image_paths[:n], label_paths[:n] |
| 30 | + |
| 31 | + train_images, val_images, train_labels, val_labels = train_test_split( |
| 32 | + image_paths, label_paths, test_size=1, random_state=42 |
| 33 | + ) |
| 34 | + |
| 35 | + if split == "train": |
| 36 | + image_paths = train_images |
| 37 | + label_paths = train_labels |
| 38 | + else: |
| 39 | + image_paths = val_images |
| 40 | + label_paths = val_labels |
| 41 | + |
| 42 | + return image_paths, label_paths |
| 43 | + |
| 44 | + |
| 45 | +def get_paths(split): |
| 46 | + image_paths, label_paths = _get_paths(split, TRAIN, LABEL) |
| 47 | + empty_image_paths, empty_label_paths = _get_paths(split, TRAIN_EMPTY, LABEL_EMPTY, n=4) |
| 48 | + return image_paths + empty_image_paths, label_paths + empty_label_paths |
| 49 | + |
| 50 | + |
| 51 | +def train(): |
| 52 | + |
| 53 | + model_name = "sgn-low-res-detection-v1" |
| 54 | + |
| 55 | + train_paths, train_label_paths = get_paths("train") |
| 56 | + val_paths, val_label_paths = get_paths("val") |
| 57 | + # We need to give the paths for the test loader, although it's never used. |
| 58 | + test_paths, test_label_paths = val_paths, val_label_paths |
| 59 | + |
| 60 | + print("Start training with:") |
| 61 | + print(len(train_paths), "tomograms for training") |
| 62 | + print(len(val_paths), "tomograms for validation") |
| 63 | + |
| 64 | + patch_shape = [48, 256, 256] |
| 65 | + batch_size = 8 |
| 66 | + check = False |
| 67 | + |
| 68 | + checkpoint_path = f"./checkpoints/{model_name}" |
| 69 | + os.makedirs(checkpoint_path, exist_ok=True) |
| 70 | + with open(os.path.join(checkpoint_path, "splits.json"), "w") as f: |
| 71 | + json.dump( |
| 72 | + { |
| 73 | + "train": {"images": train_paths, "labels": train_label_paths}, |
| 74 | + "val": {"images": val_paths, "labels": val_label_paths}, |
| 75 | + }, |
| 76 | + f, indent=2, sort_keys=True |
| 77 | + ) |
| 78 | + |
| 79 | + supervised_training( |
| 80 | + name=model_name, |
| 81 | + train_paths=train_paths, |
| 82 | + train_label_paths=train_label_paths, |
| 83 | + val_paths=val_paths, |
| 84 | + val_label_paths=val_label_paths, |
| 85 | + raw_key=None, |
| 86 | + patch_shape=patch_shape, batch_size=batch_size, |
| 87 | + check=check, |
| 88 | + lr=1e-4, |
| 89 | + n_iterations=int(1e5), |
| 90 | + out_channels=1, |
| 91 | + augmentations=None, |
| 92 | + eps=1e-5, |
| 93 | + sigma=4, |
| 94 | + lower_bound=None, |
| 95 | + upper_bound=None, |
| 96 | + test_paths=test_paths, |
| 97 | + test_label_paths=test_label_paths, |
| 98 | + # save_root="", |
| 99 | + dataset_class=DetectionDataset, |
| 100 | + n_samples_train=3200, |
| 101 | + n_samples_val=160, |
| 102 | + sampler=MinPointSampler(min_points=1, p_reject=0.5), |
| 103 | + ) |
| 104 | + |
| 105 | + |
| 106 | +def main(): |
| 107 | + train() |
| 108 | + |
| 109 | + |
| 110 | +if __name__ == "__main__": |
| 111 | + main() |
0 commit comments