|
1 | 1 | """Several image processing utilities.""" |
| 2 | + |
2 | 3 | from pathlib import Path |
3 | 4 | from warnings import warn |
4 | 5 |
|
@@ -837,3 +838,137 @@ def _start(self): |
837 | 838 | ) |
838 | 839 | else: |
839 | 840 | logger.warning("Please specify a layer or a folder") |
| 841 | + |
| 842 | + |
| 843 | +class ThresholdGridSearchUtils(BasePluginUtils): |
| 844 | + """Widget to run a grid search for thresholding.""" |
| 845 | + |
| 846 | + save_path = Path.home() / "cellseg3d" / "threshold_grid_search" |
| 847 | + |
| 848 | + def __init__(self, viewer: "napari.viewer.Viewer", parent=None): |
| 849 | + """Creates a ThresholdGridSearchUtils widget. |
| 850 | +
|
| 851 | + Args: |
| 852 | + viewer: viewer in which to process data |
| 853 | + parent: parent widget |
| 854 | + """ |
| 855 | + super().__init__( |
| 856 | + viewer, |
| 857 | + parent=parent, |
| 858 | + ) |
| 859 | + self.do_binarize = False |
| 860 | + self.result_text = "" |
| 861 | + self.values = {} |
| 862 | + |
| 863 | + self.data_panel = self._build_io_panel() |
| 864 | + # disable folder choice |
| 865 | + self.radio_buttons.setVisible(False) |
| 866 | + self.radio_buttons.setEnabled(False) |
| 867 | + |
| 868 | + self.image_layer_loader.layer_list.label.setText("Prediction :") |
| 869 | + self.label_layer_loader.layer_list.label.setText("Labels :") |
| 870 | + |
| 871 | + self.results_path = str(self.save_path) |
| 872 | + self.results_filewidget.text_field.setText(self.results_path) |
| 873 | + self.results_filewidget.check_ready() |
| 874 | + |
| 875 | + self.start_btn = ui.Button("Start", self._start) |
| 876 | + self.result_display = ui.make_label(self.result_text, self) |
| 877 | + self.image_layer_loader.layer_list.currentIndexChanged.connect( |
| 878 | + self._reset |
| 879 | + ) |
| 880 | + self.label_layer_loader.layer_list.currentIndexChanged.connect( |
| 881 | + self._reset |
| 882 | + ) |
| 883 | + |
| 884 | + self.container = self._build() |
| 885 | + |
| 886 | + def _build(self): |
| 887 | + container = ui.ContainerWidget() |
| 888 | + |
| 889 | + container.layout.addWidget(self.data_panel) |
| 890 | + ui.add_widgets( |
| 891 | + container.layout, |
| 892 | + [ |
| 893 | + self.start_btn, |
| 894 | + self.result_display, |
| 895 | + ], |
| 896 | + ) |
| 897 | + |
| 898 | + ui.ScrollArea.make_scrollable( |
| 899 | + container.layout, self, max_wh=[MAX_W, MAX_H] |
| 900 | + ) |
| 901 | + self._set_io_visibility() |
| 902 | + return container |
| 903 | + |
| 904 | + def _reset(self): |
| 905 | + self.values = {} |
| 906 | + self.result_text = "" |
| 907 | + self.result_display.setText("") |
| 908 | + |
| 909 | + def _check_ready(self): |
| 910 | + image_data = self.image_layer_loader.layer_data() |
| 911 | + label_data = self.label_layer_loader.layer_data() |
| 912 | + if image_data is None: |
| 913 | + self.result_display.setText("Please load a prediction layer") |
| 914 | + return False |
| 915 | + if label_data is None: |
| 916 | + self.result_display.setText("Please load a labels layer") |
| 917 | + return False |
| 918 | + if label_data.shape != image_data.shape: |
| 919 | + self.result_display.setText( |
| 920 | + "Prediction and labels must have the same shape" |
| 921 | + ) |
| 922 | + return False |
| 923 | + if ( |
| 924 | + label_data.min() < 0 |
| 925 | + or label_data.max() > 1 |
| 926 | + or len(np.unique(label_data)) != 2 |
| 927 | + ): |
| 928 | + self.do_binarize = True |
| 929 | + return True |
| 930 | + |
| 931 | + def _get_dice_graph(self): |
| 932 | + max_dice = max(self.values.values()) |
| 933 | + self.result_text += "Thre | Dice | Graph\n" |
| 934 | + for tr, dice in self.values.items(): |
| 935 | + bar = "°" * int((dice / max_dice) * 25) |
| 936 | + self.result_text += f"{tr:.2f} | {dice:.3f} | {bar}\n" |
| 937 | + |
| 938 | + def _start(self): |
| 939 | + utils.mkdir_from_str(self.results_path) |
| 940 | + if not self._check_ready(): |
| 941 | + return |
| 942 | + |
| 943 | + pred_data = self.image_layer_loader.layer_data().copy() |
| 944 | + label_data = self.label_layer_loader.layer_data().copy() |
| 945 | + if self.do_binarize: |
| 946 | + logger.info("Labels values are not binary, binarizing") |
| 947 | + label_data = to_semantic(label_data) |
| 948 | + # find best threshold |
| 949 | + search_space = np.arange(0, 1, 0.05) |
| 950 | + for i in search_space: |
| 951 | + i = i.round(2) |
| 952 | + binarized = threshold(pred_data, i) |
| 953 | + binarized = np.where(binarized > 0, 1, 0) |
| 954 | + dice = utils.dice_coeff(binarized, label_data).round(3) |
| 955 | + self.values[i] = dice |
| 956 | + logger.info(f"Threshold : {i}, Dice : {dice}") |
| 957 | + |
| 958 | + best_threshold = max(self.values, key=self.values.get) |
| 959 | + binarized = threshold(pred_data, best_threshold) |
| 960 | + utils.save_layer( |
| 961 | + self.results_path, |
| 962 | + f"binarized_{utils.get_date_time()}.tif", |
| 963 | + binarized, |
| 964 | + ) |
| 965 | + self.layer = utils.show_result( |
| 966 | + self._viewer, |
| 967 | + self.image_layer_loader.layer(), |
| 968 | + binarized, |
| 969 | + "binarized", |
| 970 | + existing_layer=self.layer, |
| 971 | + ) |
| 972 | + self.result_test = f"Best threshold : {best_threshold}, Dice : {self.values[best_threshold]}\n" |
| 973 | + self._get_dice_graph() |
| 974 | + self.result_display.setText(self.result_text) |
0 commit comments