Skip to content

Integrate foundation models available through timm: UNI, Virchow, Hibou, H-optimus-0, etc.Β #855

@GeorgeBatch

Description

@GeorgeBatch
  • TIA Toolbox version: 1.5.1
  • Python version: 3.11
  • Operating System: Linux

Description

I think it would be useful to integrate pre-trained foundation models from other labs into tiatoolbox.models.architecture.vanilla.py.

Currently, the _get_architecture() function allows the use of models from torchvision.models.

But another function _get_timm_architecture() could be made to incorporate foundation models which are available from timm with weights on HuggingFace Hub. All the models from time that I've used require users to sign the licence agreement with the authors, so the licencing question seems to be solved itself since there is no way users will get access to the model weights just through Tiatoolbox without getting the access request approved by the authors first.

What I Did

To add them myself, I copied de definition of CNNBackbone changing

  1. self.feat_extract = _get_timm_architecture(backbone)
  2. removed global average pooling because given a batch of images, these pathology foundation models come ready to output a feature vector of size (batch_size, embedding_size)

class CNNBackbone(ModelABC):
"""Retrieve the model backbone and strip the classification layer.
This is a wrapper for pretrained models within pytorch.
Args:
backbone (str):
Model name. Currently, the tool supports following
model names and their default associated weights from pytorch.
- "alexnet"
- "resnet18"
- "resnet34"
- "resnet50"
- "resnet101"
- "resnext50_32x4d"
- "resnext101_32x8d"
- "wide_resnet50_2"
- "wide_resnet101_2"
- "densenet121"
- "densenet161"
- "densenet169"
- "densenet201"
- "inception_v3"
- "googlenet"
- "mobilenet_v2"
- "mobilenet_v3_large"
- "mobilenet_v3_small"
Examples:
>>> # Creating resnet50 architecture from default pytorch
>>> # without the classification layer with its associated
>>> # weights loaded
>>> model = CNNBackbone(backbone="resnet50")
>>> model.eval() # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.rand(4, 3, 512, 512)
>>> features = model(samples)
>>> features.shape # features after global average pooling
torch.Size([4, 2048])
"""
def __init__(self: CNNBackbone, backbone: str) -> None:
"""Initialize :class:`CNNBackbone`."""
super().__init__()
self.feat_extract = _get_architecture(backbone)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
# pylint: disable=W0221
# because abc is generic, this is actual definition
def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
"""Pass input data through the model.
Args:
imgs (torch.Tensor):
Model input.
"""
feat = self.feat_extract(imgs)
gap_feat = self.pool(feat)
return torch.flatten(gap_feat, 1)
@staticmethod
def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
on_gpu: bool,
) -> list[np.ndarray, ...]:
"""Run inference on an input batch.
Contains logic for forward operation as well as i/o aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.
"""
img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()
# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return [output.cpu().numpy()]

Suggestion

Would you be interested in adding this functionality? If yes, I can make a pull request.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions