diff --git a/rtdetr_pytorch/src/data/transforms.py b/rtdetr_pytorch/src/data/transforms.py index aab827541..13f469e0d 100644 --- a/rtdetr_pytorch/src/data/transforms.py +++ b/rtdetr_pytorch/src/data/transforms.py @@ -81,6 +81,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: self.padding = [0, 0, w, h] return dict(padding=self.padding) + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + return self._get_params(flat_inputs) + def __init__(self, spatial_size, fill=0, padding_mode='constant') -> None: if isinstance(spatial_size, int): spatial_size = (spatial_size, spatial_size) @@ -93,6 +96,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: padding = params['padding'] return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._transform(inpt, params) + def __call__(self, *inputs: Any) -> Any: outputs = super().forward(*inputs) if len(outputs) > 1 and isinstance(outputs[1], dict): diff --git a/rtdetrv2_pytorch/src/data/transforms/_transforms.py b/rtdetrv2_pytorch/src/data/transforms/_transforms.py index 143d3f3b7..5758c9146 100644 --- a/rtdetrv2_pytorch/src/data/transforms/_transforms.py +++ b/rtdetrv2_pytorch/src/data/transforms/_transforms.py @@ -59,6 +59,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: self.padding = [0, 0, w, h] return dict(padding=self.padding) + def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + return self._get_params(flat_inputs) + def __init__(self, size, fill=0, padding_mode='constant') -> None: if isinstance(size, int): size = (size, size) @@ -70,6 +73,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: padding = params['padding'] return F.pad(inpt, padding=padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return self._transform(inpt, params) + def __call__(self, *inputs: Any) -> Any: outputs = super().forward(*inputs) if len(outputs) > 1 and isinstance(outputs[1], dict): @@ -139,4 +145,4 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt def transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._transform(inpt, params) \ No newline at end of file + return self._transform(inpt, params)