Skip to content

Commit 753d9bd

Browse files
guangyeypytorchmergebot
authored andcommitted
Introduce a new API torch.xpu.set_per_process_memory_fraction (pytorch#165510)
# Motivation Aligned with other backends, this PR introduces a new API `torch.xpu.set_per_process_memory_fraction` to allow user to customize the allowed memory per a single process. Pull Request resolved: pytorch#165510 Approved by: https://github.com/EikanWang, https://github.com/ezyang ghstack dependencies: pytorch#165508, pytorch#165509
1 parent dd1fe7c commit 753d9bd

File tree

8 files changed

+101
-2
lines changed

8 files changed

+101
-2
lines changed

c10/xpu/XPUCachingAllocator.cpp

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ class DeviceCachingAllocator {
123123
ska::flat_hash_map<xpu::XPUStream, std::deque<std::pair<sycl::event, Block*>>>
124124
xpu_events;
125125
DeviceIndex device_index;
126+
size_t allowed_memory_maximum = 0;
127+
bool set_fraction = false;
126128

127129
size_t try_merge_blocks(Block* dst, Block* src, BlockPool& pool) {
128130
if (!src || src->allocated || src->event_count > 0 ||
@@ -245,6 +247,12 @@ class DeviceCachingAllocator {
245247
if (isRetry) {
246248
stats.num_alloc_retries += 1;
247249
}
250+
if (set_fraction &&
251+
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)].current +
252+
size >
253+
allowed_memory_maximum) {
254+
return false;
255+
}
248256
void* ptr = sycl::aligned_alloc_device(
249257
kDeviceAlignment,
250258
size,
@@ -435,6 +443,11 @@ class DeviceCachingAllocator {
435443
device_free =
436444
raw_device.get_info<sycl::ext::intel::info::device::free_memory>();
437445
}
446+
std::string allowed_info;
447+
if (set_fraction) {
448+
allowed_info = format_size(allowed_memory_maximum) + " allowed; ";
449+
}
450+
438451
auto allocated_bytes =
439452
stats.allocated_bytes[static_cast<size_t>(StatType::AGGREGATE)]
440453
.current;
@@ -459,7 +472,9 @@ class DeviceCachingAllocator {
459472
format_size(device_total),
460473
" of which ",
461474
format_size(device_free),
462-
" is free. Of the allocated memory ",
475+
" is free. ",
476+
allowed_info,
477+
"Of the allocated memory ",
463478
format_size(allocated_bytes),
464479
" is allocated by PyTorch, and ",
465480
format_size(reserved_bytes - allocated_bytes),
@@ -538,6 +553,14 @@ class DeviceCachingAllocator {
538553
stats.requested_bytes[statType].reset_peak();
539554
}
540555
}
556+
557+
void setMemoryFraction(double fraction) {
558+
c10::xpu::DeviceProp device_prop;
559+
c10::xpu::get_device_properties(&device_prop, device_index);
560+
auto device_total = device_prop.global_mem_size;
561+
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
562+
set_fraction = true;
563+
}
541564
};
542565

543566
static void local_raw_delete(void* ptr);
@@ -700,6 +723,16 @@ class XPUAllocator : public DeviceAllocator {
700723
assertValidDevice(device);
701724
device_allocators[device]->resetAccumulatedStats();
702725
}
726+
727+
void setMemoryFraction(double fraction, DeviceIndex device) {
728+
assertValidDevice(device);
729+
TORCH_CHECK_VALUE(
730+
0 < fraction && fraction <= 1,
731+
"invalid fraction:",
732+
fraction,
733+
". Please set within (0, 1].");
734+
device_allocators[device]->setMemoryFraction(fraction);
735+
}
703736
};
704737

705738
static XPUAllocator allocator;
@@ -744,6 +777,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
744777
return allocator.recordStream(dataPtr, stream);
745778
}
746779

780+
void setMemoryFraction(double fraction, DeviceIndex device) {
781+
return allocator.setMemoryFraction(fraction, device);
782+
}
783+
747784
REGISTER_ALLOCATOR(kXPU, &allocator)
748785

749786
} // namespace c10::xpu::XPUCachingAllocator

c10/xpu/XPUCachingAllocator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,6 @@ C10_XPU_API void raw_delete(void* ptr);
2525

2626
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
2727

28+
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
29+
2830
} // namespace c10::xpu::XPUCachingAllocator

docs/source/xpu.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
memory_stats_as_nested_dict
8686
reset_accumulated_memory_stats
8787
reset_peak_memory_stats
88+
set_per_process_memory_fraction
8889
```
8990

9091
```{eval-rst}

test/test_xpu.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
IS_LINUX,
2626
IS_WINDOWS,
2727
run_tests,
28+
serialTest,
2829
suppress_warnings,
2930
TEST_XPU,
3031
TestCase,
@@ -482,6 +483,32 @@ def test_raises_oom(self):
482483
with self.assertRaises(torch.OutOfMemoryError):
483484
torch.empty(1024 * 1024 * 1024 * 1024, device="xpu")
484485

486+
@serialTest()
487+
def test_set_per_process_memory_fraction(self):
488+
gc.collect()
489+
torch.xpu.empty_cache()
490+
total_memory = torch.xpu.get_device_properties().total_memory
491+
fraction = 0.5
492+
with self.assertRaisesRegex(ValueError, "invalid fraction:"):
493+
torch.xpu.set_per_process_memory_fraction(-0.1)
494+
with self.assertRaisesRegex(ValueError, "invalid fraction:"):
495+
torch.xpu.set_per_process_memory_fraction(1.1)
496+
497+
torch.xpu.set_per_process_memory_fraction(fraction)
498+
allowed_memory = int(total_memory * 0.49)
499+
reserved_memory = torch.xpu.memory_reserved()
500+
application_memory = allowed_memory - reserved_memory
501+
tensor = torch.empty(application_memory, dtype=torch.int8, device="xpu")
502+
del tensor
503+
gc.collect()
504+
torch.xpu.empty_cache()
505+
506+
application_memory = int(total_memory * 0.51)
507+
with self.assertRaises(torch.OutOfMemoryError):
508+
_ = torch.empty(application_memory, dtype=torch.int8, device="xpu")
509+
510+
torch.xpu.set_per_process_memory_fraction(1.0)
511+
485512
def test_memory_allocation(self):
486513
torch.xpu.empty_cache()
487514
prev_allocated = torch.xpu.memory_allocated()

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2391,6 +2391,7 @@ def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ...
23912391
def _xpu_resetPeakMemoryStats(device: _int) -> None: ...
23922392
def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ...
23932393
def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ...
2394+
def _xpu_setMemoryFraction(fraction: _float, device: _int) -> None: ...
23942395

23952396
class _XpuDeviceProperties:
23962397
name: str

torch/csrc/xpu/Module.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,9 @@ static void initXpuMethodBindings(PyObject* module) {
420420
[](c10::DeviceIndex device, c10::DeviceIndex peer) {
421421
return at::xpu::canDeviceAccessPeer(device, peer);
422422
});
423+
m.def("_xpu_setMemoryFraction", [](double fraction, c10::DeviceIndex device) {
424+
c10::xpu::XPUCachingAllocator::setMemoryFraction(fraction, device);
425+
});
423426
}
424427

425428
// Callback for python part. Used for additional initialization of python

torch/xpu/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
530530
memory_stats_as_nested_dict,
531531
reset_accumulated_memory_stats,
532532
reset_peak_memory_stats,
533+
set_per_process_memory_fraction,
533534
)
534535
from .random import (
535536
get_rng_state,
@@ -584,6 +585,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
584585
"seed",
585586
"seed_all",
586587
"set_device",
588+
"set_per_process_memory_fraction",
587589
"set_rng_state",
588590
"set_rng_state_all",
589591
"set_stream",

torch/xpu/memory.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from torch.types import Device
66

7-
from . import _get_device_index, is_initialized
7+
from . import _get_device_index, _lazy_init, is_initialized
88

99

1010
_device_t = Union[Device, str, int, None]
@@ -194,6 +194,31 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
194194
return torch._C._xpu_getMemoryInfo(device)
195195

196196

197+
def set_per_process_memory_fraction(fraction: float, device: _device_t = None) -> None:
198+
r"""
199+
Set the memory fraction for a single process on XPU device.
200+
This function limits the amount of memory that the caching allocator can allocate
201+
on the specified XPU device. The allowed memory is computed as:
202+
203+
.. math:: \text{allowed\_memory} = \text{total\_memory} \times \text{fraction}
204+
205+
If the process attempts to allocate more than this allowed memory,
206+
an out-of-memory error will be raised by the allocator.
207+
208+
Arguments:
209+
fraction(float): Range: 0~1. Allowed memory equals total_memory * fraction.
210+
device (torch.device or int or str, optional): selected device. It uses the current device,
211+
given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default).
212+
213+
.. note:: In general, the total available free memory is less than the total capacity.
214+
"""
215+
_lazy_init()
216+
device = _get_device_index(device, optional=True)
217+
if not isinstance(fraction, float):
218+
raise TypeError("Invalid type for fraction argument, must be `float`")
219+
torch._C._xpu_setMemoryFraction(fraction, device)
220+
221+
197222
__all__ = [
198223
"empty_cache",
199224
"max_memory_allocated",
@@ -205,4 +230,5 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]:
205230
"memory_stats_as_nested_dict",
206231
"reset_accumulated_memory_stats",
207232
"reset_peak_memory_stats",
233+
"set_per_process_memory_fraction",
208234
]

0 commit comments

Comments
 (0)