Skip to content

Commit bfc0ba4

Browse files
nikitavedpytorchmergebot
authored andcommitted
nn.Linear: nD contiguous input + bias -- dispatch to addmm also when weight is sparse (pytorch#166071)
As per title. It seems safe to be able to generalize to arbitrary contiguous inputs since `at::matmul` is likely to do the flattening to avoid `baddmm`. Additionally, we guard for bias to be 1D and contiguous which is guaranteed to be fused with no copies. Pull Request resolved: pytorch#166071 Approved by: https://github.com/ngimel
1 parent 3fdc5db commit bfc0ba4

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

aten/src/ATen/native/Linear.cpp

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,35 @@ static inline bool parseLinearFlatten3d() {
5050
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
5151
// before passing it to linear operation
5252
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
53-
const auto input_sizes = input.sym_sizes();
54-
// can't use -1 in reshape because it errors when a dimension is 0
55-
c10::SymInt flattened_dim = 1;
56-
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
57-
flattened_dim = flattened_dim * input_sizes[i];
53+
const auto input_sizes = input.sym_sizes();
54+
55+
const auto result_flattened = [&]() -> Tensor {
56+
const auto input_ncols = input_sizes.back();
57+
const auto input_flattened_nrows = [&]() -> c10::SymInt {
58+
// can't use -1 in reshape because it errors when a dimension is 0
59+
auto flattened_nrows = c10::SymInt{1};
60+
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
61+
flattened_nrows *= size;
62+
}
63+
return flattened_nrows;
64+
}();
65+
66+
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
67+
if (weight.layout() == c10::kStrided) {
68+
return at::addmm(bias, input_flattened, weight.t());
69+
} else {
70+
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
71+
// so we transpose the problem.
72+
// NOTE: at::matmul handles (dense @ sparse) similarly.
73+
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
74+
return at::addmm(bias_t, weight, input_flattened.t()).t();
5875
}
59-
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
60-
const auto result = at::addmm(bias, inp_reshape, weight.t());
61-
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
62-
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
63-
sizes_vec.push_back(result.sym_size(1));
64-
return result.view_symint(sizes_vec);
76+
}();
77+
78+
// Unflatten flattened row dims
79+
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
80+
result_sizes.back() = result_flattened.sym_size(1);
81+
return result_flattened.view_symint(result_sizes);
6582
}
6683

6784

@@ -90,15 +107,23 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
90107
// Fused op is marginally faster.
91108
return at::addmm(*bias, input, weight.t());
92109
}
93-
if (bias->defined() && !input.is_xla()) {
94-
// Also hit the fused path for contiguous 3D input, if not using xla
110+
111+
const auto is_bias_likely_fusable = (
112+
bias->defined() &&
113+
// cuBLASLt: will fuse in the epilogue without copies
114+
// when input/weight/bias are all strided.
115+
// When weight is not strided, bias will not be fused,
116+
// but we can still dispatch here to avoid at::matmul
117+
// path which will probably use a very similar
118+
// flattening optimization.
119+
((bias->dim() == 1 || bias->squeeze().dim() == 1) && bias->is_contiguous_or_false())
120+
);
121+
if (is_bias_likely_fusable && !input.is_xla()) {
122+
// Also hit the fused path for contiguous nD input, if not using xla
95123
// backend. Reshaping/flattening has some performance implications on xla.
96-
bool is_contiguous = input.is_contiguous_or_false();
97-
if (is_contiguous && input_dim == 3) {
98-
return _flatten_nd_linear(input, weight, *bias);
99-
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
124+
if (input.is_contiguous_or_false()) {
100125
return _flatten_nd_linear(input, weight, *bias);
101-
} else if (parseLinearFlatten3d() && input_dim == 3) {
126+
} else if (parseLinearFlatten3d()) {
102127
// If user forces flattening via env var
103128
const Tensor input_cont = input.contiguous();
104129
return _flatten_nd_linear(input_cont, weight, *bias);

test/profiler/test_profiler_tree.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -624,8 +624,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
624624
torch/nn/modules/module.py(...): __getattr__
625625
<built-in function linear>
626626
aten::linear
627-
aten::reshape
628-
aten::view
627+
aten::view
629628
aten::t
630629
aten::transpose
631630
aten::as_strided
@@ -671,8 +670,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
671670
torch/nn/modules/module.py(...): __getattr__
672671
<built-in function linear>
673672
aten::linear
674-
aten::reshape
675-
aten::view
673+
aten::view
676674
aten::t
677675
aten::transpose
678676
aten::as_strided

0 commit comments

Comments
 (0)