@@ -31,6 +31,14 @@ class OpLinearOutTest : public OperatorTest {
3131 return torch::executor::aten::linear_outf (context_, self, mat2, {}, out);
3232 }
3333
34+ Tensor& op_linear_out (
35+ const Tensor& self,
36+ const Tensor& mat2,
37+ const Tensor& bias,
38+ Tensor& out) {
39+ return torch::executor::aten::linear_outf (context_, self, mat2, bias, out);
40+ }
41+
3442 template <class CTYPE , executorch::aten::ScalarType DTYPE>
3543 void test_dtype () {
3644 TensorFactory<DTYPE> tf;
@@ -88,6 +96,48 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8896 // for those types.
8997}
9098
99+ TEST_F (OpLinearOutTest, BiasTest) {
100+ TensorFactory<ScalarType::Int> tf;
101+
102+ // Initialize input tensors.
103+ constexpr int kReduceDim = 4 ;
104+ constexpr int kDimX = 3 , kDimY = 5 ;
105+ constexpr int kValueX = 1 ;
106+ constexpr int kValueY = 2 ;
107+ constexpr int kValueBias = 4 ;
108+ Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
109+ Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
110+ Tensor b = tf.full ({kDimY }, kValueBias );
111+ // Output matrix is also empty
112+ Tensor out = tf.zeros ({kDimX , kDimY });
113+ // Initialize expected tensor.
114+ constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias ;
115+ Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
116+
117+ EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
118+ }
119+
120+ TEST_F (OpLinearOutTest, BiasBroadcastTest) {
121+ TensorFactory<ScalarType::Int> tf;
122+
123+ // Initialize input tensors.
124+ constexpr int kReduceDim = 4 ;
125+ constexpr int kDimX = 3 , kDimY = 5 ;
126+ constexpr int kValueX = 1 ;
127+ constexpr int kValueY = 2 ;
128+ constexpr int kValueBias = 4 ;
129+ Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
130+ Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
131+ Tensor b = tf.full ({1 }, kValueBias );
132+ // Output matrix is also empty
133+ Tensor out = tf.zeros ({kDimX , kDimY });
134+ // Initialize expected tensor.
135+ constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias ;
136+ Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
137+
138+ EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
139+ }
140+
91141TEST_F (OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92142 TensorFactory<ScalarType::Float> tf;
93143
0 commit comments