Skip to content

Commit 4a0bfea

Browse files
Copilotwxyucs
andcommitted
fix(datacell): fix memory leak in FlattenDataCell query on exception (#1680)
* Initial plan * Fix memory leak in FlattenDataCell query and ComputePairVectors on exception Co-authored-by: wxyucs <12595343+wxyucs@users.noreply.github.com> * fix(datacell): fix memory leak in FlattenDataCell query on exception Co-authored-by: wxyucs <12595343+wxyucs@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: wxyucs <12595343+wxyucs@users.noreply.github.com>
1 parent 54c6f9f commit 4a0bfea

File tree

1 file changed

+67
-42
lines changed

1 file changed

+67
-42
lines changed

src/data_cell/flatten_datacell.h

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -290,41 +290,57 @@ FlattenDataCell<QuantTmpl, IOTmpl>::query(float* result_dists,
290290
this->prefetch_depth_code_ * 64);
291291
}
292292
}
293-
bool release1 = false;
294-
const auto* codes1 = this->GetCodesById(idx[i], release1);
295-
bool release2 = false;
296-
const auto* codes2 = this->GetCodesById(idx[i + 1], release2);
297-
bool release3 = false;
298-
const auto* codes3 = this->GetCodesById(idx[i + 2], release3);
299-
bool release4 = false;
300-
const auto* codes4 = this->GetCodesById(idx[i + 3], release4);
301-
computer->ComputeDistsBatch4(codes1,
302-
codes2,
303-
codes3,
304-
codes4,
305-
result_dists[i],
306-
result_dists[i + 1],
307-
result_dists[i + 2],
308-
result_dists[i + 3]);
309-
310-
if (release1) {
311-
this->io_->Release(codes1);
312-
}
313-
if (release2) {
314-
this->io_->Release(codes2);
315-
}
316-
if (release3) {
317-
this->io_->Release(codes3);
318-
}
319-
if (release4) {
320-
this->io_->Release(codes4);
293+
bool release1 = false, release2 = false, release3 = false, release4 = false;
294+
const uint8_t* codes1 = nullptr;
295+
const uint8_t* codes2 = nullptr;
296+
const uint8_t* codes3 = nullptr;
297+
const uint8_t* codes4 = nullptr;
298+
auto release_batch = [&]() {
299+
if (release1 && codes1) {
300+
this->io_->Release(codes1);
301+
}
302+
if (release2 && codes2) {
303+
this->io_->Release(codes2);
304+
}
305+
if (release3 && codes3) {
306+
this->io_->Release(codes3);
307+
}
308+
if (release4 && codes4) {
309+
this->io_->Release(codes4);
310+
}
311+
};
312+
try {
313+
codes1 = this->GetCodesById(idx[i], release1);
314+
codes2 = this->GetCodesById(idx[i + 1], release2);
315+
codes3 = this->GetCodesById(idx[i + 2], release3);
316+
codes4 = this->GetCodesById(idx[i + 3], release4);
317+
computer->ComputeDistsBatch4(codes1,
318+
codes2,
319+
codes3,
320+
codes4,
321+
result_dists[i],
322+
result_dists[i + 1],
323+
result_dists[i + 2],
324+
result_dists[i + 3]);
325+
} catch (...) {
326+
release_batch();
327+
throw;
321328
}
329+
release_batch();
322330
}
323331
for (; i < id_count; ++i) {
324332
bool release = false;
325-
const auto* codes = this->GetCodesById(idx[i], release);
326-
computer->ComputeDist(codes, result_dists + i);
327-
if (release) {
333+
const uint8_t* codes = nullptr;
334+
try {
335+
codes = this->GetCodesById(idx[i], release);
336+
computer->ComputeDist(codes, result_dists + i);
337+
} catch (...) {
338+
if (release && codes) {
339+
this->io_->Release(codes);
340+
}
341+
throw;
342+
}
343+
if (release && codes) {
328344
this->io_->Release(codes);
329345
}
330346
}
@@ -333,18 +349,27 @@ FlattenDataCell<QuantTmpl, IOTmpl>::query(float* result_dists,
333349
template <typename QuantTmpl, typename IOTmpl>
334350
float
335351
FlattenDataCell<QuantTmpl, IOTmpl>::ComputePairVectors(InnerIdType id1, InnerIdType id2) {
336-
bool release1, release2;
337-
const auto* codes1 = this->GetCodesById(id1, release1);
338-
const auto* codes2 = this->GetCodesById(id2, release2);
339-
auto result = this->quantizer_->Compute(codes1, codes2);
340-
if (release1) {
341-
this->io_->Release(codes1);
342-
}
343-
if (release2) {
344-
this->io_->Release(codes2);
352+
bool release1 = false, release2 = false;
353+
const uint8_t* codes1 = nullptr;
354+
const uint8_t* codes2 = nullptr;
355+
auto release_pair = [&]() {
356+
if (release1 && codes1) {
357+
this->io_->Release(codes1);
358+
}
359+
if (release2 && codes2) {
360+
this->io_->Release(codes2);
361+
}
362+
};
363+
try {
364+
codes1 = this->GetCodesById(id1, release1);
365+
codes2 = this->GetCodesById(id2, release2);
366+
auto result = this->quantizer_->Compute(codes1, codes2);
367+
release_pair();
368+
return result;
369+
} catch (...) {
370+
release_pair();
371+
throw;
345372
}
346-
347-
return result;
348373
}
349374

350375
template <typename QuantTmpl, typename IOTmpl>

0 commit comments

Comments
 (0)