@@ -123,11 +123,78 @@ AOTITorchError aoti_torch_empty_strided(
123123 return Error::Ok;
124124}
125125
126- // TODO(gasoonjia): reuse aoti_torch_delete_tensor_object to destory tensors
127126void clear_all_tensors () {
127+ // Use aoti_torch_delete_tensor_object to properly delete each tensor
128+ // Note: We need to collect tensor pointers first since deletion modifies the
129+ // set
130+ auto old_tensors =
131+ std::move (tensors); // tensors is now empty and no need to copy
132+ for (const auto & tensor_shared : old_tensors) {
133+ aoti_torch_delete_tensor_object (tensor_shared.get ());
134+ }
135+
136+ // tensors set should now be empty, but ensure it's cleared
128137 tensors.clear ();
129138}
130139
140+ AOTITorchError aoti_torch_delete_tensor_object (Tensor* tensor) {
141+ // Handle null tensor pointer
142+ if (tensor == nullptr ) {
143+ ET_LOG (Error, " Cannot delete null tensor" );
144+ return Error::InvalidArgument;
145+ }
146+
147+ // Check if tensor exists in our tracking
148+ bool found_in_tensors = false ;
149+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
150+ if (it->get () == tensor) {
151+ found_in_tensors = true ;
152+ break ;
153+ }
154+ }
155+
156+ // If tensor not found in our tracking, it's invalid
157+ if (!found_in_tensors) {
158+ ET_LOG (Error, " Didn't find tensor %p" , tensor);
159+ return Error::InvalidArgument;
160+ }
161+
162+ // Find and delete the tensor
163+ for (auto it = tensors.begin (); it != tensors.end (); ++it) {
164+ if (it->get () == tensor) {
165+ // Get the tensor before erasing
166+ auto tensor_ptr = *it;
167+
168+ void * data_ptr = tensor_ptr->mutable_data_ptr ();
169+
170+ // Determine if it's GPU memory
171+ cudaPointerAttributes attributes{};
172+ ET_CUDA_CHECK_OR_RETURN_ERROR (
173+ cudaPointerGetAttributes (&attributes, data_ptr));
174+
175+ // et tensor does not own data; need to free them manually.
176+ if (attributes.type == cudaMemoryTypeManaged) {
177+ // This is CUDA managed memory - free with proper synchronization
178+ ET_CUDA_CHECK_OR_RETURN_ERROR (
179+ cudaDeviceSynchronize ()); // Wait for all operations to complete
180+ // BEFORE freeing
181+ ET_CUDA_CHECK_OR_RETURN_ERROR (cudaFree (data_ptr));
182+ } else {
183+ // This is CPU memory - free immediately
184+ free (data_ptr);
185+ }
186+ // Remove from set (this will call the destructor if it's the last
187+ // reference)
188+ tensors.erase (it);
189+ return Error::Ok;
190+ }
191+ }
192+
193+ // This should never be reached since we found it above
194+ ET_LOG (Error, " Internal error: tensor not found after validation" );
195+ return Error::Internal;
196+ }
197+
131198} // extern "C"
132199
133200} // namespace cuda
0 commit comments