diff --git a/.gitignore b/.gitignore index b6e4761..36892d0 100644 --- a/.gitignore +++ b/.gitignore @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ + +#Prithvi geospatial package +prithvi_geospatial_extractor/geospatial_fm.egg-info/ diff --git a/PDG_MAPLE/requirements.txt b/PDG_MAPLE/requirements.txt index 03c947a..a96145f 100644 --- a/PDG_MAPLE/requirements.txt +++ b/PDG_MAPLE/requirements.txt @@ -1 +1 @@ -pyclowder==2.4.0 \ No newline at end of file +pyclowder==3.0.7 \ No newline at end of file diff --git a/event_driven_ml_inference/requirements.txt b/event_driven_ml_inference/requirements.txt index 06653d2..768ca22 100644 --- a/event_driven_ml_inference/requirements.txt +++ b/event_driven_ml_inference/requirements.txt @@ -1,4 +1,4 @@ -pyclowder==2.6.0 +pyclowder==3.0.7 numpy ray[default]==1.13.0 keras diff --git a/parallel-batch-ml-inference-huggingface/requirements.txt b/parallel-batch-ml-inference-huggingface/requirements.txt index 4defea1..81bd385 100644 --- a/parallel-batch-ml-inference-huggingface/requirements.txt +++ b/parallel-batch-ml-inference-huggingface/requirements.txt @@ -1,4 +1,4 @@ -pyclowder==2.6.0 +pyclowder==3.0.7 ray[default]==1.13.0 numpy scipy diff --git a/parallel-batch-ml-inference-pytorch/requirements.txt b/parallel-batch-ml-inference-pytorch/requirements.txt index ed825b7..5cd3d0c 100644 --- a/parallel-batch-ml-inference-pytorch/requirements.txt +++ b/parallel-batch-ml-inference-pytorch/requirements.txt @@ -1,4 +1,4 @@ -pyclowder==2.6.0 +pyclowder==3.0.7 ray[default]==1.13.0 torchvision pillow \ No newline at end of file diff --git a/parallel_batch_ml_inference/parallel_ml_inference_extractor.py b/parallel_batch_ml_inference/parallel_ml_inference_extractor.py index 8892a78..7f2d201 100755 --- a/parallel_batch_ml_inference/parallel_ml_inference_extractor.py +++ b/parallel_batch_ml_inference/parallel_ml_inference_extractor.py @@ -8,6 +8,8 @@ import numpy as np import ray from ray.util.queue import Queue +import os +from PIL import UnidentifiedImageError import pyclowder.files from pyclowder.extractors import Extractor @@ -26,7 +28,11 @@ def process_file(self, filepaths): from tensorflow.keras.preprocessing import image # pre-process image - original = image.load_img(filepaths, target_size=(224, 224)) + try: + original = image.load_img(filepaths, target_size=(224, 224)) + except UnidentifiedImageError: + print("Unidentified Image Error") + return "Unidentified Image Error. Possible corrupted image, please replace image." numpy_image = image.img_to_array(original) image_batch = np.expand_dims(numpy_image, axis=0) processed_image = preprocess_input(image_batch, mode='caffe') @@ -57,14 +63,19 @@ def __init__(self): def process_message(self, connector, host, secret_key, resource, parameters): """Dataset extractor. We get all filenames at once.""" logger = logging.getLogger(__name__) - + # Get list of all files in dataset filelist = pyclowder.datasets.get_file_list(connector, host, secret_key, parameters['datasetId']) localfiles = [] + clowder_version = int(os.getenv('CLOWDER_VERSION', '1')) # # Loop through dataset and download all file "locally" for file_dict in filelist: - extension = "." + file_dict['contentType'].split("/")[1] + # Use the correct key depending on the Clowder version + if clowder_version == 2: + extension = "." + file_dict['content_type']['content_type'].split("/")[1] + else: + extension = "." + file_dict['contentType'].split("/")[1] localfiles.append(pyclowder.files.download(connector, host, secret_key, file_dict['id'], ext=extension)) # These process messages will appear in the Clowder UI under Extractions. diff --git a/parallel_batch_ml_inference/requirements.txt b/parallel_batch_ml_inference/requirements.txt index 06653d2..768ca22 100644 --- a/parallel_batch_ml_inference/requirements.txt +++ b/parallel_batch_ml_inference/requirements.txt @@ -1,4 +1,4 @@ -pyclowder==2.6.0 +pyclowder==3.0.7 numpy ray[default]==1.13.0 keras diff --git a/prithvi_geospatial_extractor/.dockerignore b/prithvi_geospatial_extractor/.dockerignore new file mode 100644 index 0000000..60889c7 --- /dev/null +++ b/prithvi_geospatial_extractor/.dockerignore @@ -0,0 +1,3 @@ +geospatial_fm.egg-info/ +build/ +dist/ \ No newline at end of file diff --git a/prithvi_geospatial_extractor/Dockerfile b/prithvi_geospatial_extractor/Dockerfile new file mode 100644 index 0000000..bcfe7ca --- /dev/null +++ b/prithvi_geospatial_extractor/Dockerfile @@ -0,0 +1,44 @@ +#Source - https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-sen1floods11-demo/blob/main/Dockerfile + +FROM python:3.8 + +RUN apt-get update && apt-get install --no-install-recommends -y \ + build-essential \ + wget \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y + +RUN useradd -m -u 1000 user + +# Switch to the "user" user +USER user +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH \ + PYTHONPATH=$HOME/app \ + PYTHONUNBUFFERED=1 \ + GRADIO_ALLOW_FLAGGING=never \ + GRADIO_NUM_PORTS=1 \ + GRADIO_SERVER_NAME=0.0.0.0 \ + GRADIO_THEME=huggingface \ + SYSTEM=spaces + +# RUN conda install python=3.8 + +RUN pip install setuptools-rust +RUN pip install torch==1.11.0+cu115 torchvision==0.12.0+cu115 --extra-index-url https://download.pytorch.org/whl/cu115 +RUN pip install gradio scikit-image pillow openmim +RUN pip install --upgrade setuptools + +WORKDIR /home/user + +WORKDIR /extractor + +COPY . . + +RUN pip install -e . + +RUN mim install mmcv-full==1.6.2 -f https://download.openmmlab.com/mmcv/dist/11.5/1.11.0/index.html + +CMD ["python3", "-u", "prithvi_finetuned_extractor.py", "--heartbeat", "15"] diff --git a/prithvi_geospatial_extractor/README.md b/prithvi_geospatial_extractor/README.md new file mode 100644 index 0000000..382f0bc --- /dev/null +++ b/prithvi_geospatial_extractor/README.md @@ -0,0 +1,58 @@ +# This extractor is best run via CodeFlare, see the top-level README for more. + +# Manual Docker (no CodeFlare) + +This extractor is ready to be run as a docker container, the only dependency is a running Clowder instance. Simply build and run. + +1. Start Clowder. For help starting Clowder, see our [getting started guide](https://github.com/clowder-framework/clowder/blob/develop/doc/src/sphinx/userguide/installing_clowder.rst). + +2. First build the extractor Docker container: + +``` +# from this directory, run: + +docker build -t prithvi-finetuned-extractor . +``` + +3. Finally run the extractor: + +``` +docker run -t -i --rm --net clowder_clowder -e "RABBITMQ_URI=amqp://guest:guest@rabbitmq:5672/%2f" --name "prithvi-finetuned-extractor" prithvi-finetuned-extractor +``` + +Then open the Clowder web app and run the wordcount extractor on a .txt file (or similar)! Done. + +### Python and Docker details + +You may use any version of Python 3. Simply edit the first line of the `Dockerfile`, by default it uses `FROM python:3.8`. + +Docker flags: + +- `--net` links the extractor to the Clowder Docker network (run `docker network ls` to identify your own.) +- `-e RABBITMQ_URI=` sets the environment variables can be used to control what RabbitMQ server and exchange it will bind itself to. Setting the `RABBITMQ_EXCHANGE` may also help. + - You can also use `--link` to link the extractor to a RabbitMQ container. +- `--name` assigns the container a name visible in Docker Desktop. + +## Troubleshooting + +**If you run into _any_ trouble**, please reach out on our Clowder Slack in the [#pyclowder channel](https://clowder-software.slack.com/archives/CNC2UVBCP). + +Alternate methods of running extractors are below. + +# Commandline Execution + +To execute the extractor from the command line you will need to have the required packages installed. It is highly recommended to use python virtual environment for this. You will need to create a virtual environment first, then activate it and finally install all required packages. + +``` +virtualenv /home/clowder/virtualenv/clowder2 +. /home/clowder/virtualenv/clowder2/bin/activate +cd prithvi_finetune_extractor/ +pip install -e . +``` + +To start the extractor you will need to load the virtual environment and start the extractor. + +``` +. /home/clowder/virtualenv/wordcount/bin/activate +/home/clowder/extractors/wordcount/prithvi_finetune_extractor.py +``` diff --git a/prithvi_geospatial_extractor/extractor_info.json b/prithvi_geospatial_extractor/extractor_info.json new file mode 100644 index 0000000..97a49a8 --- /dev/null +++ b/prithvi_geospatial_extractor/extractor_info.json @@ -0,0 +1,48 @@ +{ + "@context": "http://clowder.ncsa.illinois.edu/contexts/extractors.jsonld", + "name": "pritvi.finetuned.inference.file", + "version": "1.0", + "description": "Chose from fine-tuned Prithvi models to run inference on tif files", + "author": "Vismayak Mohanarajan", + "contributors": [ + ], + "contexts": [], + "repository": [ + { + "repType": "git", + "repUrl": "https://opensource.ncsa.illinois.edu/stash/scm/cats/pyclowder.git" + } + ], + "process": { + "file": [ + "manual.submission" + ] + }, + "max_retry": 1, + "external_services": [], + "dependencies": [], + "bibtex": [], + "parameters": { + "schema": { + "APPLICATION_TYPE": { + "type": "string", + "title": "Chose the finetuned model by application", + "enum": [ + "flood_mapping", + "burn_scars", + "cover_crop" + ], + "default": "flood_mapping" + }, + "SAVE_IMAGE": { + "type": "string", + "title": "Save an image of the inferences as a mask overlaying the input image", + "enum": [ + "True", + "False" + ], + "default": "True" + } + } + } +} \ No newline at end of file diff --git a/prithvi_geospatial_extractor/geospatial_fm/__init__.py b/prithvi_geospatial_extractor/geospatial_fm/__init__.py new file mode 100644 index 0000000..85247af --- /dev/null +++ b/prithvi_geospatial_extractor/geospatial_fm/__init__.py @@ -0,0 +1,27 @@ +from .geospatial_fm import ConvTransformerTokensToEmbeddingNeck, TemporalViTEncoder, GeospatialNeck +from .geospatial_pipelines import ( + TorchRandomCrop, + LoadGeospatialAnnotations, + LoadGeospatialImageFromFile, + Reshape, + CastTensor, + CollectTestList, + TorchPermute +) +from .datasets import GeospatialDataset +from .temporal_encoder_decoder import TemporalEncoderDecoder + +__all__ = [ + "GeospatialDataset", + "TemporalViTEncoder", + "ConvTransformerTokensToEmbeddingNeck", + "LoadGeospatialAnnotations", + "LoadGeospatialImageFromFile", + "TorchRandomCrop", + "TemporalEncoderDecoder", + "Reshape", + "CastTensor", + "CollectTestList", + "GeospatialNeck", + "TorchPermute" +] diff --git a/prithvi_geospatial_extractor/geospatial_fm/datasets.py b/prithvi_geospatial_extractor/geospatial_fm/datasets.py new file mode 100644 index 0000000..76a63eb --- /dev/null +++ b/prithvi_geospatial_extractor/geospatial_fm/datasets.py @@ -0,0 +1,25 @@ +from mmseg.datasets.builder import DATASETS +from mmseg.datasets.custom import CustomDataset +from .geospatial_pipelines import LoadGeospatialAnnotations + + +@DATASETS.register_module() +class GeospatialDataset(CustomDataset): + """GeospatialDataset dataset. + """ + + def __init__(self, CLASSES=(0, 1), PALETTE=None, **kwargs): + + self.CLASSES = CLASSES + + self.PALETTE = PALETTE + + gt_seg_map_loader_cfg = kwargs.pop('gt_seg_map_loader_cfg') if 'gt_seg_map_loader_cfg' in kwargs else dict() + reduce_zero_label = kwargs.pop('reduce_zero_label') if 'reduce_zero_label' in kwargs else False + + super(GeospatialDataset, self).__init__( + reduce_zero_label=reduce_zero_label, + # ignore_index=2, + **kwargs) + + self.gt_seg_map_loader = LoadGeospatialAnnotations(reduce_zero_label=reduce_zero_label, **gt_seg_map_loader_cfg) \ No newline at end of file diff --git a/prithvi_geospatial_extractor/geospatial_fm/geospatial_fm.py b/prithvi_geospatial_extractor/geospatial_fm/geospatial_fm.py new file mode 100644 index 0000000..de94c90 --- /dev/null +++ b/prithvi_geospatial_extractor/geospatial_fm/geospatial_fm.py @@ -0,0 +1,504 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from mmcv.runner import load_checkpoint +from mmseg.models.builder import BACKBONES, NECKS +from timm.models.layers import to_2tuple +from timm.models.vision_transformer import Block +from typing import List + + +def _convTranspose2dOutput( + input_size: int, + stride: int, + padding: int, + dilation: int, + kernel_size: int, + output_padding: int, +): + """ + Calculate the output size of a ConvTranspose2d. + Taken from: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + """ + return ( + (input_size - 1) * stride + - 2 * padding + + dilation * (kernel_size - 1) + + output_padding + + 1 + ) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def get_3d_sincos_pos_embed(embed_dim: int, grid_size: tuple, cls_token: bool = False): + # Copyright (c) Meta Platforms, Inc. and affiliates. + # All rights reserved. + + # This source code is licensed under the license found in the + # LICENSE file in the root directory of this source tree. + # -------------------------------------------------------- + # Position embedding utils + # -------------------------------------------------------- + """ + grid_size: 3d tuple of grid size: t, h, w + return: + pos_embed: L, D + """ + + assert embed_dim % 16 == 0 + + t_size, h_size, w_size = grid_size + + w_embed_dim = embed_dim // 16 * 6 + h_embed_dim = embed_dim // 16 * 6 + t_embed_dim = embed_dim // 16 * 4 + + w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size)) + h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size)) + t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size)) + + w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1)) + h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1)) + t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0) + + pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1) + + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +class PatchEmbed(nn.Module): + """Frames of 2D Images to Patch Embedding + The 3D version of timm.models.vision_transformer.PatchEmbed + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + num_frames: int = 3, + tubelet_size: int = 1, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: nn.Module = None, + flatten: bool = True, + bias: bool = True, + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.grid_size = ( + num_frames // tubelet_size, + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2] + self.flatten = flatten + + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=(tubelet_size, patch_size[0], patch_size[1]), + stride=(tubelet_size, patch_size[0], patch_size[1]), + bias=bias, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, T, H, W = x.shape + assert ( + H == self.img_size[0] + ), f"Input image height ({H}) doesn't match model ({self.img_size[0]})." + assert ( + W == self.img_size[1] + ), f"Input image width ({W}) doesn't match model ({self.img_size[1]})." + x = self.proj(x) + Hp, Wp = x.shape[3], x.shape[4] + if self.flatten: + x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C + x = self.norm(x) + return x, Hp, Wp + + +class Norm2d(nn.Module): + def __init__(self, embed_dim: int): + super().__init__() + self.ln = nn.LayerNorm(embed_dim, eps=1e-6) + + def forward(self, x): + x = x.permute(0, 2, 3, 1) + x = self.ln(x) + x = x.permute(0, 3, 1, 2).contiguous() + return x + +@NECKS.register_module() +class GeospatialNeck(nn.Module): + """ + Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. + Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2 + """ + + def __init__( + self, + embed_dim: int, + first_conv_channels: int, + Hp: int = 14, + Wp: int = 14, + channel_reduction_factor: int = 2, + num_convs: int = 4, + num_convs_per_upscale: int = 1, + dropout: bool = False, + drop_cls_token: bool = True, + ): + """ + + Args: + embed_dim (int): Input embedding dimension + first_conv_channel (int): Number of channels for first dimension + Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14. + Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14. + channel_reduction_factor (int): Factor that each convolutional block reduces number of channels by. + num_convs (int): Number of convolutional upscaling blocks. Each upscales 2x. + drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True. + """ + super().__init__() + self.drop_cls_token = drop_cls_token + self.Hp = Hp + self.Wp = Wp + self.H_out = Hp + self.W_out = Wp + self.dropout = dropout + + conv_kernel_size = 3 + conv_padding = 1 + + kernel_size = 2 + stride = 2 + dilation = 1 + padding = 0 + output_padding = 0 + + self.embed_dim = embed_dim + self.channels = [first_conv_channels // (channel_reduction_factor ** i) for i in range(num_convs)] + self.channels = [embed_dim] + self.channels + + for _ in range(len(self.channels) - 1): + self.H_out = _convTranspose2dOutput( + self.H_out, stride, padding, dilation, kernel_size, output_padding + ) + self.W_out = _convTranspose2dOutput( + self.W_out, stride, padding, dilation, kernel_size, output_padding + ) + + def _build_upscale_block(channels_in, channels_out): + layers = [] + layers.append(nn.ConvTranspose2d( + channels_in, + channels_out, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + output_padding=output_padding, + )) + + layers += [nn.Sequential( + nn.Conv2d(channels_out, + channels_out, + kernel_size=conv_kernel_size, + padding=conv_padding), + nn.BatchNorm2d(channels_out), + nn.Dropout() if self.dropout else nn.Identity(), + nn.ReLU()) for _ in range(num_convs_per_upscale)] + + return nn.Sequential(*layers) + + self.layers = nn.ModuleList([ + _build_upscale_block(self.channels[i], self.channels[i+1]) + for i in range(len(self.channels) - 1) + ]) + + def forward(self, x): + x = x[0] + if self.drop_cls_token: + x = x[:, 1:, :] + x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp) + + for layer in self.layers: + x = layer(x) + + x = x.reshape((x.shape[0], self.channels[-1], self.H_out, self.W_out)) + + out = tuple([x]) + + return out + +@NECKS.register_module() +class ConvTransformerTokensToEmbeddingNeck(nn.Module): + """ + Neck that transforms the token-based output of transformer into a single embedding suitable for processing with standard layers. + Performs 4 ConvTranspose2d operations on the rearranged input with kernel_size=2 and stride=2 + """ + + def __init__( + self, + embed_dim: int, + output_embed_dim: int, + # num_frames: int = 1, + Hp: int = 14, + Wp: int = 14, + drop_cls_token: bool = True, + ): + """ + + Args: + embed_dim (int): Input embedding dimension + output_embed_dim (int): Output embedding dimension + Hp (int, optional): Height (in patches) of embedding to be upscaled. Defaults to 14. + Wp (int, optional): Width (in patches) of embedding to be upscaled. Defaults to 14. + drop_cls_token (bool, optional): Whether there is a cls_token, which should be dropped. This assumes the cls token is the first token. Defaults to True. + """ + super().__init__() + self.drop_cls_token = drop_cls_token + self.Hp = Hp + self.Wp = Wp + self.H_out = Hp + self.W_out = Wp + # self.num_frames = num_frames + + kernel_size = 2 + stride = 2 + dilation = 1 + padding = 0 + output_padding = 0 + for _ in range(4): + self.H_out = _convTranspose2dOutput( + self.H_out, stride, padding, dilation, kernel_size, output_padding + ) + self.W_out = _convTranspose2dOutput( + self.W_out, stride, padding, dilation, kernel_size, output_padding + ) + + self.embed_dim = embed_dim + self.output_embed_dim = output_embed_dim + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d( + self.embed_dim, + self.output_embed_dim, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + output_padding=output_padding, + ), + Norm2d(self.output_embed_dim), + nn.GELU(), + nn.ConvTranspose2d( + self.output_embed_dim, + self.output_embed_dim, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + output_padding=output_padding, + ), + ) + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d( + self.output_embed_dim, + self.output_embed_dim, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + output_padding=output_padding, + ), + Norm2d(self.output_embed_dim), + nn.GELU(), + nn.ConvTranspose2d( + self.output_embed_dim, + self.output_embed_dim, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + output_padding=output_padding, + ), + ) + + def forward(self, x): + x = x[0] + if self.drop_cls_token: + x = x[:, 1:, :] + x = x.permute(0, 2, 1).reshape(x.shape[0], -1, self.Hp, self.Wp) + + x = self.fpn1(x) + x = self.fpn2(x) + + x = x.reshape((-1, self.output_embed_dim, self.H_out, self.W_out)) + + out = tuple([x]) + + return out + + +@BACKBONES.register_module() +class TemporalViTEncoder(nn.Module): + """Encoder from an ViT with capability to take in temporal input. + + This class defines an encoder taken from a ViT architecture. + """ + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + num_frames: int = 1, + tubelet_size: int = 1, + in_chans: int = 3, + embed_dim: int = 1024, + depth: int = 24, + num_heads: int = 16, + mlp_ratio: float = 4.0, + norm_layer: nn.Module = nn.LayerNorm, + norm_pix_loss: bool = False, + pretrained: str = None + ): + """ + + Args: + img_size (int, optional): Input image size. Defaults to 224. + patch_size (int, optional): Patch size to be used by the transformer. Defaults to 16. + num_frames (int, optional): Number of frames (temporal dimension) to be input to the encoder. Defaults to 1. + tubelet_size (int, optional): Tubelet size used in patch embedding. Defaults to 1. + in_chans (int, optional): Number of input channels. Defaults to 3. + embed_dim (int, optional): Embedding dimension. Defaults to 1024. + depth (int, optional): Encoder depth. Defaults to 24. + num_heads (int, optional): Number of heads used in the encoder blocks. Defaults to 16. + mlp_ratio (float, optional): Ratio to be used for the size of the MLP in encoder blocks. Defaults to 4.0. + norm_layer (nn.Module, optional): Norm layer to be used. Defaults to nn.LayerNorm. + norm_pix_loss (bool, optional): Whether to use Norm Pix Loss. Defaults to False. + pretrained (str, optional): Path to pretrained encoder weights. Defaults to None. + """ + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.embed_dim = embed_dim + self.patch_embed = PatchEmbed( + img_size, patch_size, num_frames, tubelet_size, in_chans, embed_dim + ) + num_patches = self.patch_embed.num_patches + self.num_frames = num_frames + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + self.blocks = nn.ModuleList( + [ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) + for _ in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + self.norm_pix_loss = norm_pix_loss + self.pretrained = pretrained + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_3d_sincos_pos_embed( + self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + if isinstance(self.pretrained, str): + self.apply(self._init_weights) + print(f"load from {self.pretrained}") + load_checkpoint(self, self.pretrained, strict=False, map_location="cpu") + elif self.pretrained is None: + # # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + # embed patches + x, _, _ = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return tuple([x]) diff --git a/prithvi_geospatial_extractor/geospatial_fm/geospatial_pipelines.py b/prithvi_geospatial_extractor/geospatial_fm/geospatial_pipelines.py new file mode 100644 index 0000000..fc6c4fb --- /dev/null +++ b/prithvi_geospatial_extractor/geospatial_fm/geospatial_pipelines.py @@ -0,0 +1,364 @@ +""" +This file holds pipeline components useful for loading remote sensing images and annotations. +""" +import os.path as osp + +import numpy as np +import rioxarray +import torchvision.transforms.functional as F +from mmcv.parallel import DataContainer as DC +from mmseg.datasets.builder import PIPELINES +from torchvision import transforms + + +def open_tiff(fname): + data = rioxarray.open_rasterio(fname) + return data.to_numpy() + + +@PIPELINES.register_module() +class ConstantMultiply(object): + """Multiply image by constant. + + It multiplies an image by a constant + + Args: + constant (float, optional): The constant to multiply by. 1.0 (e.g. no alteration if not specified) + """ + + def __init__(self, constant=1.0): + self.constant = constant + + def __call__(self, results): + """Call function to multiply by constant input img + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Results with image multiplied by constant + """ + + results["img"] = results["img"] * self.constant + + return results + + +@PIPELINES.register_module() +class BandsExtract(object): + + """Extract bands from image. Assumes channels last + + It extracts bands from an image. Assumes channels last. + + Args: + bands (list, optional): The list of indexes to use for extraction. If not provided nothing will happen. + """ + + def __init__(self, bands=None): + self.bands = bands + + def __call__(self, results): + """Call function to multiply extract bands + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Results with extracted bands + """ + + if self.bands is not None: + results["img"] = results["img"][..., self.bands] + + return results + + +@PIPELINES.register_module() +class TorchRandomCrop(object): + + """ + + It randomly crops a multichannel tensor. + + Args: + crop_size (tuple): the size to use to crop + """ + + def __init__(self, crop_size=(224, 224)): + self.crop_size = crop_size + + def __call__(self, results): + i, j, h, w = transforms.RandomCrop.get_params(results["img"], self.crop_size) + results["img"] = F.crop(results["img"], i, j, h, w).float() + results["gt_semantic_seg"] = F.crop(results["gt_semantic_seg"], i, j, h, w) + + return results + + +@PIPELINES.register_module() +class TorchNormalize(object): + """Normalize the image. + + It normalises a multichannel image using torch + + Args: + mean (sequence): Mean values . + std (sequence): Std values of 3 channels. + """ + + def __init__(self, means, stds): + self.means = means + self.stds = stds + + def __call__(self, results): + """Call function to normalize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + results["img"] = F.normalize(results["img"], self.means, self.stds, False) + results["img_norm_cfg"] = dict(mean=self.means, std=self.stds) + return results + + +@PIPELINES.register_module() +class Reshape(object): + """ + It reshapes a tensor. + Args: + new_shape (tuple): tuple with new shape + keys (list): list with keys to apply reshape to + look_up (dict): dictionary to use to look up dimensions when more than one is to be inferred from the original image, which have to be inputed as -1s in the new_shape argument. eg {'2': 1, '3': 2} would infer the new 3rd and 4th dimensions from the 2nd and 3rd from the original image. + """ + + def __init__(self, new_shape, keys, look_up=None): + self.new_shape = new_shape + self.keys = keys + self.look_up = look_up + + def __call__(self, results): + dim_to_infer = np.where(np.array(self.new_shape) == -1)[0] + + for key in self.keys: + if (len(dim_to_infer) > 1) & (self.look_up is not None): + old_shape = results[key].shape + tmp = np.array(self.new_shape) + for i in range(len(dim_to_infer)): + tmp[dim_to_infer[i]] = old_shape[self.look_up[str(dim_to_infer[i])]] + self.new_shape = tuple(tmp) + results[key] = results[key].reshape(self.new_shape) + + return results + + +@PIPELINES.register_module() +class CastTensor(object): + """ + + It casts a tensor. + + Args: + new_type (str): torch type + keys (list): list with keys to apply reshape to + """ + + def __init__(self, new_type, keys): + self.new_type = new_type + self.keys = keys + + def __call__(self, results): + for key in self.keys: + results[key] = results[key].type(self.new_type) + + return results + + +@PIPELINES.register_module() +class CollectTestList(object): + """ + + It processes the data in a way that conforms with inference and test pipelines. + + Args: + + keys (list): keys to collect (eg img/gt_semantic_seg) + meta_keys (list): additional meta to collect and add to img_metas + + """ + + def __init__( + self, + keys, + meta_keys=( + "filename", + "ori_filename", + "ori_shape", + "img_shape", + "pad_shape", + "scale_factor", + "flip", + "flip_direction", + "img_norm_cfg", + ), + ): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + data = {} + img_meta = {} + for key in self.meta_keys: + img_meta[key] = results[key] + img_meta = [img_meta] + data["img_metas"] = DC(img_meta, cpu_only=True) + for key in self.keys: + data[key] = [results[key]] + return data + + def __repr__(self): + return ( + self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})" + ) + + +@PIPELINES.register_module() +class TorchPermute(object): + """Permute dimensions. + + Particularly useful in going from channels_last to channels_first + + Args: + keys (Sequence[str]): Keys of results to be permuted. + order (Sequence[int]): New order of dimensions. + """ + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + for key in self.keys: + results[key] = results[key].permute(self.order) + + return results + + def __repr__(self): + return self.__class__.__name__ + f"(keys={self.keys}, order={self.order})" + + +@PIPELINES.register_module() +class LoadGeospatialImageFromFile(object): + """ + + It loads a tiff image. Returns in channels last format. + + Args: + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + nodata (float/int): no data value to substitute to nodata_replace + nodata_replace (float/int): value to use to replace no data + """ + + def __init__(self, to_float32=False, nodata=None, nodata_replace=0.0): + self.to_float32 = to_float32 + self.nodata = nodata + self.nodata_replace = nodata_replace + + def __call__(self, results): + if results.get("img_prefix") is not None: + filename = osp.join(results["img_prefix"], results["img_info"]["filename"]) + else: + filename = results["img_info"]["filename"] + img = open_tiff(filename) + # to channels last format + img = np.transpose(img, (1, 2, 0)) + + if self.to_float32: + img = img.astype(np.float32) + + if self.nodata is not None: + img = np.where(img == self.nodata, self.nodata_replace, img) + + results["filename"] = filename + results["ori_filename"] = results["img_info"]["filename"] + results["img"] = img + results["img_shape"] = img.shape + results["ori_shape"] = img.shape + # Set initial values for default meta_keys + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 + results["flip"] = False + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results["img_norm_cfg"] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False, + ) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(to_float32={self.to_float32}" + return repr_str + + +@PIPELINES.register_module() +class LoadGeospatialAnnotations(object): + """Load annotations for semantic segmentation. + + Args: + to_uint8 (bool): Whether to convert the loaded label to a uint8 + reduce_zero_label (bool): Whether reduce all label value by 1. + Usually used for datasets where 0 is background label. + Default: False. + nodata (float/int): no data value to substitute to nodata_replace + nodata_replace (float/int): value to use to replace no data + + + """ + + def __init__( + self, + reduce_zero_label=False, + nodata=None, + nodata_replace=-1, + ): + self.reduce_zero_label = reduce_zero_label + self.nodata = nodata + self.nodata_replace = nodata_replace + + def __call__(self, results): + if results.get("seg_prefix", None) is not None: + filename = osp.join(results["seg_prefix"], results["ann_info"]["seg_map"]) + else: + filename = results["ann_info"]["seg_map"] + + gt_semantic_seg = open_tiff(filename).squeeze() + + if self.nodata is not None: + gt_semantic_seg = np.where( + gt_semantic_seg == self.nodata, self.nodata_replace, gt_semantic_seg + ) + # reduce zero_label + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + if results.get("label_map", None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results["label_map"].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + + results["gt_semantic_seg"] = gt_semantic_seg + results["seg_fields"].append("gt_semantic_seg") + return results diff --git a/prithvi_geospatial_extractor/geospatial_fm/temporal_encoder_decoder.py b/prithvi_geospatial_extractor/geospatial_fm/temporal_encoder_decoder.py new file mode 100644 index 0000000..717fdd5 --- /dev/null +++ b/prithvi_geospatial_extractor/geospatial_fm/temporal_encoder_decoder.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.core import add_prefix +from mmseg.ops import resize +from mmseg.models import builder +from mmseg.models.builder import SEGMENTORS +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.models.segmentors.encoder_decoder import EncoderDecoder + + +@SEGMENTORS.register_module() +class TemporalEncoderDecoder(EncoderDecoder): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, neck, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + + The backbone should return plain embeddings. + The neck can process these to make them suitable for the chosen heads. + The heads perform the final processing that will return the output. + """ + + def __init__(self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None, + frozen_backbone=False): + super(EncoderDecoder, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + + if frozen_backbone: + for param in self.backbone.parameters(): + param.requires_grad = False + + if neck is not None: + self.neck = builder.build_neck(neck) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + assert self.with_decode_head + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + + #### size calculated over last two dimensions ### + size = img.shape[-2:] + + out = resize( + input=out, + size=size, + mode='bilinear', + align_corners=self.align_corners) + return out + + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + + #### size and bactch size over last two dimensions ### + img_size = img.size() + batch_size = img_size[0] + h_img = img_size[-2] + w_img = img_size[-1] + out_channels = self.out_channels + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + + if len(img_size) == 4: + + crop_img = img[:, :, y1:y2, x1:x2] + + elif len(img_size) == 5: + + crop_img = img[:, :, :, y1:y2, x1:x2] + + + + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy( + count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + + if rescale: + # remove padding area + #### size over last two dimensions ### + resize_shape = img_meta[0]['img_shape'][:2] + preds = preds[:, :, :resize_shape[0], :resize_shape[1]] + preds = resize( + preds, + size=img_meta[0]['ori_shape'][:2], + mode='bilinear', + align_corners=self.align_corners, + warning=False) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[-2:] + else: + # remove padding area + resize_shape = img_meta[0]['img_shape'][:2] + seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]] + size = img_meta[0]['ori_shape'][:2] + seg_logit = resize( + seg_logit, + size=size, + mode='bilinear', + align_corners=self.align_corners, + warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = img_meta[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in img_meta) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + + if self.out_channels == 1: + output = F.sigmoid(seg_logit) + else: + output = F.softmax(seg_logit, dim=1) + + flip = ( + img_meta[0]["flip"] if "flip" in img_meta[0] else False + ) ##### if flip key is not there d not apply it + if flip: + flip_direction = img_meta[0]["flip_direction"] + assert flip_direction in ["horizontal", "vertical"] + if flip_direction == "horizontal": + output = output.flip(dims=(3,)) + elif flip_direction == "vertical": + output = output.flip(dims=(2,)) + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + if self.out_channels == 1: + seg_pred = (seg_logit > self.decode_head.threshold).to(seg_logit).squeeze(1) + else: + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/prithvi_geospatial_extractor/model_inference.py b/prithvi_geospatial_extractor/model_inference.py new file mode 100644 index 0000000..1988ccd --- /dev/null +++ b/prithvi_geospatial_extractor/model_inference.py @@ -0,0 +1,236 @@ +import argparse +import glob +import os +import time + +import numpy as np +import rasterio +import torch +from mmcv import Config +from mmcv.parallel import collate, scatter +from mmseg.apis import init_segmentor +from mmseg.datasets.pipelines import Compose, LoadImageFromFile + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Inference on flood detection fine-tuned model" + ) + parser.add_argument("-config", help="path to model configuration file") + parser.add_argument("-ckpt", help="path to model checkpoint") + parser.add_argument("-input", help="path to input images folder for inference") + parser.add_argument("-output", help="path to save output image") + parser.add_argument("-input_type", help="file type of input images", default="tif") + parser.add_argument( + "-bands", + help="bands in the file where to find the relevant data", + type=int, + nargs="+", + ) + parser.add_argument("-device", help="device", default="cuda", type=str) + + args = parser.parse_args() + + return args + + +def open_tiff(fname): + with rasterio.open(fname, "r") as src: + data = src.read() + + return data + + +def write_tiff(img_wrt, filename, metadata): + """ + It writes a raster image to file. + + :param img_wrt: numpy array containing the data (can be 2D for single band or 3D for multiple bands) + :param filename: file path to the output file + :param metadata: metadata to use to write the raster to disk + :return: + """ + + with rasterio.open(filename, "w", **metadata) as dest: + if len(img_wrt.shape) == 2: + img_wrt = img_wrt[None] + + for i in range(img_wrt.shape[0]): + dest.write(img_wrt[i, :, :], i + 1) + + return filename + + +def get_meta(fname): + with rasterio.open(fname, "r") as src: + meta = src.meta + + return meta + + +def inference_segmentor(model, imgs, custom_test_pipeline=None): + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + (list[Tensor]): The segmentation result. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = ( + [LoadImageFromFile()] + cfg.data.test.pipeline[1:] + if custom_test_pipeline == None + else custom_test_pipeline + ) + test_pipeline = Compose(test_pipeline) + # prepare data + data = [] + imgs = imgs if isinstance(imgs, list) else [imgs] + for img in imgs: + img_data = {"img_info": {"filename": img}} + img_data = test_pipeline(img_data) + data.append(img_data) + # print(data.shape) + + data = collate(data, samples_per_gpu=len(imgs)) + if next(model.parameters()).is_cuda: + # data = collate(data, samples_per_gpu=len(imgs)) + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + # img_metas = scatter(data['img_metas'],'cpu') + # data['img_metas'] = [i.data[0] for i in data['img_metas']] + + img_metas = data["img_metas"].data[0] + img = data["img"] + data = {"img": img, "img_metas": img_metas} + + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + return result + + +def inference_on_file(model, target_image, output_image, custom_test_pipeline): + time_taken = -1 + try: + st = time.time() + print("Running inference...") + result = inference_segmentor(model, target_image, custom_test_pipeline) + print("Output has shape: " + str(result[0].shape)) + + ##### get metadata mask + mask = open_tiff(target_image) + meta = get_meta(target_image) + mask = np.where(mask == meta["nodata"], 1, 0) + mask = np.max(mask, axis=0)[None] + + result[0] = np.where(mask == 1, -1, result[0]) + + ##### Save file to disk + meta["count"] = 1 + meta["dtype"] = "int16" + meta["compress"] = "lzw" + meta["nodata"] = -1 + print("Saving output...") + write_tiff(result[0], output_image, meta) + et = time.time() + time_taken = np.round(et - st, 1) + print( + f"Inference completed in {str(time_taken)} seconds. Output available at: " + + output_image + ) + + except: + print(f"Error on image {target_image} \nContinue to next input") + + return time_taken + + +def process_test_pipeline(custom_test_pipeline, bands=None): + # change extracted bands if necessary + if bands is not None: + extract_index = [ + i for i, x in enumerate(custom_test_pipeline) if x["type"] == "BandsExtract" + ] + + if len(extract_index) > 0: + custom_test_pipeline[extract_index[0]]["bands"] = bands + + collect_index = [ + i for i, x in enumerate(custom_test_pipeline) if x["type"].find("Collect") > -1 + ] + + # adapt collected keys if necessary + if len(collect_index) > 0: + keys = [ + "img_info", + "filename", + "ori_filename", + "img", + "img_shape", + "ori_shape", + "pad_shape", + "scale_factor", + "img_norm_cfg", + ] + custom_test_pipeline[collect_index[0]]["meta_keys"] = keys + + return custom_test_pipeline + + +def inference_on_files( + config_path, ckpt, input_type, input_path, output_path, bands, device +): + # load model + config = Config.fromfile(config_path) + config.model.backbone.pretrained = None + model = init_segmentor(config, ckpt, device) + + # identify images to predict on + target_images = glob.glob(os.path.join(input_path, "*." + input_type)) + + print("Identified images to predict on: " + str(len(target_images))) + + # check if output folder available + if not os.path.isdir(output_path): + os.mkdir(output_path) + + # modify test pipeline if necessary + custom_test_pipeline = process_test_pipeline(model.cfg.data.test.pipeline, bands) + + # for each image predict and save to disk + for i, target_image in enumerate(target_images): + print(f"Working on Image {i}") + output_image = os.path.join( + output_path, + target_image.split("/")[-1].replace( + "." + input_type, "_pred." + input_type + ), + ) + + inference_on_file(model, target_image, output_image, custom_test_pipeline) + + +def main(): + # unpack args + args = parse_args() + config_path = args.config + ckpt = args.ckpt + input_type = args.input_type + input_path = args.input + output_path = args.output + bands = args.bands + device = args.device + + inference_on_files( + config_path, ckpt, input_type, input_path, output_path, bands, device + ) + + +if __name__ == "__main__": + main() diff --git a/prithvi_geospatial_extractor/prithvi_finetuned_extractor.py b/prithvi_geospatial_extractor/prithvi_finetuned_extractor.py new file mode 100644 index 0000000..e628ec6 --- /dev/null +++ b/prithvi_geospatial_extractor/prithvi_finetuned_extractor.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +"""Prithvi fine-tuned extractor based on the example extractor clowder code.""" + +import logging +import json +import subprocess +from typing import Dict + +from pyclowder.utils import CheckMessage +from pyclowder.extractors import Extractor +import pyclowder.files + +from prithvi_finetuned_model import PrithviFineTunedModel + +class PrithviFineTunedExtractor(Extractor): + """Prithvi Fine-Tuned Extractor""" + + def __init__(self): + Extractor.__init__(self) + # parse command line and load default logging configuration + self.setup() + + # setup logging for the extractor + logging.getLogger('pyclowder').setLevel(logging.DEBUG) + logging.getLogger('__main__').setLevel(logging.DEBUG) + + + def process_message(self, connector, host, secret_key, resource, parameters): + + input_file = resource["local_paths"][0] + file_name = resource['name'] + dataset_id = resource['parent']['id'] + + # Load user-defined params from the GUI. + APPLICATION_TYPE = '' + SAVE_IMAGE = False + if 'parameters' in parameters: + params = None + try: + params = json.loads(parameters['parameters']) + except TypeError as e: + print(f"Failed to load parameters, it's not compatible with json.loads().\nError:{e}") + if type(parameters == Dict): + params = parameters['parameters'] + + if 'APPLICATION_TYPE' in params: + APPLICATION_TYPE = params['APPLICATION_TYPE'] + print(f"Received APPLICATION_TYPE: {APPLICATION_TYPE}") + + if 'SAVE_IMAGE' in params: + SAVE_IMAGE = params['SAVE_IMAGE'] + print(f"Received SAVE_IMAGE: {SAVE_IMAGE}") + + # Load tif file + connector.message_process(resource, "Loading contents of file...") + + # PREDICT + model = PrithviFineTunedModel() + model.get_model(APPLICATION_TYPE) + output_file = file_name.replace(".tif", "_pred.tif") + model.inference(input_file, output_file, save_image=SAVE_IMAGE) + print("Inference successful") + + # Upload predicted tiff file to Clowder dataset as a new file + # if save_image is true upload the combined image as well + connector.message_process(resource, "Uploading predicted tiff file...") + pyclowder.files.upload_to_dataset(connector, host, secret_key, dataset_id, output_file) + if SAVE_IMAGE: + combined_image = file_name.replace(".tif", "_masked.png") + pyclowder.files.upload_to_dataset(connector, host, secret_key, dataset_id, combined_image) + + +if __name__ == "__main__": + extractor = PrithviFineTunedExtractor() + extractor.start() \ No newline at end of file diff --git a/prithvi_geospatial_extractor/prithvi_finetuned_model.py b/prithvi_geospatial_extractor/prithvi_finetuned_model.py new file mode 100644 index 0000000..f182d0a --- /dev/null +++ b/prithvi_geospatial_extractor/prithvi_finetuned_model.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +"""Parallel Machine Learning Extractor""" +from mmcv import Config + +from mmseg.apis import init_segmentor +from model_inference import inference_segmentor, process_test_pipeline, inference_on_file +from huggingface_hub import hf_hub_download +from viz_helpers import load_raster, enhance_raster_for_visualization + +import matplotlib +matplotlib.use('Agg') # Since we are not using a GUI +import matplotlib.pyplot as plt +import numpy as np + + +class PrithviFineTunedModel: + """Prithvi Fine-Tuned Model""" + + def __init__(self): + self.model = None + + + def get_model(self,application_name): + """This function takes in the application name and assigns the finetuned model for that application to the object's model. + There are 3 applications available: Flood Mapping, Burn Scars detection and Multi-temporal-crop classification. + These models are available on Hugging Face Hub and are downloaded using the hf_hub_download function. + Args: + application_name (str): The name of the application for which the model is required. + """ + + if application_name == "flood_mapping": + repo_id = "ibm-nasa-geospatial/Prithvi-100M-sen1floods11" + config_filename = "sen1floods11_Prithvi_100M.py" + ckpt_filename = "sen1floods11_Prithvi_100M.pth" + elif application_name == "burn_scars": + repo_id = "ibm-nasa-geospatial/Prithvi-100M-burn-scar" + config_filename = "burn_scars_Prithvi_100M.py" + ckpt_filename = "burn_scars_Prithvi_100M.pth" + elif application_name == "cover_crop": + repo_id = "ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification" + config_filename = "multi_temporal_crop_classification_Prithvi_100M.py" + ckpt_filename = "multi_temporal_crop_classification_Prithvi_100M.pth" + else: + raise ValueError("Invalid application name. Please choose from flood_mapping, burn_scars or cover_crop") + + config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) + ckpt = hf_hub_download(repo_id=repo_id, filename=ckpt_filename) + self.model = init_segmentor(Config.fromfile(config_path), ckpt, device="cpu") + + def inference(self, input_data, output_file_name, save_image = False): + """This function takes in the input data and mask image and save the inferred TIFF image and + if set to true, the combined image is saved as well. + Args: + input_data (np.ndarray): The input data + output_file_name (str): The name of the output file + save_metadata_image (bool): A flag which would return a image to save as metadata if set to true + + """ + if self.model is None: + raise ValueError("Model not found. Please load the model using get_model function.") + + custom_test_pipeline = process_test_pipeline(self.model.cfg.data.test.pipeline) + inference_on_file(self.model, input_data, output_file_name, custom_test_pipeline ) + if save_image: + input_data_inference = enhance_raster_for_visualization(load_raster(input_data)) + output_data_inference = enhance_raster_for_visualization(load_raster(output_file_name)) + norm = matplotlib.colors.Normalize(vmin=0, vmax=2) + # Combine the input and output images and save as metadata + fig, ax = plt.subplots() + ax.imshow(input_data_inference) + ax.imshow(output_data_inference, cmap="jet", alpha=0.3, norm=norm) + ax.axis('off') + fig.savefig(output_file_name.replace("_pred.tif", "_masked.png"), bbox_inches='tight', + pad_inches=0, transparent=True) + + +# Test the model +if __name__ == "__main__": + model = PrithviFineTunedModel() + model.get_model("flood_mapping") + model.inference("Spain_7370579_S2Hand.tif", "output.tif", save_metadata_image=True) + print("Inference successful") \ No newline at end of file diff --git a/prithvi_geospatial_extractor/setup.py b/prithvi_geospatial_extractor/setup.py new file mode 100644 index 0000000..06f8d65 --- /dev/null +++ b/prithvi_geospatial_extractor/setup.py @@ -0,0 +1,22 @@ +# Source - https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/setup.py +from setuptools import setup + +setup( + name="geospatial_fm", + version="0.1.0", + description="MMSegmentation classes for geospatial-fm finetuning", + author="Paolo Fraccaro, Carlos Gomes, Johannes Jakubik", + packages=["geospatial_fm"], + license="Apache 2", + install_requires=[ + "mmsegmentation @ git+https://github.com/open-mmlab/mmsegmentation.git@186572a3ce64ac9b6b37e66d58c76515000c3280", + "rasterio", + "rioxarray", + "einops", + "timm==0.4.12", + "tensorboard", + "imagecodecs", + "yapf==0.40.1", + "pyclowder==3.0.7" + ], +) diff --git a/prithvi_geospatial_extractor/viz_helpers.py b/prithvi_geospatial_extractor/viz_helpers.py new file mode 100644 index 0000000..1f2d4c7 --- /dev/null +++ b/prithvi_geospatial_extractor/viz_helpers.py @@ -0,0 +1,61 @@ +# Description: Helper functions for visualizing the raster data + + +import matplotlib.pyplot as plt +import numpy as np +import rasterio + + +NO_DATA = -9999 +NO_DATA_FLOAT = 0.0001 +PERCENTILES = (0.1, 99.9) + + +def load_raster(path, crop=None): + with rasterio.open(path) as src: + img = src.read() + + # load first 6 bands + img = img[:6] + + img = np.where(img == NO_DATA, NO_DATA_FLOAT, img) + if crop: + img = img[:, -crop[0]:, -crop[1]:] + return img + +def enhance_raster_for_visualization(raster, ref_img=None): + if ref_img is None: + ref_img = raster + channels = [] + for channel in range(raster.shape[0]): + valid_mask = np.ones_like(ref_img[channel], dtype=bool) + valid_mask[ref_img[channel] == NO_DATA_FLOAT] = False + mins, maxs = np.percentile(ref_img[channel][valid_mask], PERCENTILES) + normalized_raster = (raster[channel] - mins) / (maxs - mins) + normalized_raster[~valid_mask] = 0 + clipped = np.clip(normalized_raster, 0, 1) + channels.append(clipped) + clipped = np.stack(channels) + channels_last = np.moveaxis(clipped, 0, -1)[..., :3] + rgb = channels_last[..., ::-1] + return rgb + +def plot_image_mask_reconstruction(normalized, mask_img, pred_img): + # Mix visible and predicted patches + rec_img = normalized.clone() + rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove + + mask_img_np = mask_img.numpy().reshape(6, 224, 224).transpose((1, 2, 0))[..., :3] + + rec_img_np = (rec_img.numpy().reshape(6, 224, 224) * stds) + means + + fig, ax = plt.subplots(1, 3, figsize=(15, 6)) + + for subplot in ax: + subplot.axis('off') + + ax[0].imshow(enhance_raster_for_visualization(input_data)) + masked_img_np = enhance_raster_for_visualization(input_data).copy() + masked_img_np[mask_img_np[..., 0] == 1] = 0 + ax[1].imshow(masked_img_np) + ax[2].imshow(enhance_raster_for_visualization(rec_img_np, ref_img=input_data)) \ No newline at end of file diff --git a/ray_workflows/requirements.txt b/ray_workflows/requirements.txt index 16994e9..1b20b51 100644 --- a/ray_workflows/requirements.txt +++ b/ray_workflows/requirements.txt @@ -1,3 +1,3 @@ -pyclowder==2.4.0 +pyclowder==3.0.7 codeflare scikit-learn \ No newline at end of file diff --git a/template_for_custom_parallel_batch_extractors/requirements.txt b/template_for_custom_parallel_batch_extractors/requirements.txt index 06653d2..768ca22 100644 --- a/template_for_custom_parallel_batch_extractors/requirements.txt +++ b/template_for_custom_parallel_batch_extractors/requirements.txt @@ -1,4 +1,4 @@ -pyclowder==2.6.0 +pyclowder==3.0.7 numpy ray[default]==1.13.0 keras