@@ -31,6 +31,14 @@ class OpLinearOutTest : public OperatorTest {
3131 return torch::executor::aten::linear_outf (context_, self, mat2, {}, out);
3232 }
3333
34+ Tensor& op_linear_out (
35+ const Tensor& self,
36+ const Tensor& mat2,
37+ const Tensor& bias,
38+ Tensor& out) {
39+ return torch::executor::aten::linear_outf (context_, self, mat2, bias, out);
40+ }
41+
3442 template <class CTYPE , executorch::aten::ScalarType DTYPE>
3543 void test_dtype () {
3644 TensorFactory<DTYPE> tf;
@@ -88,6 +96,70 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8896 // for those types.
8997}
9098
99+ TEST_F (OpLinearOutTest, BiasTest) {
100+ TensorFactory<ScalarType::Int> tf;
101+
102+ // Initialize input tensors.
103+ constexpr int kReduceDim = 4 ;
104+ constexpr int kDimX = 3 , kDimY = 5 ;
105+ constexpr int kValueX = 1 ;
106+ constexpr int kValueY = 2 ;
107+ constexpr int kValueBias = 4 ;
108+ Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
109+ Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
110+ Tensor b = tf.full ({kDimY }, kValueBias );
111+ // Output matrix is also empty
112+ Tensor out = tf.zeros ({kDimX , kDimY });
113+ // Initialize expected tensor.
114+ constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias ;
115+ Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
116+
117+ EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
118+ }
119+
120+ TEST_F (OpLinearOutTest, BiasBroadcastTest) {
121+ TensorFactory<ScalarType::Int> tf;
122+
123+ // Initialize input tensors.
124+ constexpr int kReduceDim = 4 ;
125+ constexpr int kDimX = 3 , kDimY = 5 ;
126+ constexpr int kValueX = 1 ;
127+ constexpr int kValueY = 2 ;
128+ constexpr int kValueBias = 4 ;
129+ Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
130+ Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
131+ Tensor b = tf.full ({1 }, kValueBias );
132+ // Output matrix is also empty
133+ Tensor out = tf.zeros ({kDimX , kDimY });
134+ // Initialize expected tensor.
135+ constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias ;
136+ Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
137+
138+ EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
139+ }
140+
141+ TEST_F (OpLinearOutTest, Bias2DTest) {
142+ TensorFactory<ScalarType::Int> tf;
143+
144+ // Initialize input tensors.
145+ constexpr int kReduceDim = 4 ;
146+ constexpr int kDimX = 3 , kDimY = 5 ;
147+ constexpr int kValueX = 1 ;
148+ constexpr int kValueY = 2 ;
149+ constexpr int kValueBias = 4 ;
150+ Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
151+ Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
152+ // Same size as output.
153+ Tensor b = tf.full ({kDimX , kDimY }, kValueBias );
154+ // Output matrix is also empty
155+ Tensor out = tf.zeros ({kDimX , kDimY });
156+ // Initialize expected tensor.
157+ constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias ;
158+ Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
159+
160+ EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
161+ }
162+
91163TEST_F (OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92164 TensorFactory<ScalarType::Float> tf;
93165
0 commit comments