|
| 1 | +"""Retention time calibration.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import logging |
| 6 | +from abc import ABC, abstractmethod |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from sklearn.linear_model import LinearRegression |
| 10 | +from sklearn.pipeline import make_pipeline |
| 11 | +from sklearn.preprocessing import SplineTransformer |
| 12 | + |
| 13 | +from deeplc._exceptions import CalibrationError |
| 14 | + |
| 15 | +LOGGER = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class Calibrator(ABC): |
| 19 | + """Abstract base class for retention time calibrators.""" |
| 20 | + |
| 21 | + @abstractmethod |
| 22 | + def __init__(self, *args, **kwargs): |
| 23 | + super().__init__() |
| 24 | + |
| 25 | + @abstractmethod |
| 26 | + def fit(measured_tr: np.ndarray, predicted_tr: np.ndarray) -> None: ... |
| 27 | + |
| 28 | + @abstractmethod |
| 29 | + def transform(tr: np.ndarray) -> np.ndarray: ... |
| 30 | + |
| 31 | + |
| 32 | +class PiecewiseLinearCalibrator(Calibrator): |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + split_cal: int = 50, |
| 36 | + bin_distance: float = 2.0, |
| 37 | + dict_cal_divider: int = 50, |
| 38 | + use_median: bool = True, |
| 39 | + ): |
| 40 | + """ |
| 41 | + Piece-wise linear calibration for retention time. |
| 42 | +
|
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + split_cal |
| 46 | + Number of splits. |
| 47 | + bin_distance |
| 48 | + Distance between bins. |
| 49 | + dict_cal_divider |
| 50 | + # TODO: Make more descriptive |
| 51 | + Divider for the dictionary used in the piece-wise linear model. |
| 52 | + use_median |
| 53 | + # TODO: Make more descriptive |
| 54 | + If True, use median instead of mean for calibration. |
| 55 | +
|
| 56 | + """ |
| 57 | + super().__init__() |
| 58 | + self.split_cal = split_cal |
| 59 | + self.bin_distance = bin_distance |
| 60 | + self.dict_cal_divider = dict_cal_divider |
| 61 | + self.use_median = use_median |
| 62 | + |
| 63 | + self._calibrate_min = None |
| 64 | + self._calibrate_max = None |
| 65 | + self._calibrate_dict = None |
| 66 | + self._fit = False |
| 67 | + |
| 68 | + def fit(self, measured_tr: np.ndarray, predicted_tr: np.ndarray) -> None: |
| 69 | + """ |
| 70 | + Fit a piece-wise linear model to the measured and predicted retention times. |
| 71 | +
|
| 72 | + Parameters |
| 73 | + ---------- |
| 74 | + measured_tr |
| 75 | + Measured retention times. |
| 76 | + predicted_tr |
| 77 | + Predicted retention times. |
| 78 | +
|
| 79 | + """ |
| 80 | + measured_tr, predicted_tr = _process_arrays(measured_tr, predicted_tr) |
| 81 | + |
| 82 | + mtr_mean = [] |
| 83 | + ptr_mean = [] |
| 84 | + |
| 85 | + calibrate_dict = {} |
| 86 | + calibrate_min = float("inf") |
| 87 | + calibrate_max = 0 |
| 88 | + |
| 89 | + LOGGER.debug( |
| 90 | + "Selecting the data points for calibration (used to fit the linear models between)" |
| 91 | + ) |
| 92 | + # smooth between observed and predicted |
| 93 | + split_val = predicted_tr[-1] / self.split_cal |
| 94 | + |
| 95 | + for range_calib_number in np.arange(0.0, predicted_tr[-1], split_val): |
| 96 | + ptr_index_start = np.argmax(predicted_tr >= range_calib_number) |
| 97 | + ptr_index_end = np.argmax(predicted_tr >= range_calib_number + split_val) |
| 98 | + |
| 99 | + # no points so no cigar... use previous points |
| 100 | + if ptr_index_start >= ptr_index_end: |
| 101 | + LOGGER.debug( |
| 102 | + "Skipping calibration step, due to no points in the " |
| 103 | + "predicted range (are you sure about the split size?): " |
| 104 | + "%s,%s", |
| 105 | + range_calib_number, |
| 106 | + range_calib_number + split_val, |
| 107 | + ) |
| 108 | + continue |
| 109 | + |
| 110 | + mtr = measured_tr[ptr_index_start:ptr_index_end] |
| 111 | + ptr = predicted_tr[ptr_index_start:ptr_index_end] |
| 112 | + |
| 113 | + if self.use_median: |
| 114 | + mtr_mean.append(np.median(mtr)) |
| 115 | + ptr_mean.append(np.median(ptr)) |
| 116 | + else: |
| 117 | + mtr_mean.append(sum(mtr) / len(mtr)) |
| 118 | + ptr_mean.append(sum(ptr) / len(ptr)) |
| 119 | + |
| 120 | + LOGGER.debug("Fitting the linear models between the points") |
| 121 | + |
| 122 | + if self.split_cal >= len(measured_tr): |
| 123 | + raise CalibrationError( |
| 124 | + f"Not enough measured tr ({len(measured_tr)}) for the chosen number of splits " |
| 125 | + f"({self.split_cal}). Choose a smaller split_cal parameter or provide more " |
| 126 | + "peptides for fitting the calibration curve." |
| 127 | + ) |
| 128 | + if len(mtr_mean) == 0: |
| 129 | + raise CalibrationError("The measured tr list is empty, not able to calibrate") |
| 130 | + if len(ptr_mean) == 0: |
| 131 | + raise CalibrationError("The predicted tr list is empty, not able to calibrate") |
| 132 | + |
| 133 | + # calculate calibration curves |
| 134 | + for i in range(0, len(ptr_mean)): |
| 135 | + if i >= len(ptr_mean) - 1: |
| 136 | + continue |
| 137 | + delta_ptr = ptr_mean[i + 1] - ptr_mean[i] |
| 138 | + delta_mtr = mtr_mean[i + 1] - mtr_mean[i] |
| 139 | + |
| 140 | + slope = delta_mtr / delta_ptr |
| 141 | + intercept = (-1 * (ptr_mean[i] * slope)) + mtr_mean[i] |
| 142 | + |
| 143 | + # optimized predictions using a dict to find calibration curve very fast |
| 144 | + for v in np.arange( |
| 145 | + round(ptr_mean[i], self.bin_distance), |
| 146 | + round(ptr_mean[i + 1], self.bin_distance), |
| 147 | + 1 / ((self.bin_distance) * self.dict_cal_divider), |
| 148 | + ): |
| 149 | + if v < calibrate_min: |
| 150 | + calibrate_min = v |
| 151 | + if v > calibrate_max: |
| 152 | + calibrate_max = v |
| 153 | + calibrate_dict[str(round(v, self.bin_distance))] = (slope, intercept) |
| 154 | + |
| 155 | + self._calibrate_min = calibrate_min |
| 156 | + self._calibrate_max = calibrate_max |
| 157 | + self._calibrate_dict = calibrate_dict |
| 158 | + |
| 159 | + self._fit = True |
| 160 | + |
| 161 | + def transform(self, tr: np.ndarray) -> np.ndarray: |
| 162 | + """ |
| 163 | + Transform the predicted retention times using the fitted piece-wise linear model. |
| 164 | +
|
| 165 | + Parameters |
| 166 | + ---------- |
| 167 | + tr |
| 168 | + Retention times to be transformed. |
| 169 | +
|
| 170 | + Returns |
| 171 | + ------- |
| 172 | + np.ndarray |
| 173 | + Transformed retention times. |
| 174 | + """ |
| 175 | + if not self._fit: |
| 176 | + raise CalibrationError( |
| 177 | + "The model has not been fitted yet. Please call fit() before transform()." |
| 178 | + ) |
| 179 | + |
| 180 | + if tr.shape[0] == 0: |
| 181 | + return np.array([]) |
| 182 | + |
| 183 | + # TODO: Can this be vectorized? |
| 184 | + cal_preds = [] |
| 185 | + for uncal_pred in tr: |
| 186 | + try: |
| 187 | + slope, intercept = self.cal_dict[str(round(uncal_pred, self.bin_distance))] |
| 188 | + cal_preds.append(slope * (uncal_pred) + intercept) |
| 189 | + except KeyError: |
| 190 | + # outside of the prediction range ... use the last |
| 191 | + # calibration curve |
| 192 | + if uncal_pred <= self.cal_min: |
| 193 | + slope, intercept = self.cal_dict[str(round(self.cal_min, self.bin_distance))] |
| 194 | + cal_preds.append(slope * (uncal_pred) + intercept) |
| 195 | + elif uncal_pred >= self.cal_max: |
| 196 | + slope, intercept = self.cal_dict[str(round(self.cal_max, self.bin_distance))] |
| 197 | + cal_preds.append(slope * (uncal_pred) + intercept) |
| 198 | + else: |
| 199 | + slope, intercept = self.cal_dict[str(round(self.cal_max, self.bin_distance))] |
| 200 | + cal_preds.append(slope * (uncal_pred) + intercept) |
| 201 | + |
| 202 | + return np.array(cal_preds) |
| 203 | + |
| 204 | + |
| 205 | +class SplineTransformerCalibrator(Calibrator): |
| 206 | + def __init__(self): |
| 207 | + """SplineTransformer calibration for retention time.""" |
| 208 | + super().__init__() |
| 209 | + self._calibrate_min = None |
| 210 | + self._calibrate_max = None |
| 211 | + self._linear_model_left = None |
| 212 | + self._spline_model = None |
| 213 | + self._linear_model_right = None |
| 214 | + |
| 215 | + self._fit = False |
| 216 | + |
| 217 | + def fit( |
| 218 | + self, |
| 219 | + measured_tr: np.ndarray, |
| 220 | + predicted_tr: np.ndarray, |
| 221 | + simplified: bool = False, # TODO: Move to __init__? |
| 222 | + ) -> None: |
| 223 | + """ |
| 224 | + Fit the SplineTransformer model to the measured and predicted retention times. |
| 225 | +
|
| 226 | + Parameters |
| 227 | + ---------- |
| 228 | + measured_tr |
| 229 | + Measured retention times. |
| 230 | + predicted_tr |
| 231 | + Predicted retention times. |
| 232 | + simplified |
| 233 | + If True, use a simplified model with fewer knots and a linear model. |
| 234 | + If False, use a more complex model with more knots and a spline model. |
| 235 | +
|
| 236 | + """ |
| 237 | + measured_tr, predicted_tr = _process_arrays(measured_tr, predicted_tr) |
| 238 | + |
| 239 | + # Fit a SplineTransformer model |
| 240 | + if simplified: |
| 241 | + spline = SplineTransformer(degree=2, n_knots=10) |
| 242 | + linear_model = LinearRegression() |
| 243 | + linear_model.fit(predicted_tr.reshape(-1, 1), measured_tr) |
| 244 | + |
| 245 | + linear_model_left = linear_model |
| 246 | + # TODO @RobbinBouwmeester: Should this be the spline model? |
| 247 | + spline_model = linear_model |
| 248 | + linear_model_right = linear_model |
| 249 | + else: |
| 250 | + spline = SplineTransformer(degree=4, n_knots=int(len(measured_tr) / 500) + 5) |
| 251 | + spline_model = make_pipeline(spline, LinearRegression()) |
| 252 | + spline_model.fit(predicted_tr.reshape(-1, 1), measured_tr) |
| 253 | + |
| 254 | + # Determine the top 10% of data on either end |
| 255 | + n_top = int(len(predicted_tr) * 0.1) |
| 256 | + |
| 257 | + # Fit a linear model on the bottom 10% (left-side extrapolation) |
| 258 | + X_left = predicted_tr[:n_top] |
| 259 | + y_left = measured_tr[:n_top] |
| 260 | + linear_model_left = LinearRegression() |
| 261 | + linear_model_left.fit(X_left.reshape(-1, 1), y_left) |
| 262 | + |
| 263 | + # Fit a linear model on the top 10% (right-side extrapolation) |
| 264 | + X_right = predicted_tr[-n_top:] |
| 265 | + y_right = measured_tr[-n_top:] |
| 266 | + linear_model_right = LinearRegression() |
| 267 | + linear_model_right.fit(X_right.reshape(-1, 1), y_right) |
| 268 | + |
| 269 | + self._calibrate_min = min(predicted_tr) |
| 270 | + self._calibrate_max = max(predicted_tr) |
| 271 | + self._linear_model_left = linear_model_left |
| 272 | + self._spline_model = spline_model |
| 273 | + self._linear_model_right = linear_model_right |
| 274 | + |
| 275 | + self._fit = True |
| 276 | + |
| 277 | + def transform(self, tr: np.ndarray) -> np.ndarray: |
| 278 | + """ |
| 279 | + Transform the predicted retention times using the fitted SplineTransformer model. |
| 280 | +
|
| 281 | + Parameters |
| 282 | + ---------- |
| 283 | + tr |
| 284 | + Retention times to be transformed. |
| 285 | +
|
| 286 | + Returns |
| 287 | + ------- |
| 288 | + np.ndarray |
| 289 | + Transformed retention times. |
| 290 | + """ |
| 291 | + if not self._fit: |
| 292 | + raise CalibrationError( |
| 293 | + "The model has not been fitted yet. Please call fit() before transform()." |
| 294 | + ) |
| 295 | + |
| 296 | + if tr.shape[0] == 0: |
| 297 | + return np.array([]) |
| 298 | + |
| 299 | + y_pred_spline = self._spline_model.predict(tr.reshape(-1, 1)) |
| 300 | + y_pred_left = self._linear_model_left.predict(tr.reshape(-1, 1)) |
| 301 | + y_pred_right = self._linear_model_right.predict(tr.reshape(-1, 1)) |
| 302 | + |
| 303 | + # Use spline model within the range of X |
| 304 | + within_range = (tr >= self.cal_min) & (tr <= self.cal_max) |
| 305 | + within_range = within_range.ravel() # Ensure this is a 1D array for proper indexing |
| 306 | + |
| 307 | + # Create a prediction array initialized with spline predictions |
| 308 | + cal_preds = np.copy(y_pred_spline) |
| 309 | + |
| 310 | + # Replace predictions outside the range with the linear model predictions |
| 311 | + cal_preds[~within_range & (tr.ravel() < self.cal_min)] = y_pred_left[ |
| 312 | + ~within_range & (tr.ravel() < self.cal_min) |
| 313 | + ] |
| 314 | + cal_preds[~within_range & (tr.ravel() > self.cal_max)] = y_pred_right[ |
| 315 | + ~within_range & (tr.ravel() > self.cal_max) |
| 316 | + ] |
| 317 | + |
| 318 | + return np.array(cal_preds) |
| 319 | + |
| 320 | + |
| 321 | +def _process_arrays( |
| 322 | + measured_tr: np.ndarray, |
| 323 | + predicted_tr: np.ndarray, |
| 324 | +) -> tuple[np.ndarray, np.ndarray]: |
| 325 | + """Process the measured and predicted retention times.""" |
| 326 | + # Check array lengths |
| 327 | + if len(measured_tr) != len(predicted_tr): |
| 328 | + raise ValueError( |
| 329 | + f"Measured and predicted retention times must have the same length. " |
| 330 | + f"Got {len(measured_tr)} and {len(predicted_tr)}." |
| 331 | + ) |
| 332 | + |
| 333 | + # Ensure both arrays are 1D |
| 334 | + if len(measured_tr.shape) > 1: |
| 335 | + measured_tr = measured_tr.flatten() |
| 336 | + if len(predicted_tr.shape) > 1: |
| 337 | + predicted_tr = predicted_tr.flatten() |
| 338 | + |
| 339 | + # Sort arrays by predicted_tr |
| 340 | + indices = np.argsort(predicted_tr) |
| 341 | + measured_tr = np.array(measured_tr, dtype=np.float32)[indices] |
| 342 | + predicted_tr = np.array(predicted_tr, dtype=np.float32)[indices] |
| 343 | + |
| 344 | + return measured_tr, predicted_tr |
0 commit comments