@@ -123,6 +123,8 @@ class DeviceCachingAllocator {
123123 ska::flat_hash_map<xpu::XPUStream, std::deque<std::pair<sycl::event, Block*>>>
124124 xpu_events;
125125 DeviceIndex device_index;
126+ size_t allowed_memory_maximum = 0 ;
127+ bool set_fraction = false ;
126128
127129 size_t try_merge_blocks (Block* dst, Block* src, BlockPool& pool) {
128130 if (!src || src->allocated || src->event_count > 0 ||
@@ -245,6 +247,12 @@ class DeviceCachingAllocator {
245247 if (isRetry) {
246248 stats.num_alloc_retries += 1 ;
247249 }
250+ if (set_fraction &&
251+ stats.reserved_bytes [static_cast <size_t >(StatType::AGGREGATE)].current +
252+ size >
253+ allowed_memory_maximum) {
254+ return false ;
255+ }
248256 void * ptr = sycl::aligned_alloc_device (
249257 kDeviceAlignment ,
250258 size,
@@ -435,6 +443,11 @@ class DeviceCachingAllocator {
435443 device_free =
436444 raw_device.get_info <sycl::ext::intel::info::device::free_memory>();
437445 }
446+ std::string allowed_info;
447+ if (set_fraction) {
448+ allowed_info = format_size (allowed_memory_maximum) + " allowed; " ;
449+ }
450+
438451 auto allocated_bytes =
439452 stats.allocated_bytes [static_cast <size_t >(StatType::AGGREGATE)]
440453 .current ;
@@ -459,7 +472,9 @@ class DeviceCachingAllocator {
459472 format_size (device_total),
460473 " of which " ,
461474 format_size (device_free),
462- " is free. Of the allocated memory " ,
475+ " is free. " ,
476+ allowed_info,
477+ " Of the allocated memory " ,
463478 format_size (allocated_bytes),
464479 " is allocated by PyTorch, and " ,
465480 format_size (reserved_bytes - allocated_bytes),
@@ -538,6 +553,14 @@ class DeviceCachingAllocator {
538553 stats.requested_bytes [statType].reset_peak ();
539554 }
540555 }
556+
557+ void setMemoryFraction (double fraction) {
558+ c10::xpu::DeviceProp device_prop;
559+ c10::xpu::get_device_properties (&device_prop, device_index);
560+ auto device_total = device_prop.global_mem_size ;
561+ allowed_memory_maximum = static_cast <size_t >(fraction * device_total);
562+ set_fraction = true ;
563+ }
541564};
542565
543566static void local_raw_delete (void * ptr);
@@ -700,6 +723,16 @@ class XPUAllocator : public DeviceAllocator {
700723 assertValidDevice (device);
701724 device_allocators[device]->resetAccumulatedStats ();
702725 }
726+
727+ void setMemoryFraction (double fraction, DeviceIndex device) {
728+ assertValidDevice (device);
729+ TORCH_CHECK_VALUE (
730+ 0 < fraction && fraction <= 1 ,
731+ " invalid fraction:" ,
732+ fraction,
733+ " . Please set within (0, 1]." );
734+ device_allocators[device]->setMemoryFraction (fraction);
735+ }
703736};
704737
705738static XPUAllocator allocator;
@@ -744,6 +777,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
744777 return allocator.recordStream (dataPtr, stream);
745778}
746779
780+ void setMemoryFraction (double fraction, DeviceIndex device) {
781+ return allocator.setMemoryFraction (fraction, device);
782+ }
783+
747784REGISTER_ALLOCATOR (kXPU , &allocator)
748785
749786} // namespace c10::xpu::XPUCachingAllocator
0 commit comments