Skip to content

Commit f5425ab

Browse files
pallgeuerzhouzaida
andauthored
Add torch_meshgrid wrapper due to PyTorch change (#2044)
* Add torch_meshgrid_ij wrapper due to PyTorch change * Update torch_meshgrid name/doc/version implementation * Make imports local * add ut * ignore ut when torch is not available Co-authored-by: zhouzaida <[email protected]>
1 parent b062468 commit f5425ab

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

.github/workflows/build.yml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,22 @@ jobs:
4545
- name: Run unittests and generate coverage report
4646
run: |
4747
pip install -r requirements/test.txt
48-
pytest tests/ --ignore=tests/test_runner --ignore=tests/test_device/test_ipu --ignore=tests/test_optimizer.py --ignore=tests/test_cnn --ignore=tests/test_parallel.py --ignore=tests/test_ops --ignore=tests/test_load_model_zoo.py --ignore=tests/test_utils/test_logging.py --ignore=tests/test_image/test_io.py --ignore=tests/test_utils/test_registry.py --ignore=tests/test_utils/test_parrots_jit.py --ignore=tests/test_utils/test_trace.py --ignore=tests/test_utils/test_hub.py --ignore=tests/test_device/test_mlu/test_mlu_parallel.py
48+
pytest tests/ \
49+
--ignore=tests/test_runner \
50+
--ignore=tests/test_device/test_ipu \
51+
--ignore=tests/test_optimizer.py \
52+
--ignore=tests/test_cnn \
53+
--ignore=tests/test_parallel.py \
54+
--ignore=tests/test_ops \
55+
--ignore=tests/test_load_model_zoo.py \
56+
--ignore=tests/test_utils/test_logging.py \
57+
--ignore=tests/test_image/test_io.py \
58+
--ignore=tests/test_utils/test_registry.py \
59+
--ignore=tests/test_utils/test_parrots_jit.py \
60+
--ignore=tests/test_utils/test_trace.py \
61+
--ignore=tests/test_utils/test_hub.py \
62+
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \
63+
--ignore=tests/test_utils/test_torch_ops.py
4964
5065
build_without_ops:
5166
runs-on: ubuntu-18.04

mmcv/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
# yapf: enable
5454
from .registry import Registry, build_from_cfg
5555
from .seed import worker_init_fn
56+
from .torch_ops import torch_meshgrid
5657
from .trace import is_jit_tracing
5758
__all__ = [
5859
'Config', 'ConfigDict', 'DictAction', 'collect_env', 'get_logger',
@@ -74,5 +75,6 @@
7475
'assert_params_all_zeros', 'check_python_script',
7576
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
7677
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
77-
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE'
78+
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
79+
'torch_meshgrid'
7880
]

mmcv/utils/torch_ops.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
from .parrots_wrapper import TORCH_VERSION
5+
from .version_utils import digit_version
6+
7+
_torch_version_meshgrid_indexing = (
8+
'parrots' not in TORCH_VERSION
9+
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))
10+
11+
12+
def torch_meshgrid(*tensors):
13+
"""A wrapper of torch.meshgrid to compat different PyTorch versions.
14+
15+
Since PyTorch 1.10.0a0, torch.meshgrid supports the arguments ``indexing``.
16+
So we implement a wrapper here to avoid warning when using high-version
17+
PyTorch and avoid compatibility issues when using previous versions of
18+
PyTorch.
19+
20+
Args:
21+
tensors (List[Tensor]): List of scalars or 1 dimensional tensors.
22+
23+
Returns:
24+
Sequence[Tensor]: Sequence of meshgrid tensors.
25+
"""
26+
if _torch_version_meshgrid_indexing:
27+
return torch.meshgrid(*tensors, indexing='ij')
28+
else:
29+
return torch.meshgrid(*tensors) # Uses indexing='ij' by default

tests/test_utils/test_torch_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
import torch
4+
5+
from mmcv.utils import torch_meshgrid
6+
7+
8+
def test_torch_meshgrid():
9+
# torch_meshgrid should not throw warning
10+
with pytest.warns(None) as record:
11+
x = torch.tensor([1, 2, 3])
12+
y = torch.tensor([4, 5, 6])
13+
grid_x, grid_y = torch_meshgrid(x, y)
14+
15+
assert len(record) == 0

0 commit comments

Comments
 (0)