Skip to content

Commit b694973

Browse files
dolpmpytorchmergebot
authored andcommitted
[nativert] force resize to zero. (pytorch#159683)
Summary: this was quite a miserable bug. there are a few kernels that don't explicitly resize outputs to zero, which led to some weird UB. Rollback Plan: Differential Revision: D79476454 Pull Request resolved: pytorch#159683 Approved by: https://github.com/SherlockNoMad, https://github.com/henryoier
1 parent 482f069 commit b694973

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torch/nativert/executor/memory/LayoutManager.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ void LayoutManager::ensure_managed_storages(bool allocate) {
105105
auto* tensor = planned_tensors_[i];
106106

107107
at::StorageImpl& storage = *tensor->storage().unsafeGetStorageImpl();
108+
at::TensorImpl& tensor_impl = *tensor->unsafeGetTensorImpl();
108109

109110
if (C10_UNLIKELY(allocate)) {
110111
// from: https://fburl.com/code/4it00yph
@@ -120,7 +121,7 @@ void LayoutManager::ensure_managed_storages(bool allocate) {
120121
//
121122
// For more information, see the doc comment for
122123
// intrusive_ptr::unsafe_adapt_non_heap_allocated.
123-
tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
124+
tensor_impl.set_storage_keep_dtype(at::Storage(
124125
c10::intrusive_ptr<at::StorageImpl>::unsafe_adapt_non_heap_allocated(
125126
&storage_impl_buffer_.to_managed(storage), 1)));
126127
} else if (
@@ -130,12 +131,16 @@ void LayoutManager::ensure_managed_storages(bool allocate) {
130131
&storage_buf
131132
[i]) /* managed storage was replaced for some reason */) {
132133
storage.reset();
133-
tensor->unsafeGetTensorImpl()->set_storage_keep_dtype(at::Storage(
134+
tensor_impl.set_storage_keep_dtype(at::Storage(
134135
c10::intrusive_ptr<at::StorageImpl>::unsafe_adapt_non_heap_allocated(
135136
// NOLINTNEXTLINE(bugprone-pointer-arithmetic-on-polymorphic-object)
136137
&storage_buf[i],
137138
1)));
138139
}
140+
141+
// resize to zero so that we ensure that we don't access out-of-bounds
142+
// addr's in the next iteration
143+
tensor_impl.set_sizes_contiguous({0});
139144
}
140145
}
141146

0 commit comments

Comments
 (0)