Skip to content

Commit 709f39b

Browse files
committed
Update
[ghstack-poisoned]
1 parent 2adbf61 commit 709f39b

File tree

6 files changed

+1178
-0
lines changed

6 files changed

+1178
-0
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List, Optional, Tuple, Union
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class AmaxModel(torch.nn.Module):
16+
def __init__(
17+
self,
18+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
19+
keepdim: bool = False
20+
):
21+
super().__init__()
22+
self.dim = dim
23+
self.keepdim = keepdim
24+
25+
def forward(self, x):
26+
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
27+
28+
@operator_test
29+
class TestAmax(OperatorTest):
30+
@dtype_test
31+
def test_amax_dtype(self, dtype, tester_factory: Callable) -> None:
32+
# Test with different dtypes
33+
model = AmaxModel().to(dtype)
34+
self._test_op(model, (torch.rand(10, 10).to(dtype),), tester_factory)
35+
36+
def test_amax_basic(self, tester_factory: Callable) -> None:
37+
# Basic test with default parameters (global reduction)
38+
self._test_op(AmaxModel(), (torch.randn(10, 10),), tester_factory)
39+
40+
def test_amax_dim(self, tester_factory: Callable) -> None:
41+
# Test with different dimensions
42+
43+
# 2D tensor, dim=0
44+
self._test_op(AmaxModel(dim=0), (torch.randn(5, 10),), tester_factory)
45+
46+
# 2D tensor, dim=1
47+
self._test_op(AmaxModel(dim=1), (torch.randn(5, 10),), tester_factory)
48+
49+
# 3D tensor, dim=0
50+
self._test_op(AmaxModel(dim=0), (torch.randn(3, 4, 5),), tester_factory)
51+
52+
# 3D tensor, dim=1
53+
self._test_op(AmaxModel(dim=1), (torch.randn(3, 4, 5),), tester_factory)
54+
55+
# 3D tensor, dim=2
56+
self._test_op(AmaxModel(dim=2), (torch.randn(3, 4, 5),), tester_factory)
57+
58+
# 4D tensor, dim=1
59+
self._test_op(AmaxModel(dim=1), (torch.randn(2, 3, 4, 5),), tester_factory)
60+
61+
# Negative dim (last dimension)
62+
self._test_op(AmaxModel(dim=-1), (torch.randn(3, 4, 5),), tester_factory)
63+
64+
# Negative dim (second-to-last dimension)
65+
self._test_op(AmaxModel(dim=-2), (torch.randn(3, 4, 5),), tester_factory)
66+
67+
def test_amax_multi_dim(self, tester_factory: Callable) -> None:
68+
# Test with multiple dimensions
69+
70+
# 3D tensor, dim=(0, 1)
71+
self._test_op(AmaxModel(dim=(0, 1)), (torch.randn(3, 4, 5),), tester_factory)
72+
73+
# 3D tensor, dim=(0, 2)
74+
self._test_op(AmaxModel(dim=(0, 2)), (torch.randn(3, 4, 5),), tester_factory)
75+
76+
# 3D tensor, dim=(1, 2)
77+
self._test_op(AmaxModel(dim=(1, 2)), (torch.randn(3, 4, 5),), tester_factory)
78+
79+
# 4D tensor, dim=(1, 3)
80+
self._test_op(AmaxModel(dim=(1, 3)), (torch.randn(2, 3, 4, 5),), tester_factory)
81+
82+
# 4D tensor, dim=(0, 2)
83+
self._test_op(AmaxModel(dim=(0, 2)), (torch.randn(2, 3, 4, 5),), tester_factory)
84+
85+
# 4D tensor, dim=(-1, -3)
86+
self._test_op(AmaxModel(dim=(-1, -3)), (torch.randn(2, 3, 4, 5),), tester_factory)
87+
88+
# 4D tensor, all dimensions
89+
self._test_op(AmaxModel(dim=(0, 1, 2, 3)), (torch.randn(2, 3, 4, 5),), tester_factory)
90+
91+
def test_amax_keepdim(self, tester_factory: Callable) -> None:
92+
# Test with keepdim=True
93+
94+
# 2D tensor, dim=0, keepdim=True
95+
self._test_op(AmaxModel(dim=0, keepdim=True), (torch.randn(5, 10),), tester_factory)
96+
97+
# 2D tensor, dim=1, keepdim=True
98+
self._test_op(AmaxModel(dim=1, keepdim=True), (torch.randn(5, 10),), tester_factory)
99+
100+
# 3D tensor, dim=1, keepdim=True
101+
self._test_op(AmaxModel(dim=1, keepdim=True), (torch.randn(3, 4, 5),), tester_factory)
102+
103+
# 4D tensor, dim=2, keepdim=True
104+
self._test_op(AmaxModel(dim=2, keepdim=True), (torch.randn(2, 3, 4, 5),), tester_factory)
105+
106+
# Multiple dimensions with keepdim=True
107+
self._test_op(AmaxModel(dim=(1, 2), keepdim=True), (torch.randn(3, 4, 5),), tester_factory)
108+
109+
def test_amax_shapes(self, tester_factory: Callable) -> None:
110+
# Test with different tensor shapes
111+
112+
# 1D tensor
113+
self._test_op(AmaxModel(), (torch.randn(20),), tester_factory)
114+
self._test_op(AmaxModel(dim=0), (torch.randn(20),), tester_factory)
115+
116+
# 2D tensor
117+
self._test_op(AmaxModel(), (torch.randn(5, 10),), tester_factory)
118+
119+
# 3D tensor
120+
self._test_op(AmaxModel(), (torch.randn(3, 4, 5),), tester_factory)
121+
122+
# 4D tensor
123+
self._test_op(AmaxModel(), (torch.randn(2, 3, 4, 5),), tester_factory)
124+
125+
# 5D tensor
126+
self._test_op(AmaxModel(), (torch.randn(2, 2, 3, 4, 5),), tester_factory)
127+
128+
def test_amax_values(self, tester_factory: Callable) -> None:
129+
# Test with different value patterns
130+
131+
# Tensor with clear maximum
132+
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
133+
self._test_op(AmaxModel(), (x,), tester_factory)
134+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
135+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
136+
137+
# Tensor with duplicate maximum values
138+
x = torch.tensor([[3.0, 2.0, 3.0], [6.0, 6.0, 5.0]])
139+
self._test_op(AmaxModel(), (x,), tester_factory)
140+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
141+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
142+
143+
# Tensor with negative values
144+
x = torch.tensor([[-3.0, -2.0, -1.0], [-6.0, -5.0, -4.0]])
145+
self._test_op(AmaxModel(), (x,), tester_factory)
146+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
147+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
148+
149+
# Tensor with mixed positive and negative values
150+
x = torch.tensor([[-3.0, 2.0, -1.0], [6.0, -5.0, 4.0]])
151+
self._test_op(AmaxModel(), (x,), tester_factory)
152+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
153+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
154+
155+
def test_amax_edge_cases(self, tester_factory: Callable) -> None:
156+
# Test edge cases
157+
158+
# Tensor with all same values
159+
x = torch.ones(3, 4)
160+
self._test_op(AmaxModel(), (x,), tester_factory)
161+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
162+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
163+
164+
# Zero tensor
165+
x = torch.zeros(3, 4)
166+
self._test_op(AmaxModel(), (x,), tester_factory)
167+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
168+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
169+
170+
# Tensor with infinity
171+
x = torch.tensor([[1.0, float('inf'), 3.0], [4.0, 5.0, float('inf')]])
172+
self._test_op(AmaxModel(), (x,), tester_factory)
173+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
174+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
175+
176+
# Tensor with negative infinity
177+
x = torch.tensor([[1.0, float('-inf'), 3.0], [4.0, 5.0, float('-inf')]])
178+
self._test_op(AmaxModel(), (x,), tester_factory)
179+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
180+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
181+
182+
# Tensor with NaN (NaN should be propagated)
183+
x = torch.tensor([[1.0, float('nan'), 3.0], [4.0, 5.0, float('nan')]])
184+
self._test_op(AmaxModel(), (x,), tester_factory)
185+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
186+
self._test_op(AmaxModel(dim=1), (x,), tester_factory)
187+
188+
# Single element tensor
189+
x = torch.tensor([5.0])
190+
self._test_op(AmaxModel(), (x,), tester_factory)
191+
self._test_op(AmaxModel(dim=0), (x,), tester_factory)
192+
193+
def test_amax_scalar(self, tester_factory: Callable) -> None:
194+
# Test with scalar input (1-element tensor)
195+
self._test_op(AmaxModel(), (torch.tensor([5.0]),), tester_factory)
196+
self._test_op(AmaxModel(dim=0), (torch.tensor([5.0]),), tester_factory)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List, Optional, Tuple, Union
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class AminModel(torch.nn.Module):
16+
def __init__(
17+
self,
18+
dim: Optional[Union[int, Tuple[int, ...], List[int]]] = None,
19+
keepdim: bool = False
20+
):
21+
super().__init__()
22+
self.dim = dim
23+
self.keepdim = keepdim
24+
25+
def forward(self, x):
26+
return torch.amin(x, dim=self.dim, keepdim=self.keepdim)
27+
28+
@operator_test
29+
class TestAmin(OperatorTest):
30+
@dtype_test
31+
def test_amin_dtype(self, dtype, tester_factory: Callable) -> None:
32+
# Test with different dtypes
33+
model = AminModel().to(dtype)
34+
self._test_op(model, (torch.rand(10, 10).to(dtype),), tester_factory)
35+
36+
def test_amin_basic(self, tester_factory: Callable) -> None:
37+
# Basic test with default parameters (global reduction)
38+
self._test_op(AminModel(), (torch.randn(10, 10),), tester_factory)
39+
40+
def test_amin_dim(self, tester_factory: Callable) -> None:
41+
# Test with different dimensions
42+
43+
# 2D tensor, dim=0
44+
self._test_op(AminModel(dim=0), (torch.randn(5, 10),), tester_factory)
45+
46+
# 2D tensor, dim=1
47+
self._test_op(AminModel(dim=1), (torch.randn(5, 10),), tester_factory)
48+
49+
# 3D tensor, dim=0
50+
self._test_op(AminModel(dim=0), (torch.randn(3, 4, 5),), tester_factory)
51+
52+
# 3D tensor, dim=1
53+
self._test_op(AminModel(dim=1), (torch.randn(3, 4, 5),), tester_factory)
54+
55+
# 3D tensor, dim=2
56+
self._test_op(AminModel(dim=2), (torch.randn(3, 4, 5),), tester_factory)
57+
58+
# 4D tensor, dim=1
59+
self._test_op(AminModel(dim=1), (torch.randn(2, 3, 4, 5),), tester_factory)
60+
61+
# Negative dim (last dimension)
62+
self._test_op(AminModel(dim=-1), (torch.randn(3, 4, 5),), tester_factory)
63+
64+
# Negative dim (second-to-last dimension)
65+
self._test_op(AminModel(dim=-2), (torch.randn(3, 4, 5),), tester_factory)
66+
67+
def test_amin_multi_dim(self, tester_factory: Callable) -> None:
68+
# Test with multiple dimensions
69+
70+
# 3D tensor, dim=(0, 1)
71+
self._test_op(AminModel(dim=(0, 1)), (torch.randn(3, 4, 5),), tester_factory)
72+
73+
# 3D tensor, dim=(0, 2)
74+
self._test_op(AminModel(dim=(0, 2)), (torch.randn(3, 4, 5),), tester_factory)
75+
76+
# 3D tensor, dim=(1, 2)
77+
self._test_op(AminModel(dim=(1, 2)), (torch.randn(3, 4, 5),), tester_factory)
78+
79+
# 4D tensor, dim=(1, 3)
80+
self._test_op(AminModel(dim=(1, 3)), (torch.randn(2, 3, 4, 5),), tester_factory)
81+
82+
# 4D tensor, dim=(0, 2)
83+
self._test_op(AminModel(dim=(0, 2)), (torch.randn(2, 3, 4, 5),), tester_factory)
84+
85+
# 4D tensor, dim=(-1, -3)
86+
self._test_op(AminModel(dim=(-1, -3)), (torch.randn(2, 3, 4, 5),), tester_factory)
87+
88+
# 4D tensor, all dimensions
89+
self._test_op(AminModel(dim=(0, 1, 2, 3)), (torch.randn(2, 3, 4, 5),), tester_factory)
90+
91+
def test_amin_keepdim(self, tester_factory: Callable) -> None:
92+
# Test with keepdim=True
93+
94+
# 2D tensor, dim=0, keepdim=True
95+
self._test_op(AminModel(dim=0, keepdim=True), (torch.randn(5, 10),), tester_factory)
96+
97+
# 2D tensor, dim=1, keepdim=True
98+
self._test_op(AminModel(dim=1, keepdim=True), (torch.randn(5, 10),), tester_factory)
99+
100+
# 3D tensor, dim=1, keepdim=True
101+
self._test_op(AminModel(dim=1, keepdim=True), (torch.randn(3, 4, 5),), tester_factory)
102+
103+
# 4D tensor, dim=2, keepdim=True
104+
self._test_op(AminModel(dim=2, keepdim=True), (torch.randn(2, 3, 4, 5),), tester_factory)
105+
106+
# Multiple dimensions with keepdim=True
107+
self._test_op(AminModel(dim=(1, 2), keepdim=True), (torch.randn(3, 4, 5),), tester_factory)
108+
109+
def test_amin_shapes(self, tester_factory: Callable) -> None:
110+
# Test with different tensor shapes
111+
112+
# 1D tensor
113+
self._test_op(AminModel(), (torch.randn(20),), tester_factory)
114+
self._test_op(AminModel(dim=0), (torch.randn(20),), tester_factory)
115+
116+
# 2D tensor
117+
self._test_op(AminModel(), (torch.randn(5, 10),), tester_factory)
118+
119+
# 3D tensor
120+
self._test_op(AminModel(), (torch.randn(3, 4, 5),), tester_factory)
121+
122+
# 4D tensor
123+
self._test_op(AminModel(), (torch.randn(2, 3, 4, 5),), tester_factory)
124+
125+
# 5D tensor
126+
self._test_op(AminModel(), (torch.randn(2, 2, 3, 4, 5),), tester_factory)
127+
128+
def test_amin_values(self, tester_factory: Callable) -> None:
129+
# Test with different value patterns
130+
131+
# Tensor with clear minimum
132+
x = torch.tensor([[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]])
133+
self._test_op(AminModel(), (x,), tester_factory)
134+
self._test_op(AminModel(dim=0), (x,), tester_factory)
135+
self._test_op(AminModel(dim=1), (x,), tester_factory)
136+
137+
# Tensor with duplicate minimum values
138+
x = torch.tensor([[3.0, 2.0, 2.0], [1.0, 1.0, 5.0]])
139+
self._test_op(AminModel(), (x,), tester_factory)
140+
self._test_op(AminModel(dim=0), (x,), tester_factory)
141+
self._test_op(AminModel(dim=1), (x,), tester_factory)
142+
143+
# Tensor with negative values
144+
x = torch.tensor([[-3.0, -2.0, -1.0], [-6.0, -5.0, -4.0]])
145+
self._test_op(AminModel(), (x,), tester_factory)
146+
self._test_op(AminModel(dim=0), (x,), tester_factory)
147+
self._test_op(AminModel(dim=1), (x,), tester_factory)
148+
149+
# Tensor with mixed positive and negative values
150+
x = torch.tensor([[-3.0, 2.0, -1.0], [6.0, -5.0, 4.0]])
151+
self._test_op(AminModel(), (x,), tester_factory)
152+
self._test_op(AminModel(dim=0), (x,), tester_factory)
153+
self._test_op(AminModel(dim=1), (x,), tester_factory)
154+
155+
def test_amin_edge_cases(self, tester_factory: Callable) -> None:
156+
# Test edge cases
157+
158+
# Tensor with all same values
159+
x = torch.ones(3, 4)
160+
self._test_op(AminModel(), (x,), tester_factory)
161+
self._test_op(AminModel(dim=0), (x,), tester_factory)
162+
self._test_op(AminModel(dim=1), (x,), tester_factory)
163+
164+
# Zero tensor
165+
x = torch.zeros(3, 4)
166+
self._test_op(AminModel(), (x,), tester_factory)
167+
self._test_op(AminModel(dim=0), (x,), tester_factory)
168+
self._test_op(AminModel(dim=1), (x,), tester_factory)
169+
170+
# Tensor with infinity
171+
x = torch.tensor([[1.0, float('inf'), 3.0], [4.0, 5.0, float('inf')]])
172+
self._test_op(AminModel(), (x,), tester_factory)
173+
self._test_op(AminModel(dim=0), (x,), tester_factory)
174+
self._test_op(AminModel(dim=1), (x,), tester_factory)
175+
176+
# Tensor with negative infinity
177+
x = torch.tensor([[1.0, float('-inf'), 3.0], [4.0, 5.0, float('-inf')]])
178+
self._test_op(AminModel(), (x,), tester_factory)
179+
self._test_op(AminModel(dim=0), (x,), tester_factory)
180+
self._test_op(AminModel(dim=1), (x,), tester_factory)
181+
182+
# Tensor with NaN (NaN should be propagated)
183+
x = torch.tensor([[1.0, float('nan'), 3.0], [4.0, 5.0, float('nan')]])
184+
self._test_op(AminModel(), (x,), tester_factory)
185+
self._test_op(AminModel(dim=0), (x,), tester_factory)
186+
self._test_op(AminModel(dim=1), (x,), tester_factory)
187+
188+
# Single element tensor
189+
x = torch.tensor([5.0])
190+
self._test_op(AminModel(), (x,), tester_factory)
191+
self._test_op(AminModel(dim=0), (x,), tester_factory)
192+
193+
def test_amin_scalar(self, tester_factory: Callable) -> None:
194+
# Test with scalar input (1-element tensor)
195+
self._test_op(AminModel(), (torch.tensor([5.0]),), tester_factory)
196+
self._test_op(AminModel(dim=0), (torch.tensor([5.0]),), tester_factory)

0 commit comments

Comments
 (0)