@@ -50,18 +50,35 @@ static inline bool parseLinearFlatten3d() {
5050// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
5151// before passing it to linear operation
5252static inline Tensor _flatten_nd_linear (const Tensor& input, const Tensor& weight, const Tensor& bias) {
53- const auto input_sizes = input.sym_sizes ();
54- // can't use -1 in reshape because it errors when a dimension is 0
55- c10::SymInt flattened_dim = 1 ;
56- for (int64_t i = 0 , ndim = input_sizes.size (); i < ndim - 1 ; ++i) {
57- flattened_dim = flattened_dim * input_sizes[i];
53+ const auto input_sizes = input.sym_sizes ();
54+
55+ const auto result_flattened = [&]() -> Tensor {
56+ const auto input_ncols = input_sizes.back ();
57+ const auto input_flattened_nrows = [&]() -> c10::SymInt {
58+ // can't use -1 in reshape because it errors when a dimension is 0
59+ auto flattened_nrows = c10::SymInt{1 };
60+ for (const auto & size : input_sizes.slice (0 , input_sizes.size () - 1 )) {
61+ flattened_nrows *= size;
62+ }
63+ return flattened_nrows;
64+ }();
65+
66+ const auto input_flattened = input.view_symint ({input_flattened_nrows, input_ncols});
67+ if (weight.layout () == c10::kStrided ) {
68+ return at::addmm (bias, input_flattened, weight.t ());
69+ } else {
70+ // weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
71+ // so we transpose the problem.
72+ // NOTE: at::matmul handles (dense @ sparse) similarly.
73+ const auto bias_t = (bias.dim () >= 2 ) ? bias.mT () : bias.unsqueeze (-1 );
74+ return at::addmm (bias_t , weight, input_flattened.t ()).t ();
5875 }
59- auto inp_reshape = input. reshape_symint ({flattened_dim, input_sizes. at (input_sizes. size () - 1 )} );
60- const auto result = at::addmm (bias, inp_reshape, weight. t ());
61- auto new_size = input_sizes. slice ( 0 , input_sizes. size () - 1 );
62- c10::SymDimVector sizes_vec (new_size .begin (), new_size .end ()) ;
63- sizes_vec. push_back (result. sym_size (1 ) );
64- return result .view_symint (sizes_vec );
76+ }( );
77+
78+ // Unflatten flattened row dims
79+ auto result_sizes = c10::SymDimVector{input_sizes .begin (), input_sizes .end ()} ;
80+ result_sizes. back () = result_flattened. sym_size (1 );
81+ return result_flattened .view_symint (result_sizes );
6582}
6683
6784
@@ -90,15 +107,23 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
90107 // Fused op is marginally faster.
91108 return at::addmm (*bias, input, weight.t ());
92109 }
93- if (bias->defined () && !input.is_xla ()) {
94- // Also hit the fused path for contiguous 3D input, if not using xla
110+
111+ const auto is_bias_likely_fusable = (
112+ bias->defined () &&
113+ // cuBLASLt: will fuse in the epilogue without copies
114+ // when input/weight/bias are all strided.
115+ // When weight is not strided, bias will not be fused,
116+ // but we can still dispatch here to avoid at::matmul
117+ // path which will probably use a very similar
118+ // flattening optimization.
119+ ((bias->dim () == 1 || bias->squeeze ().dim () == 1 ) && bias->is_contiguous_or_false ())
120+ );
121+ if (is_bias_likely_fusable && !input.is_xla ()) {
122+ // Also hit the fused path for contiguous nD input, if not using xla
95123 // backend. Reshaping/flattening has some performance implications on xla.
96- bool is_contiguous = input.is_contiguous_or_false ();
97- if (is_contiguous && input_dim == 3 ) {
98- return _flatten_nd_linear (input, weight, *bias);
99- } else if (is_contiguous && input.layout () == c10::kStrided && weight.layout () == c10::kStrided && bias->dim () == 1 ) {
124+ if (input.is_contiguous_or_false ()) {
100125 return _flatten_nd_linear (input, weight, *bias);
101- } else if (parseLinearFlatten3d () && input_dim == 3 ) {
126+ } else if (parseLinearFlatten3d ()) {
102127 // If user forces flattening via env var
103128 const Tensor input_cont = input.contiguous ();
104129 return _flatten_nd_linear (input_cont, weight, *bias);
0 commit comments