Skip to content

Commit 42a5bb2

Browse files
haozha111copybara-github
authored andcommitted
test improvement: Add feed forward unit tests
PiperOrigin-RevId: 752950242
1 parent e5dcfa9 commit 42a5bb2

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2025 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
from ai_edge_torch.generative.layers import feed_forward
17+
import torch
18+
import torch.nn.functional as F
19+
from absl.testing import absltest as googletest
20+
21+
22+
class FeedForwardTest(googletest.TestCase):
23+
24+
def test_sequential_feed_forward(self):
25+
ff = feed_forward.SequentialFeedForward(
26+
dim=10,
27+
hidden_dim=10,
28+
activation=F.silu,
29+
use_bias=True,
30+
use_glu=False,
31+
pre_ff_norm=torch.nn.Identity(),
32+
post_ff_norm=torch.nn.Identity(),
33+
)
34+
x = torch.ones((1, 10))
35+
out = ff(x)
36+
self.assertEqual(out.shape, (1, 10))
37+
38+
def test_gated_feed_forward(self):
39+
ff = feed_forward.GatedFeedForward(
40+
dim=10,
41+
hidden_dim=10,
42+
activation=F.silu,
43+
use_bias=True,
44+
use_glu=False,
45+
pre_ff_norm=torch.nn.Identity(),
46+
post_ff_norm=torch.nn.Identity(),
47+
)
48+
x = torch.ones((1, 10))
49+
out = ff(x)
50+
self.assertEqual(out.shape, (1, 10))
51+
52+
53+
if __name__ == "__main__":
54+
googletest.main()

0 commit comments

Comments
 (0)