diff --git a/saicinpainting/training/data/aug.py b/saicinpainting/training/data/aug.py index b1246250..63db575d 100644 --- a/saicinpainting/training/data/aug.py +++ b/saicinpainting/training/data/aug.py @@ -1,7 +1,7 @@ -from albumentations import DualIAATransform, to_tuple -import imgaug.augmenters as iaa +from albumentations import DualTransform, Affine, Perspective +from albumentations.core.utils import to_tuple -class IAAAffine2(DualIAATransform): +class IAAAffine2(DualTransform): """Place a regular grid of points on the input and randomly move the neighbourhood of these point around via affine transformations. @@ -39,7 +39,7 @@ def __init__( @property def processor(self): - return iaa.Affine( + return Affine( self.scale, self.translate_percent, self.translate_px, @@ -54,7 +54,7 @@ def get_transform_init_args_names(self): return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode") -class IAAPerspective2(DualIAATransform): +class IAAPerspective2(DualTransform): """Perform a random four point perspective transform of the input. Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} @@ -78,7 +78,7 @@ def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5, @property def processor(self): - return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval) + return Perspective(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval) def get_transform_init_args_names(self): return ("scale", "keep_size") diff --git a/saicinpainting/training/trainers/__init__.py b/saicinpainting/training/trainers/__init__.py index c59241f5..5f604f49 100644 --- a/saicinpainting/training/trainers/__init__.py +++ b/saicinpainting/training/trainers/__init__.py @@ -24,7 +24,7 @@ def make_training_model(config): def load_checkpoint(train_config, path, map_location='cuda', strict=True): model: torch.nn.Module = make_training_model(train_config) - state = torch.load(path, map_location=map_location) + state = torch.load(path, map_location=map_location, weights_only=False) model.load_state_dict(state['state_dict'], strict=strict) model.on_load_checkpoint(state) return model