Skip to content

Commit 2f6387e

Browse files
Lucaskabelabdhirsh
andauthored
[CherrryPick][2.9] Cherry pick request for Reapply "Make functionalization ViewMeta serializable with pickle pytorch#163769 (pytorch#163873)
Reapply "Make functionalization `ViewMeta` serializable with pickle. (pytorch#143712)" (pytorch#163769) NOTE: This is a re-export of pytorch#161994 ; the changes between these two PRs is exclusively to the buck/build files (Summary from pytorch#161994 ) Attempted rebase of pytorch#143712. This reverts commit 6c713cc. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela imported-using-ghimport Test Plan: Imported from OSS Differential Revision: D81524507 Pulled By: Lucaskabela Pull Request resolved: pytorch#163769 Approved by: https://github.com/dolpm (cherry picked from commit 7d71040) Co-authored-by: Brian Hirsh <[email protected]>
1 parent 017d857 commit 2f6387e

38 files changed

+979
-423
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ torch/return_types.pyi
8282
torch/nn/functional.pyi
8383
torch/utils/data/datapipes/datapipe.pyi
8484
torch/csrc/autograd/generated/*
85+
torch/csrc/functionalization/generated/*
8586
torch/csrc/lazy/generated/*.[!m]*
8687
torch_compile_debug/
8788
# Listed manually because some files in this directory are not generated

BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ generated_cpu_cpp = [
9191
"aten/src/ATen/NativeMetaFunctions.h",
9292
"aten/src/ATen/RegistrationDeclarations.h",
9393
"aten/src/ATen/VmapGeneratedPlumbing.h",
94+
"aten/src/ATen/ViewMetaClasses.h",
95+
"aten/src/ATen/ViewMetaClasses.cpp",
9496
"aten/src/ATen/core/aten_interned_strings.h",
9597
"aten/src/ATen/core/enum_tag.h",
9698
"aten/src/ATen/core/TensorBody.h",
@@ -1106,6 +1108,7 @@ test_suite(
11061108
"aten/src/ATen/templates/LazyNonNativeIr.h",
11071109
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
11081110
"aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
1111+
"aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
11091112
"aten/src/ATen/native/native_functions.yaml",
11101113
"aten/src/ATen/native/tags.yaml",
11111114
"aten/src/ATen/native/ts_native_functions.yaml",

aten/src/ATen/FunctionalStorageImpl.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,6 @@
99

1010
namespace at::functionalization {
1111

12-
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
13-
if (out_idx == this->out_index) return *this;
14-
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
15-
}
16-
1712
// Note [Functionalization: Alias Removal Part 2]
1813
// See Note [Functionalization: Alias Removal] for more details.
1914
// This function applies a single update from one of the views to the StorageImpl.
@@ -42,22 +37,21 @@ ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
4237
static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
4338
at::Tensor t = update.new_val;
4439
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
45-
if (update.view_metas.empty()) return t;
40+
if (update.view_metas.empty()) { return t; }
4641

4742
std::vector<at::Tensor> tmp_values({base});
4843
tmp_values.reserve(update.view_metas.size());
4944
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
50-
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
45+
at::Tensor next_view = update.view_metas[i]->forward(tmp_values.back());
5146
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
5247
// All of these ops require additional information to recover the sizes of the original tensor.
5348
// If need to, we could probably apply this optimization and only bother computing tmp_values
5449
// for those necessary view ops.
5550
tmp_values.push_back(std::move(next_view));
5651
}
5752
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
58-
int64_t out_idx = update.view_metas[i].out_index;
5953
// Each view inverse is implemented in ViewInverses.cpp.
60-
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
54+
t = update.view_metas[i]->reverse(tmp_values[i], t);
6155
}
6256
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
6357
return t;
@@ -111,13 +105,13 @@ FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
111105
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
112106
}
113107

114-
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
108+
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<std::shared_ptr<ViewMeta>>& metas) {
115109
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
116110

117111
if (metas.size() > 1) {
118112
for (size_t i = 1; i < metas.size(); ++i) {
119113
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
120-
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
114+
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i]->is_as_strided,
121115
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
122116
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
123117
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "

aten/src/ATen/FunctionalStorageImpl.h

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,44 +8,89 @@ namespace at::functionalization {
88

99
// See Note [Functionalization Pass In Core]
1010

11+
enum class InverseReturnMode {
12+
/// Specifies that functional inverses should always return a view.
13+
AlwaysView,
14+
/// Specifies that functional inverses should always return a non-view / copy.
15+
NeverView,
16+
/// Specifies that functional inverses should return a view unless a (copying)
17+
/// scatter
18+
/// inverse exists, in which case that will be used instead.
19+
/// This avoids as_strided() calls that can be difficult for subclasses to
20+
/// handle.
21+
ViewOrScatterInverse,
22+
};
23+
24+
#define FUNCTIONALIZATION_VIEWMETA_NAME(TYPE) \
25+
static const char* name() { \
26+
return #TYPE; \
27+
}
28+
29+
#define FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(...) \
30+
using SerializableTuple = std::tuple<__VA_ARGS__>
31+
1132
// ViewMeta is a class used by the functionalization pass to navigate between
1233
// a base tensor and a view tensor.
1334
// For example, if I call `b = a.view1(...)`
14-
// the functionalization pass will generate and store a ViewMeta on b that looks
15-
// like:
35+
// the functionalization pass will generate and store a ViewMeta specialization
36+
// for `view1` operation on b that looks like:
1637
//
17-
// ViewMeta(
18-
// [<captures>](const Tensor& base, int64_t mutated_view_idx) {
19-
// return base.view1(...);
20-
// },
21-
// [<captures>](const at::Tensor& base, const at::Tensor& mutated_view,
22-
// int64_t mutated_view_idx) -> at::Tensor {
23-
// return at::functionalization::impl::view1_inverse(base, mutated_view,
24-
// ...);
38+
// struct TORCH_API view1_ViewMeta : public ViewMeta {
39+
// FUNCTIONALIZATION_VIEWMETA_NAME(view1_ViewMeta);
40+
// FUNCTIONALIZATION_VIEWMETA_SERIALIZABLE_TUPLE(
41+
// bool /* reapply_views */,
42+
// const std::vector<int64_t>&);
43+
//
44+
// view1_ViewMeta(const SerializableTuple& tpl)
45+
// : view1_ViewMeta(std::get<0>(tpl), std::get<1>(tpl)) {}
46+
//
47+
// view1_ViewMeta(bool reapply_views, const std::vector<int64_t>& size)
48+
// : ViewMeta(/*has_symbolic_inputs=*/false),
49+
// reapply_views(reapply_views),
50+
// size(size) {}
51+
//
52+
// Tensor forward(const Tensor& base) override {
53+
// return base.view1(...);
2554
// }
2655
//
27-
// The forward_fn lambda describes how to replay view1 on a tensor.
56+
// Tensor reverse(const Tensor& base, const Tensor& mutated_view) override {
57+
// return at::functionalization::impl::view1_inverse(base, mutated_view,
58+
// ...);
59+
// }
2860
//
29-
// The reverse_fn lambda describes how, given a tensor that is already a view,
61+
// SerializableTuple to_serializable_tuple() {
62+
// return std::make_tuple(reapply_views, size);
63+
// }
64+
//
65+
// bool reapply_views;
66+
// std::vector<int64_t> size;
67+
// };
68+
//
69+
// The forward function describes how to replay view1 on a tensor.
70+
//
71+
// The reverse function describes how, given a tensor that is already a view,
3072
// how to get the corresponding base tensor. See Note [Functionalization Pass:
3173
// View Inverses] for details.
74+
//
75+
// `SerializedTuple` is a typedef that defines an `std::tuple<...>` type
76+
// representing the `ViewMeta` instance state. Methods that take in/return such
77+
// a type are used for supporting pickle serialization.
3278
struct ViewMeta {
3379
ViewMeta(
34-
std::function<Tensor(const Tensor&, int64_t)> forward,
35-
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse,
3680
bool has_symbolic_inputs,
3781
bool is_multi_output = false,
3882
bool is_as_strided = false,
3983
int64_t out_idx = 0)
40-
: forward_fn(std::move(forward)),
41-
reverse_fn(std::move(reverse)),
42-
out_index(out_idx),
84+
: out_index(out_idx),
4385
is_multi_output(is_multi_output),
4486
is_as_strided(is_as_strided),
4587
has_symbolic_inputs(has_symbolic_inputs) {}
4688

47-
std::function<Tensor(const Tensor&, int64_t)> forward_fn;
48-
std::function<Tensor(const Tensor&, const Tensor&, int64_t)> reverse_fn;
89+
virtual ~ViewMeta() = default;
90+
91+
virtual Tensor forward(const Tensor& base) = 0;
92+
virtual Tensor reverse(const Tensor& base, const Tensor& mutated_view) = 0;
93+
4994
// See Note [out_idx in ViewMeta]
5095
int64_t out_index;
5196

@@ -57,10 +102,17 @@ struct ViewMeta {
57102
// Tells us if this view operation has any symbolic inputs
58103
bool has_symbolic_inputs;
59104

60-
// Returns a copy of the current ViewMeta, if out_idx matches the current
61-
// out_index. Otherwise, returns a new ViewMeta with the same forward/reverse
105+
// Returns a new ViewMeta with the same forward/reverse
62106
// functions, but a new out index.
63-
ViewMeta to_out_idx(int64_t out_idx);
107+
//
108+
// This method should be implemented by those `ViewMeta` that have more than
109+
// one output.
110+
virtual std::shared_ptr<ViewMeta> to_out_index(int64_t out_index) {
111+
TORCH_CHECK_NOT_IMPLEMENTED(
112+
false,
113+
"ViewMeta::to_out_index not implemented. ",
114+
"Likely because there's only one output.");
115+
}
64116
};
65117

66118
// FunctionalStorageImpl is a subclass of StorageImpl used by the
@@ -93,14 +145,14 @@ struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl {
93145
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
94146
const at::Tensor new_val;
95147
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
96-
const std::vector<ViewMeta> view_metas;
148+
const std::vector<std::shared_ptr<ViewMeta>> view_metas;
97149
};
98150

99151
explicit FunctionalStorageImpl(const Tensor& value);
100152

101153
void add_update(
102154
const Tensor& updated_val,
103-
const std::vector<ViewMeta>& view_metas);
155+
const std::vector<std::shared_ptr<ViewMeta>>& view_metas);
104156
bool apply_updates();
105157
const Tensor& base() {
106158
return base_;

0 commit comments

Comments
 (0)