@@ -50,35 +50,18 @@ static inline bool parseLinearFlatten3d() {
5050// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
5151// before passing it to linear operation
5252static inline Tensor _flatten_nd_linear (const Tensor& input, const Tensor& weight, const Tensor& bias) {
53- const auto input_sizes = input.sym_sizes ();
54-
55- const auto result_flattened = [&]() -> Tensor {
56- const auto input_ncols = input_sizes.back ();
57- const auto input_flattened_nrows = [&]() -> c10::SymInt {
58- // can't use -1 in reshape because it errors when a dimension is 0
59- auto flattened_nrows = c10::SymInt{1 };
60- for (const auto & size : input_sizes.slice (0 , input_sizes.size () - 1 )) {
61- flattened_nrows *= size;
62- }
63- return flattened_nrows;
64- }();
65-
66- const auto input_flattened = input.view_symint ({input_flattened_nrows, input_ncols});
67- if (weight.layout () == c10::kStrided ) {
68- return at::addmm (bias, input_flattened, weight.t ());
69- } else {
70- // weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
71- // so we transpose the problem.
72- // NOTE: at::matmul handles (dense @ sparse) similarly.
73- const auto bias_t = (bias.dim () >= 2 ) ? bias.mT () : bias.unsqueeze (-1 );
74- return at::addmm (bias_t , weight, input_flattened.t ()).t ();
53+ const auto input_sizes = input.sym_sizes ();
54+ // can't use -1 in reshape because it errors when a dimension is 0
55+ c10::SymInt flattened_dim = 1 ;
56+ for (int64_t i = 0 , ndim = input_sizes.size (); i < ndim - 1 ; ++i) {
57+ flattened_dim = flattened_dim * input_sizes[i];
7558 }
76- }( );
77-
78- // Unflatten flattened row dims
79- auto result_sizes = c10::SymDimVector{input_sizes .begin (), input_sizes .end ()} ;
80- result_sizes. back () = result_flattened. sym_size (1 );
81- return result_flattened .view_symint (result_sizes );
59+ auto inp_reshape = input. reshape_symint ({flattened_dim, input_sizes. at (input_sizes. size () - 1 )} );
60+ const auto result = at::addmm (bias, inp_reshape, weight. t ());
61+ auto new_size = input_sizes. slice ( 0 , input_sizes. size () - 1 );
62+ c10::SymDimVector sizes_vec (new_size .begin (), new_size .end ()) ;
63+ sizes_vec. push_back (result. sym_size (1 ) );
64+ return result .view_symint (sizes_vec );
8265}
8366
8467
@@ -107,23 +90,15 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
10790 // Fused op is marginally faster.
10891 return at::addmm (*bias, input, weight.t ());
10992 }
110-
111- const auto is_bias_likely_fusable = (
112- bias->defined () &&
113- // cuBLASLt: will fuse in the epilogue without copies
114- // when input/weight/bias are all strided.
115- // When weight is not strided, bias will not be fused,
116- // but we can still dispatch here to avoid at::matmul
117- // path which will probably use a very similar
118- // flattening optimization.
119- (bias->dim () == 1 && bias->is_contiguous_or_false ())
120- );
121- if (is_bias_likely_fusable && !input.is_xla ()) {
122- // Also hit the fused path for contiguous nD input, if not using xla
93+ if (bias->defined () && !input.is_xla ()) {
94+ // Also hit the fused path for contiguous 3D input, if not using xla
12395 // backend. Reshaping/flattening has some performance implications on xla.
124- if (input.is_contiguous_or_false ()) {
96+ bool is_contiguous = input.is_contiguous_or_false ();
97+ if (is_contiguous && input_dim == 3 ) {
98+ return _flatten_nd_linear (input, weight, *bias);
99+ } else if (is_contiguous && input.layout () == c10::kStrided && weight.layout () == c10::kStrided && bias->dim () == 1 ) {
125100 return _flatten_nd_linear (input, weight, *bias);
126- } else if (parseLinearFlatten3d ()) {
101+ } else if (parseLinearFlatten3d () && input_dim == 3 ) {
127102 // If user forces flattening via env var
128103 const Tensor input_cont = input.contiguous ();
129104 return _flatten_nd_linear (input_cont, weight, *bias);
0 commit comments