Skip to content

Commit f5a1e6d

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add support for bias in optimized op_linear.cpp. (#11210)
Summary: Pull Request resolved: #11210 Diff uses `op_add_sub_impl` to add bias after optimized gemm call. Reviewed By: zonglinpeng Differential Revision: D75491158
1 parent f8a3fd8 commit f5a1e6d

File tree

2 files changed

+155
-20
lines changed

2 files changed

+155
-20
lines changed

kernels/optimized/cpu/op_linear.cpp

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,58 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <array>
10+
11+
#include <c10/util/irange.h>
12+
913
#include <executorch/kernels/optimized/blas/CPUBlas.h>
1014
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1115
#include <executorch/runtime/kernel/kernel_includes.h>
1216

13-
#include <array>
14-
1517
namespace torch {
1618
namespace executor {
1719
namespace native {
1820

19-
using Tensor = executorch::aten::Tensor;
21+
namespace {
22+
using ::executorch::aten::Tensor;
23+
using ::executorch::cpublas::gemm;
24+
using ::executorch::cpublas::TransposeType;
25+
26+
template <typename scalar_t>
27+
void initialize_to_bias_scalar(
28+
const int out_numel,
29+
const scalar_t* bias,
30+
scalar_t* out) {
31+
for (const auto i : c10::irange(out_numel)) {
32+
out[i] = *bias;
33+
}
34+
}
35+
36+
template <typename scalar_t>
37+
void initialize_to_bias_vector(
38+
const int n,
39+
const int m,
40+
const scalar_t* bias,
41+
scalar_t* out) {
42+
// Output is a n x m x scalar_t, while bias is m x scalar_t.
43+
const size_t row_size = static_cast<size_t>(m) * sizeof(scalar_t);
44+
for (const auto col : c10::irange(n)) {
45+
std::memcpy(
46+
// Point to Column `col` of the output tensor.
47+
out + col * m,
48+
bias,
49+
row_size);
50+
}
51+
}
52+
53+
} // namespace
2054

2155
Tensor& opt_linear_out(
2256
RuntimeContext& ctx,
2357
const Tensor& in,
2458
const Tensor& mat2,
2559
const optional<Tensor>& bias,
2660
Tensor& out) {
27-
ET_KERNEL_CHECK_MSG(
28-
ctx,
29-
!bias.has_value(),
30-
InvalidArgument,
31-
out,
32-
"bias not supported yet in linear");
3361
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);
3462

3563
size_t output_ndim = 0;
@@ -46,28 +74,63 @@ Tensor& opt_linear_out(
4674
return out;
4775
}
4876

49-
int flattened_input_dim = 1;
77+
ssize_t flattened_input_dim = 1;
5078
for (int ii = 0; ii < in.dim() - 1; ++ii) {
5179
flattened_input_dim *= in.sizes()[ii];
5280
}
81+
const ssize_t n = flattened_input_dim;
82+
const ssize_t k = in.sizes()[in.dim() - 1];
83+
const ssize_t m = mat2.size(0);
84+
85+
if (bias.has_value()) {
86+
ET_KERNEL_CHECK_MSG(
87+
ctx,
88+
// Either no bias or bias is a 1D tensor of size m or 1.
89+
bias->dtype() == out.dtype(),
90+
InvalidArgument,
91+
out,
92+
"Bias has wrong dimensionality! Expected 1-D tensor of size %ld or empty,"
93+
" but got %zd-D tensor with %ld elements",
94+
m,
95+
bias->dim(),
96+
bias->numel());
97+
}
98+
5399
ET_SWITCH_REAL_TYPES_AND2(
54-
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
55-
size_t n = flattened_input_dim;
56-
size_t k = in.sizes()[in.dim() - 1];
57-
size_t m = mat2.size(0);
58-
59-
executorch::cpublas::gemm(
60-
executorch::cpublas::TransposeType::Transpose,
61-
executorch::cpublas::TransposeType::NoTranspose,
100+
Half, BFloat16, out.scalar_type(), ctx, "linear.out", CTYPE, [&] {
101+
if (bias.has_value() && bias->numel() == 1) {
102+
// Scalar version of initialization.
103+
initialize_to_bias_scalar<CTYPE>(
104+
out.numel(),
105+
bias->const_data_ptr<CTYPE>(),
106+
out.mutable_data_ptr<CTYPE>());
107+
} else if (bias.has_value()) {
108+
// Assume bias is a 1D tensor of size m.
109+
initialize_to_bias_vector<CTYPE>(
110+
n,
111+
m,
112+
bias->const_data_ptr<CTYPE>(),
113+
out.mutable_data_ptr<CTYPE>());
114+
}
115+
116+
// Set beta to 1 if bias was applied so that GEMM adds to the pre-filled
117+
// bias, otherwise beta remains 0 (i.e. the output is fully overwritten
118+
// by GEMM).
119+
const CTYPE beta =
120+
bias.has_value() ? static_cast<CTYPE>(1) : static_cast<CTYPE>(0);
121+
122+
gemm(
123+
/*transa=*/TransposeType::Transpose,
124+
/*transb=*/TransposeType::NoTranspose,
62125
m,
63126
n,
64127
k,
65-
static_cast<CTYPE>(1),
128+
/*alpha=*/static_cast<CTYPE>(1),
66129
mat2.const_data_ptr<CTYPE>(),
67130
k,
68131
in.const_data_ptr<CTYPE>(),
69132
k,
70-
static_cast<CTYPE>(0),
133+
beta,
71134
out.mutable_data_ptr<CTYPE>(),
72135
m);
73136
});

kernels/test/op_linear_test.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class OpLinearOutTest : public OperatorTest {
3131
return torch::executor::aten::linear_outf(context_, self, mat2, {}, out);
3232
}
3333

34+
Tensor& op_linear_out(
35+
const Tensor& self,
36+
const Tensor& mat2,
37+
const Tensor& bias,
38+
Tensor& out) {
39+
return torch::executor::aten::linear_outf(context_, self, mat2, bias, out);
40+
}
41+
3442
template <class CTYPE, executorch::aten::ScalarType DTYPE>
3543
void test_dtype() {
3644
TensorFactory<DTYPE> tf;
@@ -88,6 +96,70 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8896
// for those types.
8997
}
9098

99+
TEST_F(OpLinearOutTest, BiasTest) {
100+
TensorFactory<ScalarType::Int> tf;
101+
102+
// Initialize input tensors.
103+
constexpr int kReduceDim = 4;
104+
constexpr int kDimX = 3, kDimY = 5;
105+
constexpr int kValueX = 1;
106+
constexpr int kValueY = 2;
107+
constexpr int kValueBias = 4;
108+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
109+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
110+
Tensor b = tf.full({kDimY}, kValueBias);
111+
// Output matrix is also empty
112+
Tensor out = tf.zeros({kDimX, kDimY});
113+
// Initialize expected tensor.
114+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
115+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
116+
117+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
118+
}
119+
120+
TEST_F(OpLinearOutTest, BiasBroadcastTest) {
121+
TensorFactory<ScalarType::Int> tf;
122+
123+
// Initialize input tensors.
124+
constexpr int kReduceDim = 4;
125+
constexpr int kDimX = 3, kDimY = 5;
126+
constexpr int kValueX = 1;
127+
constexpr int kValueY = 2;
128+
constexpr int kValueBias = 4;
129+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
130+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
131+
Tensor b = tf.full({1}, kValueBias);
132+
// Output matrix is also empty
133+
Tensor out = tf.zeros({kDimX, kDimY});
134+
// Initialize expected tensor.
135+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
136+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
137+
138+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
139+
}
140+
141+
TEST_F(OpLinearOutTest, Bias2DTest) {
142+
TensorFactory<ScalarType::Int> tf;
143+
144+
// Initialize input tensors.
145+
constexpr int kReduceDim = 4;
146+
constexpr int kDimX = 3, kDimY = 5;
147+
constexpr int kValueX = 1;
148+
constexpr int kValueY = 2;
149+
constexpr int kValueBias = 4;
150+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
151+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
152+
// Same size as output.
153+
Tensor b = tf.full({kDimX, kDimY}, kValueBias);
154+
// Output matrix is also empty
155+
Tensor out = tf.zeros({kDimX, kDimY});
156+
// Initialize expected tensor.
157+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
158+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
159+
160+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
161+
}
162+
91163
TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92164
TensorFactory<ScalarType::Float> tf;
93165

0 commit comments

Comments
 (0)