Skip to content

Commit f9417e3

Browse files
authored
Fix PadToSize impl to follow Transform API after torchvision 0.21 (#629)
1 parent abe5441 commit f9417e3

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

rtdetr_pytorch/src/data/transforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
8181
self.padding = [0, 0, w, h]
8282
return dict(padding=self.padding)
8383

84+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
85+
return self._get_params(flat_inputs)
86+
8487
def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None:
8588
if isinstance(spatial_size, int):
8689
spatial_size = (spatial_size, spatial_size)
@@ -93,6 +96,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
9396
padding = params['padding']
9497
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
9598

99+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
100+
return self._transform(inpt, params)
101+
96102
def __call__(self, *inputs: Any) -> Any:
97103
outputs = super().forward(*inputs)
98104
if len(outputs) > 1 and isinstance(outputs[1], dict):

rtdetrv2_pytorch/src/data/transforms/_transforms.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
5959
self.padding = [0, 0, w, h]
6060
return dict(padding=self.padding)
6161

62+
def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
63+
return self._get_params(flat_inputs)
64+
6265
def __init__(self, size, fill=0, padding_mode='constant') -> None:
6366
if isinstance(size, int):
6467
size = (size, size)
@@ -70,6 +73,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
7073
padding = params['padding']
7174
return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type]
7275

76+
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
77+
return self._transform(inpt, params)
78+
7379
def __call__(self, *inputs: Any) -> Any:
7480
outputs = super().forward(*inputs)
7581
if len(outputs) > 1 and isinstance(outputs[1], dict):
@@ -139,4 +145,4 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
139145
return inpt
140146

141147
def transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
142-
return self._transform(inpt, params)
148+
return self._transform(inpt, params)

0 commit comments

Comments
 (0)