|
2 | 2 | #include <functorch/csrc/Constants.h>
|
3 | 3 | #include <torch/library.h>
|
4 | 4 | #include <ATen/ATen.h>
|
| 5 | +#include <functorch/csrc/TensorWrapper.h> |
| 6 | +#include <functorch/csrc/BatchedTensorImpl.h> |
5 | 7 |
|
6 | 8 | namespace at { namespace functorch {
|
7 | 9 |
|
@@ -34,9 +36,98 @@ Tensor index_select_backward_hack(const Tensor& grad, IntArrayRef self_sizes, in
|
34 | 36 | return at::zeros(self_sizes, grad.options()).index_add(dim, index, grad);
|
35 | 37 | }
|
36 | 38 |
|
| 39 | +// TODO: https://github.com/pytorch/pytorch/issues/69991 |
| 40 | +Tensor frobenius_norm_dim_hack(const Tensor& self, IntArrayRef dim, bool keepdim) { |
| 41 | + if (dim.size() == 1 || dim.size() == 0) { |
| 42 | + return at::norm(self, 2, dim, keepdim); |
| 43 | + } else { |
| 44 | + auto dim_ = dim.vec(); |
| 45 | + maybe_wrap_dims(dim_, self.dim()); |
| 46 | + TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead"); |
| 47 | + if (self.is_complex()){ |
| 48 | + return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim)); |
| 49 | + } else { |
| 50 | + return at::sqrt(at::sum((self * self), dim_, keepdim)); |
| 51 | + } |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +static optional<std::tuple<Tensor,int64_t>> unwrap(const Tensor& tensor) { |
| 56 | + auto* wrapped = maybeGetTensorWrapper(tensor); |
| 57 | + if (wrapped) { |
| 58 | + if (wrapped->level().has_value()) { |
| 59 | + return std::make_tuple(wrapped->value(), *wrapped->level()); |
| 60 | + } |
| 61 | + return unwrap(wrapped->value()); |
| 62 | + } |
| 63 | + auto* batched = maybeGetBatchedImpl(tensor); |
| 64 | + if (batched) { |
| 65 | + return std::make_tuple(batched->value(), batched->level()); |
| 66 | + } |
| 67 | + return nullopt; |
| 68 | +} |
| 69 | + |
| 70 | +static bool can_perform_inplace(const Tensor& a, const Tensor& b) { |
| 71 | + // TODO: generalize this to more transforms |
| 72 | + auto a_ = unwrap(a); |
| 73 | + auto b_ = unwrap(b); |
| 74 | + if (!a_.has_value() && b_.has_value()) { |
| 75 | + return false; |
| 76 | + } |
| 77 | + if (!a_.has_value() && !b_.has_value()) { |
| 78 | + return true; |
| 79 | + } |
| 80 | + if (a_.has_value() && !b_.has_value()) { |
| 81 | + return true; |
| 82 | + } |
| 83 | + TORCH_INTERNAL_ASSERT(a_.has_value() && b_.has_value()); |
| 84 | + |
| 85 | + // If b has any wrapper that a does not, then we cannot do a.inplace_(b) |
| 86 | + if (std::get<1>(*a_) < std::get<1>(*b_)) { |
| 87 | + return false; |
| 88 | + } |
| 89 | + if (std::get<1>(*a_) > std::get<1>(*b_)) { |
| 90 | + return can_perform_inplace(std::get<0>(*a_), b); |
| 91 | + } |
| 92 | + return can_perform_inplace(std::get<0>(*a_), std::get<0>(*b_)); |
| 93 | +} |
| 94 | + |
| 95 | +// TODO: linear is pretty important for performance, but I'm not sure how to work |
| 96 | +// around the in-place. |
| 97 | +Tensor linear_hack(const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt) { |
| 98 | + // See [Note: hacky wrapper removal for optional tensor] |
| 99 | + auto bias = bias_opt.has_value() |
| 100 | + ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt) |
| 101 | + : c10::MaybeOwned<Tensor>::owned(c10::in_place); |
| 102 | + |
| 103 | + if (input.is_mkldnn()) { |
| 104 | + return at::mkldnn_linear(input, weight, *bias); |
| 105 | + } |
| 106 | +#if defined(C10_MOBILE) |
| 107 | + if (xnnpack::use_linear(input, weight, *bias)) { |
| 108 | + return xnnpack::linear(input, weight, *bias); |
| 109 | + } |
| 110 | +#endif |
| 111 | + if (input.dim() == 2 && bias->defined()) { |
| 112 | + // Fused op is marginally faster. |
| 113 | + return at::addmm(*bias, input, weight.t()); |
| 114 | + } |
| 115 | + auto output = at::matmul(input, weight.t()); |
| 116 | + if (bias->defined()) { |
| 117 | + // TODO(rzou): I'm a little uncomfortable with this |
| 118 | + if (can_perform_inplace(output, *bias)) { |
| 119 | + return output.add_(*bias); |
| 120 | + } |
| 121 | + return output.add(*bias); |
| 122 | + } |
| 123 | + return output; |
| 124 | +} |
| 125 | + |
37 | 126 | TORCH_LIBRARY_IMPL(aten, FT_DYNAMIC_LAYER_FRONT_MODE_KEY, m) {
|
38 | 127 | m.impl("value_selecting_reduction_backward", value_selecting_reduction_backward_hack);
|
39 | 128 | m.impl("index_select_backward", index_select_backward_hack);
|
| 129 | + m.impl("frobenius_norm.dim", frobenius_norm_dim_hack); |
| 130 | + m.impl("linear", linear_hack); |
40 | 131 | }
|
41 | 132 |
|
42 | 133 | }}
|
0 commit comments