Skip to content

Commit 21b5d51

Browse files
committed
Update
[ghstack-poisoned]
1 parent 3c6167a commit 21b5d51

File tree

3 files changed

+145
-0
lines changed

3 files changed

+145
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, dim=-1):
17+
super().__init__()
18+
self.dim = dim
19+
20+
def forward(self, x):
21+
return torch.nn.functional.log_softmax(x, dim=self.dim)
22+
23+
@operator_test
24+
class TestLogSoftmax(OperatorTest):
25+
@dtype_test
26+
def test_log_softmax_dtype(self, dtype, tester_factory: Callable) -> None:
27+
self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), tester_factory)
28+
29+
def test_log_softmax_f32_dim_last(self, tester_factory: Callable) -> None:
30+
# Default dim is -1 (last dimension)
31+
self._test_op(Model(), (torch.randn(3, 4, 5),), tester_factory)
32+
33+
def test_log_softmax_f32_dim_first(self, tester_factory: Callable) -> None:
34+
# Test with dim=0 (first dimension)
35+
self._test_op(Model(dim=0), (torch.randn(3, 4, 5),), tester_factory)
36+
37+
def test_log_softmax_f32_dim_middle(self, tester_factory: Callable) -> None:
38+
# Test with dim=1 (middle dimension)
39+
self._test_op(Model(dim=1), (torch.randn(3, 4, 5),), tester_factory)
40+
41+
def test_log_softmax_f32_1d_tensor(self, tester_factory: Callable) -> None:
42+
# Test with 1D tensor
43+
self._test_op(Model(), (torch.randn(10),), tester_factory)
44+
45+
def test_log_softmax_f32_large_values(self, tester_factory: Callable) -> None:
46+
# Test with large values to check numerical stability
47+
x = torch.tensor([[1000.0, 0.0, -1000.0]])
48+
self._test_op(Model(), (x,), tester_factory)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def __init__(self, dim=-1):
17+
super().__init__()
18+
self.dim = dim
19+
20+
def forward(self, x):
21+
return torch.nn.functional.softmax(x, dim=self.dim)
22+
23+
@operator_test
24+
class TestSoftmax(OperatorTest):
25+
@dtype_test
26+
def test_softmax_dtype(self, dtype, tester_factory: Callable) -> None:
27+
self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), tester_factory)
28+
29+
def test_softmax_f32_dim_last(self, tester_factory: Callable) -> None:
30+
# Default dim is -1 (last dimension)
31+
self._test_op(Model(), (torch.randn(3, 4, 5),), tester_factory)
32+
33+
def test_softmax_f32_dim_first(self, tester_factory: Callable) -> None:
34+
# Test with dim=0 (first dimension)
35+
self._test_op(Model(dim=0), (torch.randn(3, 4, 5),), tester_factory)
36+
37+
def test_softmax_f32_dim_middle(self, tester_factory: Callable) -> None:
38+
# Test with dim=1 (middle dimension)
39+
self._test_op(Model(dim=1), (torch.randn(3, 4, 5),), tester_factory)
40+
41+
def test_softmax_f32_1d_tensor(self, tester_factory: Callable) -> None:
42+
# Test with 1D tensor
43+
self._test_op(Model(), (torch.randn(10),), tester_factory)
44+
45+
def test_softmax_f32_large_values(self, tester_factory: Callable) -> None:
46+
# Test with large values to check numerical stability
47+
x = torch.tensor([[1000.0, 0.0, -1000.0]])
48+
self._test_op(Model(), (x,), tester_factory)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class Model(torch.nn.Module):
16+
def forward(self, x):
17+
# softmax2d is equivalent to softmax with dim=1 for 4D inputs
18+
return torch.nn.functional.softmax(x, dim=1)
19+
20+
@operator_test
21+
class TestSoftmax2d(OperatorTest):
22+
@dtype_test
23+
def test_softmax2d_dtype(self, dtype, tester_factory: Callable) -> None:
24+
# Input must be 4D (N, C, H, W)
25+
self._test_op(Model(), ((torch.rand(2, 3, 4, 5) * 100).to(dtype),), tester_factory)
26+
27+
def test_softmax2d_f32_various_shapes(self, tester_factory: Callable) -> None:
28+
# Test with different shapes
29+
self._test_op(Model(), (torch.randn(1, 3, 8, 8),), tester_factory)
30+
31+
def test_softmax2d_f32_single_channel(self, tester_factory: Callable) -> None:
32+
# Test with single channel (C=1)
33+
self._test_op(Model(), (torch.randn(2, 1, 4, 4),), tester_factory)
34+
35+
def test_softmax2d_f32_many_channels(self, tester_factory: Callable) -> None:
36+
# Test with many channels
37+
self._test_op(Model(), (torch.randn(2, 16, 4, 4),), tester_factory)
38+
39+
def test_softmax2d_f32_single_batch(self, tester_factory: Callable) -> None:
40+
# Test with single batch (N=1)
41+
self._test_op(Model(), (torch.randn(1, 3, 4, 4),), tester_factory)
42+
43+
def test_softmax2d_f32_large_values(self, tester_factory: Callable) -> None:
44+
# Test with large values to check numerical stability
45+
x = torch.zeros(2, 3, 2, 2)
46+
x[:, 0] = 1000.0 # First channel has large positive values
47+
x[:, 1] = 0.0 # Second channel has zeros
48+
x[:, 2] = -1000.0 # Third channel has large negative values
49+
self._test_op(Model(), (x,), tester_factory)

0 commit comments

Comments
 (0)