This repository was archived by the owner on Oct 9, 2023. It is now read-only.
Replies: 1 comment
-
Hi @andife thanks for the question! Here's an example showing how you can use custom transforms from albumentations with the semantic segmentation task: from functools import partial
from typing import Tuple
import torch
from dataclasses import dataclass
import albumentations as A
import flash
from flash import InputTransform
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
# 1. Create the DataModule
# The data was generated with the CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"./data",
)
@dataclass
class CustomTransform(InputTransform):
image_size: Tuple[int, int] = (300, 300)
crop_size: Tuple[int, int] = (256, 256)
def __post_init__(self):
self.train_transform = A.Compose([
A.Resize(width=self.image_size[0], height=self.image_size[1]),
A.RandomCrop(width=self.crop_size[0], height=self.crop_size[1]),
A.HorizontalFlip(p=0.5),
A.RandomBrightnessContrast(p=0.2),
])
self.transform = A.Compose([
A.Resize(width=self.image_size[0], height=self.image_size[1]),
A.CenterCrop(width=self.crop_size[0], height=self.crop_size[1]),
])
super().__post_init__()
@staticmethod
def _apply_transform(transform, sample):
if "target" in sample:
kwargs = {
"mask": sample["target"].numpy()
}
else:
kwargs = {}
transformed = transform(
image=sample["input"].permute(1, 2, 0).numpy(),
**kwargs
)
sample["input"] = torch.from_numpy(transformed["image"]).permute(2, 0, 1)
if "mask" in transformed:
sample["target"] = torch.from_numpy(transformed["mask"])
return sample
def per_sample_transform(self):
return partial(self._apply_transform, self.transform)
def train_per_sample_transform(self):
return partial(self._apply_transform, self.train_transform)
@staticmethod
def _prepare_target(target) -> torch.Tensor:
"""Convert the target mask to long and remove the channel dimension."""
return target.long().squeeze(1)
def target_per_batch_transform(self):
return self._prepare_target
datamodule = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
val_split=0.1,
train_transform=CustomTransform,
val_transform=CustomTransform,
transform_kwargs=dict(image_size=(300, 300)),
num_classes=21,
batch_size=4,
)
# 2. Build the task
model = SemanticSegmentation(
backbone="mobilenetv3_large_100",
head="fpn",
num_classes=datamodule.num_classes,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count(), fast_dev_run=True)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Segment a few images!
datamodule = SemanticSegmentationData.from_files(
predict_transform=CustomTransform,
predict_files=[
"data/CameraRGB/F61-1.png",
"data/CameraRGB/F62-1.png",
"data/CameraRGB/F63-1.png",
],
batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt") Hope that helps 😃 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I'm trying to use a more advanced/specific image augmentation setup for a pytorch-flash task.
As starting point I look at the following examples and combined them:
#1107
https://lightning-flash.readthedocs.io/en/stable/reference/semantic_segmentation.html
The error I get is
as I understand icevision is not implemented for semantic segmentation task. As it seems to be implemented in instance_segmentation, the code should be transferable to the semantic segmentation functionality?
Another option I tried was https://github.com/PyTorchLightning/lightning-flash#flash-transforms but for I realized that
segmentation/input_transform.py does not have the same functionality as flash.image.classification.input_transform so far.
What would be the next recommended steps?
Beta Was this translation helpful? Give feedback.
All reactions