Skip to content

Commit e80ebe8

Browse files
authored
[torchlib] Migrate torchvision implementations (#2569)
Adapted from https://github.com/pytorch/vision/blob/main/torchvision/ops/_register_onnx_ops.py --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 6e91205 commit e80ebe8

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

onnxscript/function_libs/torch_lib/ops/vision.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,87 @@
77

88
from __future__ import annotations
99

10+
import warnings
11+
from typing import Sequence
12+
1013
from onnxscript.function_libs.torch_lib.registration import torch_op
1114
from onnxscript.onnx_opset import opset18 as op
1215
from onnxscript.onnx_types import FLOAT, INT64
1316

1417
_INT64_MAX = 0x7FFFFFFFFFFFFFFF
1518

1619

17-
@torch_op("torchvision::nms")
20+
@torch_op("torchvision::nms", trace_only=True)
1821
def torchvision_nms(boxes: FLOAT, scores: FLOAT, iou_threshold: float) -> INT64:
22+
"""nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor"""
1923
# boxes: [num_batches, spatial_dimension, 4]
2024
boxes = op.Unsqueeze(boxes, [0])
2125
# scores: [num_batches, num_classes, spatial_dimension]
2226
scores = op.Unsqueeze(scores, [0, 1])
2327
# nms_out: [num_selected_indices, 3] where each column is [batch_index, class_index, box_index]
2428
nms_out = op.NonMaxSuppression(boxes, scores, _INT64_MAX, iou_threshold)
2529
return op.Reshape(op.Slice(nms_out, axes=[1], starts=[2], ends=[3]), [-1])
30+
31+
32+
def _process_batch_indices_for_roi_align(rois):
33+
# Extract batch indices from the first column (index 0) of rois
34+
indices = op.Slice(rois, axes=[1], starts=[0], ends=[1])
35+
indices = op.Squeeze(indices, axes=[1])
36+
return op.Cast(indices, to=INT64.dtype)
37+
38+
39+
def _process_rois_for_roi_align(rois):
40+
# Extract roi coordinates from columns 1, 2, 3, 4 (x1, y1, x2, y2)
41+
return op.Slice(rois, axes=[1], starts=[1], ends=[5])
42+
43+
44+
def _process_sampling_ratio_for_roi_align(sampling_ratio: int):
45+
if sampling_ratio < 0:
46+
warnings.warn(
47+
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
48+
"The model will be exported with a sampling_ratio of 0.",
49+
stacklevel=2,
50+
)
51+
sampling_ratio = 0
52+
return sampling_ratio
53+
54+
55+
@torch_op("torchvision::roi_align", trace_only=True)
56+
def torchvision_roi_align(
57+
input,
58+
boxes,
59+
output_size: Sequence[int],
60+
spatial_scale: float = 1.0,
61+
sampling_ratio: int = -1,
62+
aligned: bool = False,
63+
):
64+
"""roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor"""
65+
pooled_height, pooled_width = output_size
66+
batch_indices = _process_batch_indices_for_roi_align(boxes)
67+
rois_coords = _process_rois_for_roi_align(boxes)
68+
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
69+
sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio)
70+
71+
return op.RoiAlign(
72+
input,
73+
rois_coords,
74+
batch_indices,
75+
coordinate_transformation_mode=coordinate_transformation_mode,
76+
spatial_scale=spatial_scale,
77+
output_height=pooled_height,
78+
output_width=pooled_width,
79+
sampling_ratio=sampling_ratio,
80+
)
81+
82+
83+
@torch_op("torchvision::roi_pool", trace_only=True)
84+
def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale: float = 1.0):
85+
"""roi_pool(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0) -> torch.Tensor"""
86+
# MaxRoiPool expects boxes in format [batch_index, x1, y1, x2, y2]
87+
pooled_height, pooled_width = output_size
88+
return op.MaxRoiPool(
89+
input,
90+
boxes,
91+
pooled_shape=(pooled_height, pooled_width),
92+
spatial_scale=spatial_scale,
93+
)

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,98 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa
14701470
yield opinfo_core.SampleInput(make_inp(shape), args=(pad,))
14711471

14721472

1473+
def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs):
1474+
del op_info
1475+
del kwargs
1476+
# roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)
1477+
1478+
# Test 1: spatial_scale=1, sampling_ratio=2
1479+
x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1480+
roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device)
1481+
yield opinfo_core.SampleInput(
1482+
x1,
1483+
args=(roi1, (5, 5)),
1484+
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True},
1485+
)
1486+
1487+
# Test 2: spatial_scale=0.5, sampling_ratio=3
1488+
x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1489+
roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1490+
yield opinfo_core.SampleInput(
1491+
x2,
1492+
args=(roi2, (5, 5)),
1493+
kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True},
1494+
)
1495+
1496+
# Test 3: spatial_scale=1.8, sampling_ratio=2
1497+
x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1498+
roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1499+
yield opinfo_core.SampleInput(
1500+
x3,
1501+
args=(roi3, (5, 5)),
1502+
kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True},
1503+
)
1504+
1505+
# Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
1506+
x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1507+
roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1508+
yield opinfo_core.SampleInput(
1509+
x4,
1510+
args=(roi4, (2, 2)),
1511+
kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True},
1512+
)
1513+
1514+
# Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
1515+
x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1516+
roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1517+
yield opinfo_core.SampleInput(
1518+
x5,
1519+
args=(roi5, (2, 2)),
1520+
kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True},
1521+
)
1522+
1523+
# Test 6: malformed boxes (test_roi_align_malformed_boxes)
1524+
x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1525+
roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device)
1526+
yield opinfo_core.SampleInput(
1527+
x6,
1528+
args=(roi6, (5, 5)),
1529+
kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True},
1530+
)
1531+
1532+
# Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
1533+
x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1534+
roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
1535+
yield opinfo_core.SampleInput(
1536+
x7,
1537+
args=(roi7, (5, 5)),
1538+
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False},
1539+
)
1540+
1541+
# Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
1542+
x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1543+
roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
1544+
yield opinfo_core.SampleInput(
1545+
x8,
1546+
args=(roi8, (5, 5)),
1547+
kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False},
1548+
)
1549+
1550+
1551+
def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs):
1552+
del op_info
1553+
del kwargs
1554+
# roi_pool signature: (input, boxes, output_size, spatial_scale=1.0)
1555+
1556+
x = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1557+
rois = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
1558+
yield opinfo_core.SampleInput(
1559+
x,
1560+
args=(rois, (5, 5)),
1561+
kwargs={"spatial_scale": 2.0},
1562+
)
1563+
1564+
14731565
def sample_inputs_slice_scatter(op_info, device, dtype, requires_grad, **kwargs):
14741566
del op_info
14751567
del kwargs
@@ -3038,4 +3130,18 @@ def __init__(self):
30383130
sample_inputs_func=sample_inputs_non_max_suppression,
30393131
supports_out=False,
30403132
),
3133+
opinfo_core.OpInfo(
3134+
"torchvision.ops.roi_align",
3135+
op=torchvision.ops.roi_align,
3136+
dtypes=common_dtype.floating_types(),
3137+
sample_inputs_func=sample_inputs_roi_align,
3138+
supports_out=False,
3139+
),
3140+
opinfo_core.OpInfo(
3141+
"torchvision.ops.roi_pool",
3142+
op=torchvision.ops.roi_pool,
3143+
dtypes=common_dtype.floating_types(),
3144+
sample_inputs_func=sample_inputs_roi_pool,
3145+
supports_out=False,
3146+
),
30413147
]

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,8 @@ def _where_input_wrangler(
18721872
),
18731873
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like),
18741874
TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms),
1875+
TorchLibOpInfo("torchvision.ops.roi_align", vision_ops.torchvision_roi_align),
1876+
TorchLibOpInfo("torchvision.ops.roi_pool", vision_ops.torchvision_roi_pool),
18751877
)
18761878

18771879
ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims"))

0 commit comments

Comments
 (0)