Skip to content

Commit e69df21

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add support for bias in optimized op_linear.cpp. (#11210)
Summary: Pull Request resolved: #11210 Diff uses `op_add_sub_impl` to add bias after optimized gemm call. Reviewed By: zonglinpeng Differential Revision: D75491158
1 parent b308544 commit e69df21

File tree

2 files changed

+79
-17
lines changed

2 files changed

+79
-17
lines changed

kernels/optimized/cpu/op_linear.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,16 @@ namespace torch {
1616
namespace executor {
1717
namespace native {
1818

19-
using Tensor = executorch::aten::Tensor;
19+
using ::executorch::aten::Tensor;
20+
using ::executorch::cpublas::gemm;
21+
using ::executorch::cpublas::TransposeType;
2022

2123
Tensor& opt_linear_out(
2224
RuntimeContext& ctx,
2325
const Tensor& in,
2426
const Tensor& mat2,
2527
const optional<Tensor>& bias,
2628
Tensor& out) {
27-
ET_KERNEL_CHECK_MSG(
28-
ctx,
29-
!bias.has_value(),
30-
InvalidArgument,
31-
out,
32-
"bias not supported yet in linear");
3329
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);
3430

3531
size_t output_ndim = 0;
@@ -46,28 +42,44 @@ Tensor& opt_linear_out(
4642
return out;
4743
}
4844

49-
int flattened_input_dim = 1;
45+
ssize_t flattened_input_dim = 1;
5046
for (int ii = 0; ii < in.dim() - 1; ++ii) {
5147
flattened_input_dim *= in.sizes()[ii];
5248
}
49+
5350
ET_SWITCH_REAL_TYPES_AND2(
54-
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
55-
size_t n = flattened_input_dim;
56-
size_t k = in.sizes()[in.dim() - 1];
57-
size_t m = mat2.size(0);
51+
Half, BFloat16, in.scalar_type(), ctx, "linear.out", CTYPE, [&] {
52+
const ssize_t n = flattened_input_dim;
53+
const ssize_t k = in.sizes()[in.dim() - 1];
54+
const ssize_t m = mat2.size(0);
55+
56+
// Output is a n x m x CTYPE, while bias is m x CTYPE.
57+
const size_t row_size = static_cast<size_t>(m) * sizeof(CTYPE);
58+
for (const auto col : c10::irange(n)) {
59+
std::memcpy(
60+
// Point to Column `col` of the output tensor.
61+
out.mutable_data_ptr<CTYPE>() + col * m,
62+
bias->const_data_ptr<CTYPE>(),
63+
row_size);
64+
}
65+
// Set beta to 1 if bias was applied so that GEMM adds to the pre-filled
66+
// bias, otherwise beta remains 0 (i.e. the output is fully overwritten
67+
// by GEMM).
68+
const CTYPE beta =
69+
bias.has_value() ? static_cast<CTYPE>(1) : static_cast<CTYPE>(0);
5870

59-
executorch::cpublas::gemm(
60-
executorch::cpublas::TransposeType::Transpose,
61-
executorch::cpublas::TransposeType::NoTranspose,
71+
gemm(
72+
/*transa=*/TransposeType::Transpose,
73+
/*transb=*/TransposeType::NoTranspose,
6274
m,
6375
n,
6476
k,
65-
static_cast<CTYPE>(1),
77+
/*alpha=*/static_cast<CTYPE>(1),
6678
mat2.const_data_ptr<CTYPE>(),
6779
k,
6880
in.const_data_ptr<CTYPE>(),
6981
k,
70-
static_cast<CTYPE>(0),
82+
beta,
7183
out.mutable_data_ptr<CTYPE>(),
7284
m);
7385
});

kernels/test/op_linear_test.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class OpLinearOutTest : public OperatorTest {
3131
return torch::executor::aten::linear_outf(context_, self, mat2, {}, out);
3232
}
3333

34+
Tensor& op_linear_out(
35+
const Tensor& self,
36+
const Tensor& mat2,
37+
const Tensor& bias,
38+
Tensor& out) {
39+
return torch::executor::aten::linear_outf(context_, self, mat2, bias, out);
40+
}
41+
3442
template <class CTYPE, executorch::aten::ScalarType DTYPE>
3543
void test_dtype() {
3644
TensorFactory<DTYPE> tf;
@@ -88,6 +96,48 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8896
// for those types.
8997
}
9098

99+
TEST_F(OpLinearOutTest, BiasTest) {
100+
TensorFactory<ScalarType::Int> tf;
101+
102+
// Initialize input tensors.
103+
constexpr int kReduceDim = 4;
104+
constexpr int kDimX = 3, kDimY = 5;
105+
constexpr int kValueX = 1;
106+
constexpr int kValueY = 2;
107+
constexpr int kValueBias = 4;
108+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
109+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
110+
Tensor b = tf.full({kDimY}, kValueBias);
111+
// Output matrix is also empty
112+
Tensor out = tf.zeros({kDimX, kDimY});
113+
// Initialize expected tensor.
114+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
115+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
116+
117+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
118+
}
119+
120+
TEST_F(OpLinearOutTest, BiasBroadcastTest) {
121+
TensorFactory<ScalarType::Int> tf;
122+
123+
// Initialize input tensors.
124+
constexpr int kReduceDim = 4;
125+
constexpr int kDimX = 3, kDimY = 5;
126+
constexpr int kValueX = 1;
127+
constexpr int kValueY = 2;
128+
constexpr int kValueBias = 4;
129+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
130+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
131+
Tensor b = tf.full({1}, kValueBias);
132+
// Output matrix is also empty
133+
Tensor out = tf.zeros({kDimX, kDimY});
134+
// Initialize expected tensor.
135+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
136+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
137+
138+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
139+
}
140+
91141
TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92142
TensorFactory<ScalarType::Float> tf;
93143

0 commit comments

Comments
 (0)