From 43395bc212efceb80cdee6b3dbc3b2240e7400a3 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 9 Jan 2025 11:35:43 -0700 Subject: [PATCH 1/6] Update pytorch360convert.py --- pytorch360convert/pytorch360convert.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch360convert/pytorch360convert.py b/pytorch360convert/pytorch360convert.py index 3ac8b97..0d3cf26 100644 --- a/pytorch360convert/pytorch360convert.py +++ b/pytorch360convert/pytorch360convert.py @@ -554,14 +554,15 @@ def cube_h2list(cube_h: torch.Tensor) -> List[torch.Tensor]: Args: cube_h (torch.Tensor): Horizontal cube representation tensor in the - shape of: [w, w*6, C]. + shape of: [w, w*6, C] or [B, w, w*6, C]. Returns: List[torch.Tensor]: List of cube face tensors in the order of: ['Front', 'Right', 'Back', 'Left', 'Up', 'Down'] """ - w = cube_h.shape[0] - return [cube_h[:, i * w : (i + 1) * w, :] for i in range(6)] + assert cube_h.dim() == 3 or cube_h.dim() == 4 + w = cube_h.shape[0] if cube_h.dim() == 3 else cube_h.shape[1] + return [cube_h[..., i * w: (i + 1) * w, :] for i in range(6)] def cube_list2h(cube_list: List[torch.Tensor]) -> torch.Tensor: From 9dfbd936f0fe0e9b88002650e91ce61e7fec824f Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 9 Jan 2025 11:42:54 -0700 Subject: [PATCH 2/6] Update pytorch360convert.py --- pytorch360convert/pytorch360convert.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch360convert/pytorch360convert.py b/pytorch360convert/pytorch360convert.py index 0d3cf26..39c5444 100644 --- a/pytorch360convert/pytorch360convert.py +++ b/pytorch360convert/pytorch360convert.py @@ -579,6 +579,12 @@ def cube_list2h(cube_list: List[torch.Tensor]) -> torch.Tensor: assert all( cube.shape == cube_list[0].shape for cube in cube_list ), "All cube faces should have the same shape." + assert all( + cube.device == cube_list[0].device for cube in cube_list + ), "All cube faces should have the same device." + assert all( + cube.dtype == cube_list[0].dtype for cube in cube_list + ), "All cube faces should have the same dtype." return torch.cat(cube_list, dim=1) From 1e3353560341a7b95d18faa22e14cea1abf8bfb1 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 9 Jan 2025 12:03:16 -0700 Subject: [PATCH 3/6] Update pytorch360convert.py --- pytorch360convert/pytorch360convert.py | 30 +++++++++----------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/pytorch360convert/pytorch360convert.py b/pytorch360convert/pytorch360convert.py index 39c5444..573e4b5 100644 --- a/pytorch360convert/pytorch360convert.py +++ b/pytorch360convert/pytorch360convert.py @@ -38,42 +38,32 @@ def rotation_matrix(rad: torch.Tensor, ax: torch.Tensor) -> torch.Tensor: return R -def _nhwc2nchw(x: torch.Tensor, channels_first: bool = True) -> torch.Tensor: +def _nhwc2nchw(x: torch.Tensor) -> torch.Tensor: """ Convert NHWC to NCHW or HWC to CHW format. Args: x (torch.Tensor): Input tensor to be converted, either in NCHW or CHW format. - channels_first (bool, optional): The channel format of e_img. PyTorch - uses channels first. Default: 'True' Returns: torch.Tensor: The converted tensor in NCHW or CHW format. """ - if x.dim() == 3: - x = x.permute(2, 0, 1) if channels_first else x - else: - x = x.permute(0, 3, 1, 2) if channels_first else x - return x + assert x.dim() == 3 or x.dim() == 4 + return x.permute(2, 0, 1) if x.dim() == 3 else x.permute(0, 3, 1, 2) -def _nchw2nhwc(x: torch.Tensor, channels_first: bool = True) -> torch.Tensor: +def _nchw2nhwc(x: torch.Tensor) -> torch.Tensor: """ Convert NCHW to NHWC or CHW to HWC format. Args: x (torch.Tensor): Input tensor to be converted, either in NCHW or CHW format. - channels_first (bool, optional): The channel format of e_img. PyTorch - uses channels first. Default: 'True' Returns: torch.Tensor: The converted tensor in NHWC or HWC format. """ - if x.dim() == 3: - x = x.permute(1, 2, 0) if channels_first else x - else: - x = x.permute(0, 2, 3, 1) if channels_first else x - return x + assert x.dim() == 3 or x.dim() == 4 + return x.permute(1, 2, 0) if x.dim() == 3 else x.permute(0, 2, 3, 1) def _slice_chunk( @@ -962,7 +952,7 @@ def e2p( assert e_img.dim() == 3 or e_img.dim() == 4 # Ensure input is in HWC format for processing - e_img = _nchw2nhwc(e_img) + e_img = _nchw2nhwc(e_img) if channels_first else e_img if e_img.dim() == 3: h, w = e_img.shape[:2] else: @@ -997,7 +987,7 @@ def e2p( pers_img = sample_equirec(e_img, coor_xy, mode) # Convert back to CHW if required - pers_img = _nhwc2nchw(pers_img, channels_first) + pers_img = _nhwc2nchw(pers_img) if channels_first else pers_img return pers_img @@ -1043,7 +1033,7 @@ def e2e( assert e_img.dim() == 3 or e_img.dim() == 4 # Ensure input is in HWC format for processing - e_img = _nchw2nhwc(e_img) + e_img = _nchw2nhwc(e_img) if channels_first else e_img if e_img.dim() == 3: h, w = e_img.shape[:2] else: @@ -1095,5 +1085,5 @@ def e2e( rotated = sample_equirec(e_img, coor_xy, mode=mode) # Return to original channel format if needed - rotated = _nhwc2nchw(rotated, channels_first) + rotated = _nhwc2nchw(rotated) if channels_first else rotated return rotated From 674291efe1b1baf98017829180cc6fb45c74029b Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 9 Jan 2025 12:11:18 -0700 Subject: [PATCH 4/6] Update pytorch360convert.py --- pytorch360convert/pytorch360convert.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch360convert/pytorch360convert.py b/pytorch360convert/pytorch360convert.py index 573e4b5..94c27a6 100644 --- a/pytorch360convert/pytorch360convert.py +++ b/pytorch360convert/pytorch360convert.py @@ -752,13 +752,13 @@ def c2e( # Ensure input is in HWC format for processing if channels_first: if cube_format == "list" and isinstance(cubemap, (list, tuple)): - cubemap = [r.permute(1, 2, 0) for r in cubemap] + cubemap = [_nchw2nhwc(r) for r in cubemap] elif cube_format == "dict" and torch.jit.isinstance( cubemap, Dict[str, torch.Tensor] ): - cubemap = {k: v.permute(1, 2, 0) for k, v in cubemap.items()} # type: ignore + cubemap = {k: _nchw2nhwc(v) for k, v in cubemap.items()} # type: ignore elif cube_format in ["horizon", "dice"] and isinstance(cubemap, torch.Tensor): - cubemap = cubemap.permute(1, 2, 0) + cubemap = _nchw2nhwc(cubemap) else: raise NotImplementedError("unknown cube_format and cubemap type") @@ -826,7 +826,7 @@ def c2e( equirec = sample_cubefaces(cube_faces, tp, coor_y, coor_x, mode) # Convert back to CHW if required - equirec = equirec.permute(2, 0, 1) if channels_first else equirec + equirec = _nhwc2nchw(equirec) if channels_first else equirec return equirec @@ -872,7 +872,7 @@ def e2c( NotImplementedError: If an unknown cube_format is provided. """ assert len(e_img.shape) == 3 - e_img = e_img.permute(1, 2, 0) if channels_first else e_img + e_img = _nchw2nhwc(e_img) if channels_first else e_img h, w = e_img.shape[:2] # returns [face_w, face_w*6, 3] in order @@ -902,13 +902,13 @@ def e2c( if channels_first: if cube_format == "list" or cube_format == "stack": assert isinstance(result, (list, tuple)) - result = [r.permute(2, 0, 1) for r in result] + result = [_nhwc2nchw(r) for r in result] elif cube_format == "dict": assert torch.jit.isinstance(result, Dict[str, torch.Tensor]) - result = {k: v.permute(2, 0, 1) for k, v in result.items()} # type: ignore[union-attr] + result = {k: _nhwc2nchw(v) for k, v in result.items()} # type: ignore[union-attr] elif cube_format in ["horizon", "dice"]: assert isinstance(result, torch.Tensor) - result = result.permute(2, 0, 1) + result = _nhwc2nchw(result) if cube_format == "stack" and isinstance(result, (list, tuple)): result = torch.stack(result) return result From 5038d25ff782f99f207e611bfddcfde0f173d8a9 Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 9 Jan 2025 12:12:54 -0700 Subject: [PATCH 5/6] Update test_module.py --- tests/test_module.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/test_module.py b/tests/test_module.py index f068475..6bece09 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -165,24 +165,14 @@ def test_rotation_matrix_90deg(self) -> None: def test_nhwc_to_nchw_channels_first(self) -> None: input_tensor = torch.rand(2, 3, 4, 5) - converted_tensor = _nhwc2nchw(input_tensor, channels_first=True) + converted_tensor = _nhwc2nchw(input_tensor) self.assertEqual(converted_tensor.shape, (2, 5, 3, 4)) - def test_nhwc_to_nchw_channels_last(self) -> None: - input_tensor = torch.rand(2, 3, 4, 5) - converted_tensor = _nhwc2nchw(input_tensor, channels_first=False) - self.assertEqual(converted_tensor.shape, (2, 3, 4, 5)) - def test_nchw_to_nhwc_channels_first(self) -> None: input_tensor = torch.rand(2, 5, 3, 4) - converted_tensor = _nchw2nhwc(input_tensor, channels_first=True) + converted_tensor = _nchw2nhwc(input_tensor) self.assertEqual(converted_tensor.shape, (2, 3, 4, 5)) - def test_nchw_to_nhwc_channels_last(self) -> None: - input_tensor = torch.rand(2, 5, 3, 4) - converted_tensor = _nchw2nhwc(input_tensor, channels_first=False) - self.assertEqual(converted_tensor.shape, (2, 5, 3, 4)) - def test_slice_chunk_default(self) -> None: index = 2 width = 3 From 95631aedd528be37439eedb9e8488ad95911d44c Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Thu, 9 Jan 2025 12:31:05 -0700 Subject: [PATCH 6/6] Update pytorch360convert.py --- pytorch360convert/pytorch360convert.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/pytorch360convert/pytorch360convert.py b/pytorch360convert/pytorch360convert.py index 94c27a6..8b157e8 100644 --- a/pytorch360convert/pytorch360convert.py +++ b/pytorch360convert/pytorch360convert.py @@ -740,13 +740,21 @@ def c2e( NotImplementedError: If an unknown cube_format is provided. """ - if cube_format == "stack": - assert ( - isinstance(cubemap, torch.Tensor) - and len(cubemap.shape) == 4 - and cubemap.shape[0] == 6 - ) - cubemap = [cubemap[i] for i in range(cubemap.shape[0])] + if cubemap[0].dim() == 4 or cubemap[0].dim() == 5: + if cubemap[0].dim() == 4: + assert ( + isinstance(cubemap, torch.Tensor) + and len(cubemap.shape) == 4 + and cubemap.shape[0] == 6 + ) + cubemap = [cubemap[i] for i in range(cubemap.shape[0])] + else: + assert ( + isinstance(cubemap, torch.Tensor) + and len(cubemap.shape) == 5 + and cubemap.shape[1] == 6 + ) + cubemap = [cubemap[:, i] for i in range(cubemap.shape[1])] cube_format = "list" # Ensure input is in HWC format for processing