Skip to content

Commit b7f3e02

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add support for bias in optimized op_linear.cpp. (#11210)
Summary: Diff initializes the output tensor before calling gemm with beta=1 when bias is non-nullopt. Reviewed By: larryliu0820, zonglinpeng Differential Revision: D75491158
1 parent 851f5fc commit b7f3e02

File tree

2 files changed

+173
-21
lines changed

2 files changed

+173
-21
lines changed

kernels/optimized/cpu/op_linear.cpp

Lines changed: 101 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,75 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <array>
10+
11+
#include <c10/util/irange.h>
12+
913
#include <executorch/kernels/optimized/blas/CPUBlas.h>
14+
#include <executorch/kernels/optimized/vec/functional_base.h>
15+
#include <executorch/kernels/optimized/vec/vec_base.h>
1016
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1117
#include <executorch/runtime/kernel/kernel_includes.h>
1218

13-
#include <array>
14-
1519
namespace torch {
1620
namespace executor {
1721
namespace native {
1822

19-
using Tensor = executorch::aten::Tensor;
23+
namespace {
24+
using ::executorch::aten::Tensor;
25+
using ::executorch::cpublas::gemm;
26+
using ::executorch::cpublas::TransposeType;
27+
using ::executorch::vec::map;
28+
using ::executorch::vec::Vectorized;
29+
30+
// Use vector store to initialize with scalar bias.
31+
template <typename scalar_t>
32+
void initialize_scalar(
33+
const ssize_t out_numel,
34+
const scalar_t init,
35+
scalar_t* out) {
36+
using Vec = Vectorized<scalar_t>;
37+
38+
// Initialize a vector with the scalar initial value.
39+
Vec init_vec(init);
40+
41+
ssize_t d = 0;
42+
for (; d < out_numel - (out_numel % Vec::size()); d += Vec::size()) {
43+
// Vector-length store.
44+
init_vec.store(out + d);
45+
}
46+
if (out_numel - d > 0) {
47+
// Sub-vector-length store.
48+
init_vec.store(out + d, out_numel - d);
49+
}
50+
}
51+
52+
// Use std::memcpy to initialize with vector bias.
53+
template <typename scalar_t>
54+
void initialize_to_vector(
55+
const ssize_t n,
56+
const ssize_t m,
57+
const scalar_t* bias,
58+
scalar_t* out) {
59+
// Output is a n x m x scalar_t, while bias is m x scalar_t.
60+
const size_t row_size = static_cast<size_t>(m) * sizeof(scalar_t);
61+
for (const auto col : c10::irange(n)) {
62+
std::memcpy(
63+
// Point to Column `col` of the output tensor.
64+
out + col * m,
65+
bias,
66+
row_size);
67+
}
68+
}
69+
70+
} // namespace
2071

2172
Tensor& opt_linear_out(
2273
RuntimeContext& ctx,
2374
const Tensor& in,
2475
const Tensor& mat2,
2576
const optional<Tensor>& bias,
2677
Tensor& out) {
27-
ET_KERNEL_CHECK_MSG(
28-
ctx,
29-
!bias.has_value(),
30-
InvalidArgument,
31-
out,
32-
"bias not supported yet in linear");
3378
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);
3479

3580
size_t output_ndim = 0;
@@ -46,28 +91,63 @@ Tensor& opt_linear_out(
4691
return out;
4792
}
4893

49-
int flattened_input_dim = 1;
94+
ssize_t n = 1;
5095
for (int ii = 0; ii < in.dim() - 1; ++ii) {
51-
flattened_input_dim *= in.sizes()[ii];
96+
n *= in.sizes()[ii];
5297
}
98+
const ssize_t k = in.sizes()[in.dim() - 1];
99+
const ssize_t m = mat2.size(0);
100+
101+
if (bias.has_value()) {
102+
ET_KERNEL_CHECK_MSG(
103+
ctx,
104+
// Either no bias or bias is a 1D tensor of size m or 1.
105+
bias->dtype() == out.dtype(),
106+
InvalidArgument,
107+
out,
108+
"Bias has wrong dimensionality! Expected 1-D tensor of size %d or empty,"
109+
" but got %d-D tensor with %d elements",
110+
static_cast<int>(m),
111+
static_cast<int>(bias->dim()),
112+
static_cast<int>(bias->numel()));
113+
}
114+
53115
ET_SWITCH_REAL_TYPES_AND2(
54-
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
55-
size_t n = flattened_input_dim;
56-
size_t k = in.sizes()[in.dim() - 1];
57-
size_t m = mat2.size(0);
58-
59-
executorch::cpublas::gemm(
60-
executorch::cpublas::TransposeType::Transpose,
61-
executorch::cpublas::TransposeType::NoTranspose,
116+
Half, BFloat16, out.scalar_type(), ctx, "linear.out", CTYPE, [&] {
117+
// Fill output with bias if it is provided.
118+
if (bias.has_value() && bias->numel() == 1) {
119+
// Scalar version of initialization.
120+
initialize_scalar<CTYPE>(
121+
out.numel(),
122+
*bias->const_data_ptr<CTYPE>(),
123+
out.mutable_data_ptr<CTYPE>());
124+
} else if (bias.has_value()) {
125+
// Assume bias is a 1D tensor of size m.
126+
initialize_to_vector<CTYPE>(
127+
n,
128+
m,
129+
bias->const_data_ptr<CTYPE>(),
130+
out.mutable_data_ptr<CTYPE>());
131+
}
132+
133+
// Set beta to 1 if bias was applied so that GEMM adds to the pre-filled
134+
// bias, otherwise beta remains 0 (i.e. the output is fully overwritten
135+
// by GEMM).
136+
const CTYPE beta =
137+
bias.has_value() ? static_cast<CTYPE>(1) : static_cast<CTYPE>(0);
138+
139+
gemm(
140+
/*transa=*/TransposeType::Transpose,
141+
/*transb=*/TransposeType::NoTranspose,
62142
m,
63143
n,
64144
k,
65-
static_cast<CTYPE>(1),
145+
/*alpha=*/static_cast<CTYPE>(1),
66146
mat2.const_data_ptr<CTYPE>(),
67147
k,
68148
in.const_data_ptr<CTYPE>(),
69149
k,
70-
static_cast<CTYPE>(0),
150+
beta,
71151
out.mutable_data_ptr<CTYPE>(),
72152
m);
73153
});

kernels/test/op_linear_test.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class OpLinearOutTest : public OperatorTest {
3131
return torch::executor::aten::linear_outf(context_, self, mat2, {}, out);
3232
}
3333

34+
Tensor& op_linear_out(
35+
const Tensor& self,
36+
const Tensor& mat2,
37+
const Tensor& bias,
38+
Tensor& out) {
39+
return torch::executor::aten::linear_outf(context_, self, mat2, bias, out);
40+
}
41+
3442
template <class CTYPE, executorch::aten::ScalarType DTYPE>
3543
void test_dtype() {
3644
TensorFactory<DTYPE> tf;
@@ -88,6 +96,70 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8896
// for those types.
8997
}
9098

99+
TEST_F(OpLinearOutTest, BiasTest) {
100+
TensorFactory<ScalarType::Int> tf;
101+
102+
// Initialize input tensors.
103+
constexpr int kReduceDim = 4;
104+
constexpr int kDimX = 3, kDimY = 5;
105+
constexpr int kValueX = 1;
106+
constexpr int kValueY = 2;
107+
constexpr int kValueBias = 4;
108+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
109+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
110+
Tensor b = tf.full({kDimY}, kValueBias);
111+
// Output matrix is also empty
112+
Tensor out = tf.zeros({kDimX, kDimY});
113+
// Initialize expected tensor.
114+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
115+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
116+
117+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
118+
}
119+
120+
TEST_F(OpLinearOutTest, BiasBroadcastTest) {
121+
TensorFactory<ScalarType::Int> tf;
122+
123+
// Initialize input tensors.
124+
constexpr int kReduceDim = 4;
125+
constexpr int kDimX = 3, kDimY = 5;
126+
constexpr int kValueX = 1;
127+
constexpr int kValueY = 2;
128+
constexpr int kValueBias = 4;
129+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
130+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
131+
Tensor b = tf.full({1}, kValueBias);
132+
// Output matrix is also empty
133+
Tensor out = tf.zeros({kDimX, kDimY});
134+
// Initialize expected tensor.
135+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
136+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
137+
138+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
139+
}
140+
141+
TEST_F(OpLinearOutTest, Bias2DTest) {
142+
TensorFactory<ScalarType::Int> tf;
143+
144+
// Initialize input tensors.
145+
constexpr int kReduceDim = 4;
146+
constexpr int kDimX = 3, kDimY = 5;
147+
constexpr int kValueX = 1;
148+
constexpr int kValueY = 2;
149+
constexpr int kValueBias = 4;
150+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
151+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
152+
// Same size as output.
153+
Tensor b = tf.full({kDimX, kDimY}, kValueBias);
154+
// Output matrix is also empty
155+
Tensor out = tf.zeros({kDimX, kDimY});
156+
// Initialize expected tensor.
157+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
158+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
159+
160+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
161+
}
162+
91163
TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92164
TensorFactory<ScalarType::Float> tf;
93165

0 commit comments

Comments
 (0)