@@ -1470,6 +1470,98 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa
14701470 yield opinfo_core .SampleInput (make_inp (shape ), args = (pad ,))
14711471
14721472
1473+ def sample_inputs_roi_align (op_info , device , dtype , requires_grad , ** kwargs ):
1474+ del op_info
1475+ del kwargs
1476+ # roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)
1477+
1478+ # Test 1: spatial_scale=1, sampling_ratio=2
1479+ x1 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1480+ roi1 = torch .tensor ([[0 , 1.5 , 1.5 , 3 , 3 ]], dtype = dtype , device = device )
1481+ yield opinfo_core .SampleInput (
1482+ x1 ,
1483+ args = (roi1 , (5 , 5 )),
1484+ kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 2 , "aligned" : True },
1485+ )
1486+
1487+ # Test 2: spatial_scale=0.5, sampling_ratio=3
1488+ x2 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1489+ roi2 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1490+ yield opinfo_core .SampleInput (
1491+ x2 ,
1492+ args = (roi2 , (5 , 5 )),
1493+ kwargs = {"spatial_scale" : 0.5 , "sampling_ratio" : 3 , "aligned" : True },
1494+ )
1495+
1496+ # Test 3: spatial_scale=1.8, sampling_ratio=2
1497+ x3 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1498+ roi3 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1499+ yield opinfo_core .SampleInput (
1500+ x3 ,
1501+ args = (roi3 , (5 , 5 )),
1502+ kwargs = {"spatial_scale" : 1.8 , "sampling_ratio" : 2 , "aligned" : True },
1503+ )
1504+
1505+ # Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
1506+ x4 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1507+ roi4 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1508+ yield opinfo_core .SampleInput (
1509+ x4 ,
1510+ args = (roi4 , (2 , 2 )),
1511+ kwargs = {"spatial_scale" : 2.5 , "sampling_ratio" : 0 , "aligned" : True },
1512+ )
1513+
1514+ # Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
1515+ x5 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1516+ roi5 = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
1517+ yield opinfo_core .SampleInput (
1518+ x5 ,
1519+ args = (roi5 , (2 , 2 )),
1520+ kwargs = {"spatial_scale" : 2.5 , "sampling_ratio" : - 1 , "aligned" : True },
1521+ )
1522+
1523+ # Test 6: malformed boxes (test_roi_align_malformed_boxes)
1524+ x6 = torch .randn (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1525+ roi6 = torch .tensor ([[0 , 2 , 0.3 , 1.5 , 1.5 ]], dtype = dtype , device = device )
1526+ yield opinfo_core .SampleInput (
1527+ x6 ,
1528+ args = (roi6 , (5 , 5 )),
1529+ kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 1 , "aligned" : True },
1530+ )
1531+
1532+ # Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
1533+ x7 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1534+ roi7 = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = dtype , device = device )
1535+ yield opinfo_core .SampleInput (
1536+ x7 ,
1537+ args = (roi7 , (5 , 5 )),
1538+ kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : 2 , "aligned" : False },
1539+ )
1540+
1541+ # Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
1542+ x8 = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1543+ roi8 = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = dtype , device = device )
1544+ yield opinfo_core .SampleInput (
1545+ x8 ,
1546+ args = (roi8 , (5 , 5 )),
1547+ kwargs = {"spatial_scale" : 1.0 , "sampling_ratio" : - 1 , "aligned" : False },
1548+ )
1549+
1550+
1551+ def sample_inputs_roi_pool (op_info , device , dtype , requires_grad , ** kwargs ):
1552+ del op_info
1553+ del kwargs
1554+ # roi_pool signature: (input, boxes, output_size, spatial_scale=1.0)
1555+
1556+ x = torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1557+ rois = torch .tensor ([[0 , 0 , 0 , 4 , 4 ]], dtype = dtype , device = device )
1558+ yield opinfo_core .SampleInput (
1559+ x ,
1560+ args = (rois , (5 , 5 )),
1561+ kwargs = {"spatial_scale" : 2.0 },
1562+ )
1563+
1564+
14731565def sample_inputs_slice_scatter (op_info , device , dtype , requires_grad , ** kwargs ):
14741566 del op_info
14751567 del kwargs
@@ -3038,4 +3130,18 @@ def __init__(self):
30383130 sample_inputs_func = sample_inputs_non_max_suppression ,
30393131 supports_out = False ,
30403132 ),
3133+ opinfo_core .OpInfo (
3134+ "torchvision.ops.roi_align" ,
3135+ op = torchvision .ops .roi_align ,
3136+ dtypes = common_dtype .floating_types (),
3137+ sample_inputs_func = sample_inputs_roi_align ,
3138+ supports_out = False ,
3139+ ),
3140+ opinfo_core .OpInfo (
3141+ "torchvision.ops.roi_pool" ,
3142+ op = torchvision .ops .roi_pool ,
3143+ dtypes = common_dtype .floating_types (),
3144+ sample_inputs_func = sample_inputs_roi_pool ,
3145+ supports_out = False ,
3146+ ),
30413147]
0 commit comments