1818#include < gtest/gtest.h>
1919#include < limits>
2020
21- using namespace ::testing;
21+ namespace {
22+
2223using executorch::aten::ArrayRef;
2324using executorch::aten::Scalar;
2425using executorch::aten::ScalarType;
@@ -31,7 +32,15 @@ class OpLinearOutTest : public OperatorTest {
3132 return torch::executor::aten::linear_outf (context_, self, mat2, {}, out);
3233 }
3334
34- template <class CTYPE , executorch::aten::ScalarType DTYPE>
35+ Tensor& op_linear_out (
36+ const Tensor& self,
37+ const Tensor& mat2,
38+ const Tensor& bias,
39+ Tensor& out) {
40+ return torch::executor::aten::linear_outf (context_, self, mat2, bias, out);
41+ }
42+
43+ template <class CTYPE , ScalarType DTYPE>
3544 void test_dtype () {
3645 TensorFactory<DTYPE> tf;
3746
@@ -43,16 +52,16 @@ class OpLinearOutTest : public OperatorTest {
4352 }
4453 }
4554
46- // matmul gives 32 * 2 * 3 = 192
47- Tensor x = tf.full ({3 , 32 }, 2 );
48- Tensor y = tf.full ({5 , 32 }, 3 );
55+ // matmul gives 19 * 2 * 3 = 114
56+ Tensor x = tf.full ({3 , 19 }, 2 );
57+ Tensor y = tf.full ({5 , 19 }, 3 );
4958
5059 // Output shape should be (3, 5)
5160 Tensor out = tf.zeros ({3 , 5 });
5261
5362 op_linear_out (x, y, out);
5463
55- Tensor expected = tf.full ({3 , 5 }, 192 );
64+ Tensor expected = tf.full ({3 , 5 }, 114 );
5665
5766 EXPECT_TENSOR_EQ (out, expected);
5867 }
@@ -88,6 +97,80 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8897 // for those types.
8998}
9099
100+ TEST_F (OpLinearOutTest, BiasTest) {
101+ TensorFactory<ScalarType::Int> tf;
102+
103+ // Initialize input tensors.
104+ constexpr int kReduceDim = 4 ;
105+ constexpr int kDimX = 3 , kDimY = 2 ;
106+ constexpr int kValueX = 1 ;
107+ constexpr int kValueY = 2 ;
108+ constexpr int kValueBias0 = 4 , kValueBias1 = 7 ;
109+ const Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
110+ const Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
111+ const Tensor b = tf.make ({kDimY }, {kValueBias0 , kValueBias1 });
112+ // Output matrix is also empty
113+ Tensor out = tf.zeros ({kDimX , kDimY });
114+ // Initialize expected tensor.
115+ constexpr int kValueExpected0 = kValueX * kValueY * kReduceDim + kValueBias0 ;
116+ constexpr int kValueExpected1 = kValueX * kValueY * kReduceDim + kValueBias1 ;
117+ // Check that the bias is added to the correct position in the output matrix.
118+ const Tensor expected = tf.make (
119+ {kDimX , kDimY },
120+ {kValueExpected0 ,
121+ kValueExpected1 ,
122+ kValueExpected0 ,
123+ kValueExpected1 ,
124+ kValueExpected0 ,
125+ kValueExpected1 });
126+
127+ EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
128+ }
129+
130+ TEST_F (OpLinearOutTest, BiasBroadcastTest) {
131+ TensorFactory<ScalarType::Int> tf;
132+
133+ // Initialize input tensors.
134+ constexpr int kReduceDim = 4 ;
135+ constexpr int kDimX = 3 , kDimY = 5 ;
136+ constexpr int kValueX = 1 ;
137+ constexpr int kValueY = 2 ;
138+ constexpr int kValueBias = 4 ;
139+ const Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
140+ const Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
141+ const Tensor b = tf.full ({1 }, kValueBias );
142+ // Output matrix is also empty
143+ Tensor out = tf.zeros ({kDimX , kDimY });
144+ // Initialize expected tensor.
145+ constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias ;
146+ const Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
147+
148+ EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
149+ }
150+
151+ TEST_F (OpLinearOutTest, BiasDtypeMismatch) {
152+ TensorFactory<ScalarType::Int> tf;
153+ TensorFactory<ScalarType::Short> tf_bias;
154+
155+ // Initialize input tensors.
156+ constexpr int kReduceDim = 4 ;
157+ constexpr int kDimX = 3 , kDimY = 5 ;
158+ constexpr int kValueX = 1 ;
159+ constexpr int kValueY = 2 ;
160+ constexpr int kValueBias = 4 ;
161+ Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
162+ Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
163+ // Same size as output.
164+ Tensor b = tf_bias.full ({kDimY }, kValueBias );
165+ // Output matrix is also empty
166+ Tensor out = tf.zeros ({kDimX , kDimY });
167+ // Initialize expected tensor.
168+ constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias ;
169+ Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
170+
171+ ET_EXPECT_KERNEL_FAILURE (context_, op_linear_out (x, y, b, out));
172+ }
173+
91174TEST_F (OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92175 TensorFactory<ScalarType::Float> tf;
93176
@@ -297,5 +380,4 @@ TEST_F(OpLinearOutTest, DynamicShapeUnbound) {
297380 Tensor ret = op_linear_out (x, y, out);
298381 EXPECT_TENSOR_CLOSE (out, expected_result);
299382}
300-
301- // TODO: support and test bias
383+ } // namespace
0 commit comments