Skip to content

Commit c594950

Browse files
Revert "nn.Linear: nD contiguous input + bias -- dispatch to addmm also when weight is sparse (pytorch#166071)"
This reverts commit 467c21a. Reverted pytorch#166071 on behalf of https://github.com/atalman due to Multiple CI breakages: test/profiler/test_profiler_tree.py::TestProfilerTree::test_profiler_experimental_tree_with_stack_and_modules [GH job link](https://github.com/pytorch/pytorch/actions/runs/18909087335/job/53976915830) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/467c21ad9ae4133c20a3c098a0355e9ac20d48aa) ([comment](pytorch#166071 (comment)))
1 parent 14102fb commit c594950

File tree

1 file changed

+18
-43
lines changed

1 file changed

+18
-43
lines changed

aten/src/ATen/native/Linear.cpp

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -50,35 +50,18 @@ static inline bool parseLinearFlatten3d() {
5050
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
5151
// before passing it to linear operation
5252
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
53-
const auto input_sizes = input.sym_sizes();
54-
55-
const auto result_flattened = [&]() -> Tensor {
56-
const auto input_ncols = input_sizes.back();
57-
const auto input_flattened_nrows = [&]() -> c10::SymInt {
58-
// can't use -1 in reshape because it errors when a dimension is 0
59-
auto flattened_nrows = c10::SymInt{1};
60-
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
61-
flattened_nrows *= size;
62-
}
63-
return flattened_nrows;
64-
}();
65-
66-
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
67-
if (weight.layout() == c10::kStrided) {
68-
return at::addmm(bias, input_flattened, weight.t());
69-
} else {
70-
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
71-
// so we transpose the problem.
72-
// NOTE: at::matmul handles (dense @ sparse) similarly.
73-
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
74-
return at::addmm(bias_t, weight, input_flattened.t()).t();
53+
const auto input_sizes = input.sym_sizes();
54+
// can't use -1 in reshape because it errors when a dimension is 0
55+
c10::SymInt flattened_dim = 1;
56+
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
57+
flattened_dim = flattened_dim * input_sizes[i];
7558
}
76-
}();
77-
78-
// Unflatten flattened row dims
79-
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
80-
result_sizes.back() = result_flattened.sym_size(1);
81-
return result_flattened.view_symint(result_sizes);
59+
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
60+
const auto result = at::addmm(bias, inp_reshape, weight.t());
61+
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
62+
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
63+
sizes_vec.push_back(result.sym_size(1));
64+
return result.view_symint(sizes_vec);
8265
}
8366

8467

@@ -107,23 +90,15 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
10790
// Fused op is marginally faster.
10891
return at::addmm(*bias, input, weight.t());
10992
}
110-
111-
const auto is_bias_likely_fusable = (
112-
bias->defined() &&
113-
// cuBLASLt: will fuse in the epilogue without copies
114-
// when input/weight/bias are all strided.
115-
// When weight is not strided, bias will not be fused,
116-
// but we can still dispatch here to avoid at::matmul
117-
// path which will probably use a very similar
118-
// flattening optimization.
119-
(bias->dim() == 1 && bias->is_contiguous_or_false())
120-
);
121-
if (is_bias_likely_fusable && !input.is_xla()) {
122-
// Also hit the fused path for contiguous nD input, if not using xla
93+
if (bias->defined() && !input.is_xla()) {
94+
// Also hit the fused path for contiguous 3D input, if not using xla
12395
// backend. Reshaping/flattening has some performance implications on xla.
124-
if (input.is_contiguous_or_false()) {
96+
bool is_contiguous = input.is_contiguous_or_false();
97+
if (is_contiguous && input_dim == 3) {
98+
return _flatten_nd_linear(input, weight, *bias);
99+
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
125100
return _flatten_nd_linear(input, weight, *bias);
126-
} else if (parseLinearFlatten3d()) {
101+
} else if (parseLinearFlatten3d() && input_dim == 3) {
127102
// If user forces flattening via env var
128103
const Tensor input_cont = input.contiguous();
129104
return _flatten_nd_linear(input_cont, weight, *bias);

0 commit comments

Comments
 (0)