-
Notifications
You must be signed in to change notification settings - Fork 101
Description
- TIA Toolbox version: 1.5.1
- Python version: 3.11
- Operating System: Linux
Description
I think it would be useful to integrate pre-trained foundation models from other labs into tiatoolbox.models.architecture.vanilla.py.
Currently, the _get_architecture() function allows the use of models from torchvision.models.
But another function _get_timm_architecture() could be made to incorporate foundation models which are available from timm with weights on HuggingFace Hub. All the models from time that I've used require users to sign the licence agreement with the authors, so the licencing question seems to be solved itself since there is no way users will get access to the model weights just through Tiatoolbox without getting the access request approved by the authors first.
What I Did
To add them myself, I copied de definition of CNNBackbone changing
self.feat_extract = _get_timm_architecture(backbone)- removed global average pooling because given a batch of images, these pathology foundation models come ready to output a feature vector of size
(batch_size, embedding_size)
tiatoolbox/tiatoolbox/models/architecture/vanilla.py
Lines 176 to 270 in 015652c
| class CNNBackbone(ModelABC): | |
| """Retrieve the model backbone and strip the classification layer. | |
| This is a wrapper for pretrained models within pytorch. | |
| Args: | |
| backbone (str): | |
| Model name. Currently, the tool supports following | |
| model names and their default associated weights from pytorch. | |
| - "alexnet" | |
| - "resnet18" | |
| - "resnet34" | |
| - "resnet50" | |
| - "resnet101" | |
| - "resnext50_32x4d" | |
| - "resnext101_32x8d" | |
| - "wide_resnet50_2" | |
| - "wide_resnet101_2" | |
| - "densenet121" | |
| - "densenet161" | |
| - "densenet169" | |
| - "densenet201" | |
| - "inception_v3" | |
| - "googlenet" | |
| - "mobilenet_v2" | |
| - "mobilenet_v3_large" | |
| - "mobilenet_v3_small" | |
| Examples: | |
| >>> # Creating resnet50 architecture from default pytorch | |
| >>> # without the classification layer with its associated | |
| >>> # weights loaded | |
| >>> model = CNNBackbone(backbone="resnet50") | |
| >>> model.eval() # set to evaluation mode | |
| >>> # dummy sample in NHWC form | |
| >>> samples = torch.rand(4, 3, 512, 512) | |
| >>> features = model(samples) | |
| >>> features.shape # features after global average pooling | |
| torch.Size([4, 2048]) | |
| """ | |
| def __init__(self: CNNBackbone, backbone: str) -> None: | |
| """Initialize :class:`CNNBackbone`.""" | |
| super().__init__() | |
| self.feat_extract = _get_architecture(backbone) | |
| self.pool = nn.AdaptiveAvgPool2d((1, 1)) | |
| # pylint: disable=W0221 | |
| # because abc is generic, this is actual definition | |
| def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: | |
| """Pass input data through the model. | |
| Args: | |
| imgs (torch.Tensor): | |
| Model input. | |
| """ | |
| feat = self.feat_extract(imgs) | |
| gap_feat = self.pool(feat) | |
| return torch.flatten(gap_feat, 1) | |
| @staticmethod | |
| def infer_batch( | |
| model: nn.Module, | |
| batch_data: torch.Tensor, | |
| *, | |
| on_gpu: bool, | |
| ) -> list[np.ndarray, ...]: | |
| """Run inference on an input batch. | |
| Contains logic for forward operation as well as i/o aggregation. | |
| Args: | |
| model (nn.Module): | |
| PyTorch defined model. | |
| batch_data (torch.Tensor): | |
| A batch of data generated by | |
| `torch.utils.data.DataLoader`. | |
| on_gpu (bool): | |
| Whether to run inference on a GPU. | |
| """ | |
| img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( | |
| torch.float32, | |
| ) # to NCHW | |
| img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() | |
| # Inference mode | |
| model.eval() | |
| # Do not compute the gradient (not training) | |
| with torch.inference_mode(): | |
| output = model(img_patches_device) | |
| # Output should be a single tensor or scalar | |
| return [output.cpu().numpy()] |
Suggestion
Would you be interested in adding this functionality? If yes, I can make a pull request.