Skip to content

Commit c7be080

Browse files
guangyeygujinghui
andauthored
bugfix torch.xpu.empty_cache (#3550) (#3552)
Co-authored-by: Jinghui <[email protected]> (cherry picked from commit 51465f8)
1 parent 44cb7ce commit c7be080

File tree

5 files changed

+28
-4
lines changed

5 files changed

+28
-4
lines changed

csrc/gpu/runtime/CachingDeviceAllocator.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,20 @@ void CachingDeviceAllocator::recordQueue(void* buffer, sycl::queue* queue) {
380380
void CachingDeviceAllocator::emptyCache() {
381381
std::lock_guard<std::recursive_mutex> lock(mutex);
382382
synchronize_and_free_events(std::nullopt);
383+
384+
/*
385+
* See Note [Safe to Free Blocks on BlockPool]
386+
*
387+
* torch.xpu.empty_cache will release all unoccupied cached memory currently
388+
* held on all the GPUs. So we have to do a device-level synchronization on
389+
* all GPUs.
390+
*/
391+
int count = 0;
392+
AT_DPCPP_CHECK(dpcppGetDeviceCount(&count));
393+
for (auto i = 0; i < count; i++) {
394+
xpu::dpcpp::deviceSynchronize(i);
395+
}
396+
383397
free_blocks(large_blocks, large_blocks.begin(), large_blocks.end());
384398
free_blocks(small_blocks, small_blocks.begin(), small_blocks.end());
385399
}
@@ -545,6 +559,15 @@ size_t CachingDeviceAllocator::try_merge_blocks(
545559
return subsumed_size;
546560
}
547561

562+
/**
563+
* Note [Safe to Free Blocks on BlockPool]
564+
*
565+
* Callers must ensure that all accesses to the block, whose raw pointer is
566+
* allocated by SYCL APIs, have been completed before invoking sycl::free.
567+
*
568+
* We have to do a device-level synchronization before free these blocks to
569+
* guarantee that all kernels can access to the blocks have finished.
570+
*/
548571
void CachingDeviceAllocator::free_blocks(
549572
BlockPool& blocks,
550573
BlockPool::iterator it,
@@ -579,6 +602,11 @@ void CachingDeviceAllocator::free_cached_blocks(DeviceId di) {
579602
Block lower_bound(di, nullptr, 0);
580603
Block upper_bound(di + 1, nullptr, 0);
581604

605+
/*
606+
* See Note [Safe to Free Blocks on BlockPool]
607+
*/
608+
xpu::dpcpp::deviceSynchronize(di);
609+
582610
free_blocks(
583611
large_blocks,
584612
large_blocks.lower_bound(&lower_bound),

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/Functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,6 @@ def _ipex_beam_search(
655655
# IPEXTransformerAtten.release_all_static_cached_resources()
656656
reserved_mem = round(torch.xpu.memory_reserved() / 1024**3, 3)
657657
if reserved_mem > 50:
658-
torch.xpu.synchronize()
659658
torch.xpu.empty_cache()
660659
if hasattr(self, "token_latency") and self.token_latency:
661660
return out, latency_list

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/BaseAttention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def end_of_attention(self):
114114
not self.is_beam_search()
115115
and IPEXTransformerAttn.timestamp % self.runtime_cache_size == 0
116116
):
117-
torch.xpu.synchronize()
118117
torch.xpu.empty_cache()
119118
IPEXTransformerAttn.timestamp += 1
120119

intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/transformer_modules/Decoderblock.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def port_all_parameters_to_new_module(self):
3939
self.port_mlp_parameter()
4040
self.port_norm_parameter()
4141
self.port_module_specific_parameter()
42-
torch.xpu.synchronize()
4342
torch.xpu.empty_cache()
4443
# for debug
4544
# self.print_all_paramter_with_name

tests/gpu/regression/test_fill.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def test_fill(self):
99
Regression desc:
1010
fill_ may set values to part of large-size tensor.
1111
"""
12-
torch.xpu.synchronize()
1312
torch.xpu.empty_cache()
1413

1514
output_cpu = torch.zeros([2, 8, 256, 512, 224])

0 commit comments

Comments
 (0)