@@ -123,11 +123,78 @@ AOTITorchError aoti_torch_empty_strided(
123
123
return Error::Ok;
124
124
}
125
125
126
- // TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
127
126
void clear_all_tensors () {
127
+ // Use aoti_torch_delete_tensor_object to properly delete each tensor
128
+ // Note: We need to collect tensor pointers first since deletion modifies the
129
+ // set
130
+ auto old_tensors =
131
+ std::move (tensors); // tensors is now empty and no need to copy
132
+ for (const auto & tensor_shared : old_tensors) {
133
+ aoti_torch_delete_tensor_object (tensor_shared.get ());
134
+ }
135
+
136
+ // tensors set should now be empty, but ensure it's cleared
128
137
tensors.clear ();
129
138
}
130
139
140
+ AOTITorchError aoti_torch_delete_tensor_object (Tensor* tensor) {
141
+ // Handle null tensor pointer
142
+ if (tensor == nullptr ) {
143
+ ET_LOG (Error, " Cannot delete null tensor" );
144
+ return Error::InvalidArgument;
145
+ }
146
+
147
+ // Check if tensor exists in our tracking
148
+ bool found_in_tensors = false ;
149
+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
150
+ if (it->get () == tensor) {
151
+ found_in_tensors = true ;
152
+ break ;
153
+ }
154
+ }
155
+
156
+ // If tensor not found in our tracking, it's invalid
157
+ if (!found_in_tensors) {
158
+ ET_LOG (Error, " Didn't find tensor %p" , tensor);
159
+ return Error::InvalidArgument;
160
+ }
161
+
162
+ // Find and delete the tensor
163
+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
164
+ if (it->get () == tensor) {
165
+ // Get the tensor before erasing
166
+ auto tensor_ptr = *it;
167
+
168
+ void * data_ptr = tensor_ptr->mutable_data_ptr ();
169
+
170
+ // Determine if it's GPU memory
171
+ cudaPointerAttributes attributes{};
172
+ ET_CUDA_CHECK_OR_RETURN_ERROR (
173
+ cudaPointerGetAttributes (&attributes, data_ptr));
174
+
175
+ // et tensor does not own data; need to free them manually.
176
+ if (attributes.type == cudaMemoryTypeManaged) {
177
+ // This is CUDA managed memory - free with proper synchronization
178
+ ET_CUDA_CHECK_OR_RETURN_ERROR (
179
+ cudaDeviceSynchronize ()); // Wait for all operations to complete
180
+ // BEFORE freeing
181
+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaFree (data_ptr));
182
+ } else {
183
+ // This is CPU memory - free immediately
184
+ free (data_ptr);
185
+ }
186
+ // Remove from set (this will call the destructor if it's the last
187
+ // reference)
188
+ tensors.erase (it);
189
+ return Error::Ok;
190
+ }
191
+ }
192
+
193
+ // This should never be reached since we found it above
194
+ ET_LOG (Error, " Internal error: tensor not found after validation" );
195
+ return Error::Internal;
196
+ }
197
+
131
198
} // extern "C"
132
199
133
200
} // namespace cuda
0 commit comments