66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ #include < array>
10+
11+ #include < c10/util/irange.h>
12+
913#include < executorch/kernels/optimized/blas/CPUBlas.h>
14+ #include < executorch/kernels/optimized/vec/functional_base.h>
15+ #include < executorch/kernels/optimized/vec/vec.h>
1016#include < executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1117#include < executorch/runtime/kernel/kernel_includes.h>
1218
13- #include < array>
14-
1519namespace torch {
1620namespace executor {
1721namespace native {
1822
19- using Tensor = executorch::aten::Tensor;
23+ namespace {
24+ using ::executorch::aten::Tensor;
25+ using ::executorch::cpublas::gemm;
26+ using ::executorch::cpublas::TransposeType;
27+ using ::executorch::vec::map;
28+ using ::executorch::vec::Vectorized;
29+
30+ // Use vector store to initialize with scalar bias.
31+ template <typename scalar_t >
32+ void initialize_scalar (
33+ const ssize_t out_numel,
34+ const scalar_t init,
35+ scalar_t * out) {
36+ using Vec = Vectorized<scalar_t >;
37+
38+ // Initialize a vector with the scalar initial value.
39+ Vec init_vec (init);
40+
41+ int d = 0 ;
42+ for (; d < out_numel - (out_numel % Vec::size ()); d += Vec::size ()) {
43+ // Vector-length store.
44+ init_vec.store (out + d);
45+ }
46+ if (out_numel - d > 0 ) {
47+ // Sub-vector-length store.
48+ init_vec.store (out + d, out_numel - d);
49+ }
50+ }
51+
52+ // Use std::memcpy to initialize with vector bias.
53+ template <typename scalar_t >
54+ void initialize_to_vector (
55+ const int n,
56+ const int m,
57+ const scalar_t * bias,
58+ scalar_t * out) {
59+ // Output is a n x m x scalar_t, while bias is m x scalar_t.
60+ const size_t row_size = static_cast <size_t >(m) * sizeof (scalar_t );
61+ for (const auto col : c10::irange (n)) {
62+ std::memcpy (
63+ // Point to Column `col` of the output tensor.
64+ out + col * m,
65+ bias,
66+ row_size);
67+ }
68+ }
69+
70+ } // namespace
2071
2172Tensor& opt_linear_out (
2273 RuntimeContext& ctx,
2374 const Tensor& in,
2475 const Tensor& mat2,
2576 const optional<Tensor>& bias,
2677 Tensor& out) {
27- ET_KERNEL_CHECK_MSG (
28- ctx,
29- !bias.has_value (),
30- InvalidArgument,
31- out,
32- " bias not supported yet in linear" );
3378 ET_KERNEL_CHECK (ctx, check_linear_args (in, mat2, out), InvalidArgument, out);
3479
3580 size_t output_ndim = 0 ;
@@ -46,28 +91,63 @@ Tensor& opt_linear_out(
4691 return out;
4792 }
4893
49- int flattened_input_dim = 1 ;
94+ ssize_t n = 1 ;
5095 for (int ii = 0 ; ii < in.dim () - 1 ; ++ii) {
51- flattened_input_dim *= in.sizes ()[ii];
96+ n *= in.sizes ()[ii];
5297 }
98+ const ssize_t k = in.sizes ()[in.dim () - 1 ];
99+ const ssize_t m = mat2.size (0 );
100+
101+ if (bias.has_value ()) {
102+ ET_KERNEL_CHECK_MSG (
103+ ctx,
104+ // Either no bias or bias is a 1D tensor of size m or 1.
105+ bias->dtype () == out.dtype (),
106+ InvalidArgument,
107+ out,
108+ " Bias has wrong dimensionality! Expected 1-D tensor of size %ld or empty,"
109+ " but got %zd-D tensor with %ld elements" ,
110+ m,
111+ bias->dim (),
112+ bias->numel ());
113+ }
114+
53115 ET_SWITCH_REAL_TYPES_AND2 (
54- Half, BFloat16, in.scalar_type (), ctx, " mm.out" , CTYPE, [&]() {
55- size_t n = flattened_input_dim;
56- size_t k = in.sizes ()[in.dim () - 1 ];
57- size_t m = mat2.size (0 );
58-
59- executorch::cpublas::gemm (
60- executorch::cpublas::TransposeType::Transpose,
61- executorch::cpublas::TransposeType::NoTranspose,
116+ Half, BFloat16, out.scalar_type (), ctx, " linear.out" , CTYPE, [&] {
117+ // Fill output with bias if it is provided.
118+ if (bias.has_value () && bias->numel () == 1 ) {
119+ // Scalar version of initialization.
120+ initialize_scalar<CTYPE>(
121+ out.numel (),
122+ *bias->const_data_ptr <CTYPE>(),
123+ out.mutable_data_ptr <CTYPE>());
124+ } else if (bias.has_value ()) {
125+ // Assume bias is a 1D tensor of size m.
126+ initialize_to_vector<CTYPE>(
127+ n,
128+ m,
129+ bias->const_data_ptr <CTYPE>(),
130+ out.mutable_data_ptr <CTYPE>());
131+ }
132+
133+ // Set beta to 1 if bias was applied so that GEMM adds to the pre-filled
134+ // bias, otherwise beta remains 0 (i.e. the output is fully overwritten
135+ // by GEMM).
136+ const CTYPE beta =
137+ bias.has_value () ? static_cast <CTYPE>(1 ) : static_cast <CTYPE>(0 );
138+
139+ gemm (
140+ /* transa=*/ TransposeType::Transpose,
141+ /* transb=*/ TransposeType::NoTranspose,
62142 m,
63143 n,
64144 k,
65- static_cast <CTYPE>(1 ),
145+ /* alpha= */ static_cast <CTYPE>(1 ),
66146 mat2.const_data_ptr <CTYPE>(),
67147 k,
68148 in.const_data_ptr <CTYPE>(),
69149 k,
70- static_cast <CTYPE>( 0 ) ,
150+ beta ,
71151 out.mutable_data_ptr <CTYPE>(),
72152 m);
73153 });
0 commit comments