diff --git a/kornia/augmentation/_2d/geometric/thin_plate_spline.py b/kornia/augmentation/_2d/geometric/thin_plate_spline.py index d21ee534d9f..ce63e59cc45 100644 --- a/kornia/augmentation/_2d/geometric/thin_plate_spline.py +++ b/kornia/augmentation/_2d/geometric/thin_plate_spline.py @@ -70,6 +70,11 @@ def __init__( "padding_mode": SamplePadding.get(padding_mode), } self.dist = torch.distributions.Uniform(-scale, scale) + # Pre-create the source control points template + self._src_template = torch.tensor( + [[[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0], [0.0, 0.0]]], + dtype=torch.float32, + ) def generate_parameters(self, shape: Tuple[int, ...]) -> Dict[str, torch.Tensor]: B, _, _, _ = shape @@ -78,11 +83,7 @@ def generate_parameters(self, shape: Tuple[int, ...]) -> Dict[str, torch.Tensor] dtype = self.dtype # 5 TPS control points in normalized coordinates - src = torch.tensor( - [[[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0], [0.0, 0.0]]], - device=device, - dtype=dtype, - ).expand(B, 5, 2) + src = self._src_template.to(device=device, dtype=dtype).expand(B, 5, 2) if self.same_on_batch: noise = self.dist.rsample((1, 5, 2)).to(device=device, dtype=dtype)