|
| 1 | +import torch |
| 2 | + |
| 3 | +from torch_staintools.functional.conversion.od import rgb2od |
| 4 | +from torch_staintools.functional.optimization.solver import coord_descent, ista, fista |
| 5 | +from torch_staintools.functional.optimization.sparse_util import initialize_code, METHOD_FACTORIZE, _batch_supported |
| 6 | +from torch_staintools.functional.utility import transpose_trailing |
| 7 | + |
| 8 | + |
| 9 | +def get_concentrations_single(od_flatten, stain_matrix, regularizer=0.01, |
| 10 | + method: METHOD_FACTORIZE = 'fista', |
| 11 | + rng: torch.Generator = None, |
| 12 | + positive: bool = False, |
| 13 | + ): |
| 14 | + """Helper function to estimate concentration matrix given an image and stain matrix with shape: 2 x (H*W) |
| 15 | +
|
| 16 | + For solvers without batch support. Inputs are individual data points from a batch |
| 17 | +
|
| 18 | + Args: |
| 19 | + od_flatten: Flattened optical density vectors in shape of (H*W) x C (H and W dimensions flattened). |
| 20 | + stain_matrix: the computed stain matrices in shape of num_stain x input channel |
| 21 | + regularizer: regularization term if ISTA algorithm is used |
| 22 | + method: which method to compute the concentration: coordinate descent ('cd') or iterative-shrinkage soft |
| 23 | + thresholding algorithm ('ista') |
| 24 | + rng: torch.Generator for random initializations |
| 25 | + positive: enforce positive concentration |
| 26 | + Returns: |
| 27 | + computed concentration: num_stains x num_pixel_in_tissue_mask |
| 28 | + """ |
| 29 | + z0 = initialize_code(od_flatten, stain_matrix.T, 'zero', rng=rng) |
| 30 | + match method: |
| 31 | + case 'cd': |
| 32 | + return coord_descent(od_flatten, z0, stain_matrix.T, alpha=regularizer, positive_code=positive).T |
| 33 | + case 'ista': |
| 34 | + return ista(od_flatten, z0, stain_matrix.T, alpha=regularizer, positive_code=positive).T |
| 35 | + case 'fista': |
| 36 | + return fista(od_flatten, z0, stain_matrix.T, alpha=regularizer, positive_code=positive).T |
| 37 | + case 'ls': |
| 38 | + return torch.linalg.lstsq(stain_matrix.T, od_flatten.T)[0].T |
| 39 | + |
| 40 | + raise NotImplementedError(f"{method} is not a valid optimizer") |
| 41 | + |
| 42 | + |
| 43 | +def get_concentration_one_by_one(od_flatten, stain_matrix, regularizer, algorithm, rng): |
| 44 | + result = list() |
| 45 | + for od_single, stain_mat_single in zip(od_flatten, stain_matrix): |
| 46 | + result.append(get_concentrations_single(od_single, stain_mat_single, regularizer, algorithm, rng=rng)) |
| 47 | + # get_concentrations_helper(od_flatten, stain_matrix, regularizer, method) |
| 48 | + return torch.stack(result) |
| 49 | + |
| 50 | + |
| 51 | +def _ls_batch(od_flatten, stain_matrix): |
| 52 | + """Use least square to solve the factorization for concentration. |
| 53 | +
|
| 54 | + Warnings: |
| 55 | + May fail on GPU for individual large input in cuSolver backend (e.g., 1000 x 1000), regardless of batch size. |
| 56 | + Better for multiple small inputs in terms of H and W. |
| 57 | + Magma backend may work: torch.backends.cuda.preferred_linalg_library('magma') |
| 58 | +
|
| 59 | + Args: |
| 60 | + od_flatten: B * (HW) x num_input_channel |
| 61 | + stain_matrix: B x num_stains x num_input_channel |
| 62 | +
|
| 63 | + Returns: |
| 64 | + concentration B x num_stains x (HW) |
| 65 | + """ |
| 66 | + return torch.linalg.lstsq(transpose_trailing(stain_matrix), transpose_trailing(od_flatten))[0] |
| 67 | + |
| 68 | + |
| 69 | +def get_concentration_batch(od_flatten, stain_matrix, regularizer, algorithm, rng): |
| 70 | + assert algorithm in _batch_supported |
| 71 | + if not _batch_supported[algorithm]: |
| 72 | + return get_concentration_one_by_one(od_flatten, stain_matrix, regularizer, algorithm, rng) |
| 73 | + match algorithm: |
| 74 | + case 'ls': |
| 75 | + return _ls_batch(od_flatten, stain_matrix) |
| 76 | + case _: |
| 77 | + ... |
| 78 | + |
| 79 | + raise NotImplementedError('Currently only least-square (ls) is implemented as batch concentration solver') |
| 80 | + |
| 81 | + |
| 82 | +def get_concentrations(image, stain_matrix, regularizer=0.01, |
| 83 | + algorithm: METHOD_FACTORIZE = 'fista', |
| 84 | + rng: torch.Generator = None): |
| 85 | + """Estimate concentration matrix given an image and stain matrix. |
| 86 | +
|
| 87 | + Warnings: |
| 88 | + algorithm = 'ls' May fail on GPU for individual large input (e.g., 1000 x 1000), regardless of batch size. |
| 89 | + Better for multiple small inputs in terms of H and W. |
| 90 | + Args: |
| 91 | + image: batched image(s) in shape of BxCxHxW |
| 92 | + stain_matrix: B x num_stain x input channel |
| 93 | + regularizer: regularization term if ISTA algorithm is used |
| 94 | + algorithm: which method to compute the concentration: Solve min||HExC - OD||p |
| 95 | + support 'ista', 'cd', and 'ls'. 'ls' simply solves the least square problem for factorization of |
| 96 | + min||HExC - OD||F (Frobenius norm) but is faster. 'ista'/cd enforce the sparse penalty (L1 norm) but slower. |
| 97 | + rng: torch.Generator for random initializations |
| 98 | + Returns: |
| 99 | + concentration matrix: B x num_stains x num_pixel_in_tissue_mask |
| 100 | + """ |
| 101 | + device = image.device |
| 102 | + stain_matrix = stain_matrix.to(device) |
| 103 | + # BCHW |
| 104 | + od = rgb2od(image).to(device) |
| 105 | + # B (H*W) C |
| 106 | + od_flatten = od.flatten(start_dim=2, end_dim=-1).permute(0, 2, 1) |
| 107 | + return get_concentration_batch(od_flatten, stain_matrix, regularizer, algorithm, rng) |
0 commit comments