Custom Preprocessors #2750
-
Hello. I have created a custom pre-processor, and I am curious how I should go about using it with an existing model. Currently, if I just pass it into the model with the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 9 replies
-
@lucianchauvin, passing it to the model should ideally be the way. Can you provide minimal reproducible example? |
Beta Was this translation helpful? Give feedback.
-
Hello, This works for me just fine, maybe it'll help you: from anomalib.engine import Engine
from anomalib.models import Padim
from anomalib.data import MVTecAD
from anomalib.pre_processing import PreProcessor
from torchvision.transforms.v2 import RandomVerticalFlip, Compose
class DefaultProcessor(PreProcessor):
def __init__(self, transform):
super().__init__(transform=transform)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
batch.image = self.transform(batch.image)
def on_val_batch_start(self, trainer, pl_module, batch, batch_idx):
batch.image = self.transform(batch.image)
def on_test_batch_start(self, trainer, pl_module, batch, batch_idx):
batch.image = self.transform(batch.image)
def on_predict_batch_start(self, trainer, pl_module, batch, batch_idx):
batch.image = self.transform(batch.image)
datamodule = MVTecAD(
category="bottle", # MVTec category to use
train_batch_size=32, # Number of images per training batch
eval_batch_size=32, # Number of images per validation/test batch
num_workers=8, # Number of parallel processes for data loading
)
datamodule.prepare_data()
datamodule.setup()
i, train_data = next(enumerate(datamodule.train_dataloader()))
print("Batch Image Shape", train_data.image.shape)
new_transform = RandomVerticalFlip(p=1.0)
old_transform = Padim.configure_pre_processor().transform
combined_transform = Compose(
[new_transform, *old_transform.transforms],
)
default_preprocess = DefaultProcessor(transform=combined_transform)
model = Padim(pre_processor=default_preprocess)
engine = Engine()
engine.fit(model=model, datamodule=datamodule)
test_results = engine.test(
model=model,
datamodule=datamodule,
ckpt_path=engine.trainer.checkpoint_callback.best_model_path,
)
You are also overwriting the model's pre-processor but just keeping the same transforms. You could in theory |
Beta Was this translation helpful? Give feedback.
It's unorthodox but this works for me 😄: