1818#include  < gtest/gtest.h> 
1919#include  < limits> 
2020
21- using  namespace  ::testing; 
21+ namespace  {
22+ 
2223using  executorch::aten::ArrayRef;
2324using  executorch::aten::Scalar;
2425using  executorch::aten::ScalarType;
@@ -31,7 +32,15 @@ class OpLinearOutTest : public OperatorTest {
3132    return  torch::executor::aten::linear_outf (context_, self, mat2, {}, out);
3233  }
3334
34-   template  <class  CTYPE , executorch::aten::ScalarType DTYPE>
35+   Tensor& op_linear_out (
36+       const  Tensor& self,
37+       const  Tensor& mat2,
38+       const  Tensor& bias,
39+       Tensor& out) {
40+     return  torch::executor::aten::linear_outf (context_, self, mat2, bias, out);
41+   }
42+ 
43+   template  <class  CTYPE , ScalarType DTYPE>
3544  void  test_dtype () {
3645    TensorFactory<DTYPE> tf;
3746
@@ -43,16 +52,16 @@ class OpLinearOutTest : public OperatorTest {
4352      }
4453    }
4554
46-     //  matmul gives 32  * 2 * 3 = 192 
47-     Tensor x = tf.full ({3 , 32 }, 2 );
48-     Tensor y = tf.full ({5 , 32 }, 3 );
55+     //  matmul gives 19  * 2 * 3 = 114 
56+     Tensor x = tf.full ({3 , 19 }, 2 );
57+     Tensor y = tf.full ({5 , 19 }, 3 );
4958
5059    //  Output shape should be (3, 5)
5160    Tensor out = tf.zeros ({3 , 5 });
5261
5362    op_linear_out (x, y, out);
5463
55-     Tensor expected = tf.full ({3 , 5 }, 192 );
64+     Tensor expected = tf.full ({3 , 5 }, 114 );
5665
5766    EXPECT_TENSOR_EQ (out, expected);
5867  }
@@ -88,6 +97,80 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8897  //  for those types.
8998}
9099
100+ TEST_F (OpLinearOutTest, BiasTest) {
101+   TensorFactory<ScalarType::Int> tf;
102+ 
103+   //  Initialize input tensors.
104+   constexpr  int  kReduceDim  = 4 ;
105+   constexpr  int  kDimX  = 3 , kDimY  = 2 ;
106+   constexpr  int  kValueX  = 1 ;
107+   constexpr  int  kValueY  = 2 ;
108+   constexpr  int  kValueBias0  = 4 , kValueBias1  = 7 ;
109+   const  Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
110+   const  Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
111+   const  Tensor b = tf.make ({kDimY }, {kValueBias0 , kValueBias1 });
112+   //  Output matrix is also empty
113+   Tensor out = tf.zeros ({kDimX , kDimY });
114+   //  Initialize expected tensor.
115+   constexpr  int  kValueExpected0  = kValueX  * kValueY  * kReduceDim  + kValueBias0 ;
116+   constexpr  int  kValueExpected1  = kValueX  * kValueY  * kReduceDim  + kValueBias1 ;
117+   //  Check that the bias is added to the correct position in the output matrix.
118+   const  Tensor expected = tf.make (
119+       {kDimX , kDimY },
120+       {kValueExpected0 ,
121+        kValueExpected1 ,
122+        kValueExpected0 ,
123+        kValueExpected1 ,
124+        kValueExpected0 ,
125+        kValueExpected1 });
126+ 
127+   EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
128+ }
129+ 
130+ TEST_F (OpLinearOutTest, BiasBroadcastTest) {
131+   TensorFactory<ScalarType::Int> tf;
132+ 
133+   //  Initialize input tensors.
134+   constexpr  int  kReduceDim  = 4 ;
135+   constexpr  int  kDimX  = 3 , kDimY  = 5 ;
136+   constexpr  int  kValueX  = 1 ;
137+   constexpr  int  kValueY  = 2 ;
138+   constexpr  int  kValueBias  = 4 ;
139+   const  Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
140+   const  Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
141+   const  Tensor b = tf.full ({1 }, kValueBias );
142+   //  Output matrix is also empty
143+   Tensor out = tf.zeros ({kDimX , kDimY });
144+   //  Initialize expected tensor.
145+   constexpr  int  kValueExpected  = kValueX  * kValueY  * kReduceDim  + kValueBias ;
146+   const  Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
147+ 
148+   EXPECT_TENSOR_EQ (op_linear_out (x, y, b, out), expected);
149+ }
150+ 
151+ TEST_F (OpLinearOutTest, BiasDtypeMismatch) {
152+   TensorFactory<ScalarType::Int> tf;
153+   TensorFactory<ScalarType::Short> tf_bias;
154+ 
155+   //  Initialize input tensors.
156+   constexpr  int  kReduceDim  = 4 ;
157+   constexpr  int  kDimX  = 3 , kDimY  = 5 ;
158+   constexpr  int  kValueX  = 1 ;
159+   constexpr  int  kValueY  = 2 ;
160+   constexpr  int  kValueBias  = 4 ;
161+   Tensor x = tf.full ({kDimX , kReduceDim }, kValueX );
162+   Tensor y = tf.full ({kDimY , kReduceDim }, kValueY );
163+   //  Same size as output.
164+   Tensor b = tf_bias.full ({kDimY }, kValueBias );
165+   //  Output matrix is also empty
166+   Tensor out = tf.zeros ({kDimX , kDimY });
167+   //  Initialize expected tensor.
168+   constexpr  int  kValueExpected  = kValueX  * kValueY  * kReduceDim  + kValueBias ;
169+   Tensor expected = tf.full ({kDimX , kDimY }, kValueExpected );
170+ 
171+   ET_EXPECT_KERNEL_FAILURE (context_, op_linear_out (x, y, b, out));
172+ }
173+ 
91174TEST_F (OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92175  TensorFactory<ScalarType::Float> tf;
93176
@@ -297,5 +380,4 @@ TEST_F(OpLinearOutTest, DynamicShapeUnbound) {
297380  Tensor ret = op_linear_out (x, y, out);
298381  EXPECT_TENSOR_CLOSE (out, expected_result);
299382}
300- 
301- //  TODO: support and test bias
383+ } //  namespace
0 commit comments