@@ -1100,7 +1100,7 @@ class RingBuffer {
11001100} // anonymous namespace
11011101} // namespace Native
11021102
1103- static std::string reportProcessMemoryInfo (c10::DeviceIndex device ) {
1103+ static std::string reportProcessMemoryInfo (const cudaDeviceProp& prop ) {
11041104#ifdef PYTORCH_C10_DRIVER_API_SUPPORTED
11051105 void * nvml_handle = DriverAPI::get_nvml_handle ();
11061106 if (!nvml_handle) {
@@ -1111,9 +1111,6 @@ static std::string reportProcessMemoryInfo(c10::DeviceIndex device) {
11111111 return true ;
11121112 }();
11131113
1114- cudaDeviceProp prop{};
1115- C10_CUDA_CHECK (cudaGetDeviceProperties (&prop, device));
1116-
11171114 // NOLINTNEXTLINE(*-c-arrays)
11181115 char pci_id[80 ];
11191116 snprintf (
@@ -1215,14 +1212,16 @@ class DeviceCachingAllocator {
12151212 // record used memory.
12161213 size_t total_allocated_memory = 0 ;
12171214
1218- size_t allowed_memory_maximum = 0 ;
1215+ cudaDeviceProp device_prop;
1216+
1217+ // maximum amount of memory that device is allowed to
1218+ // allocate. This is set iff memory fraction is less than 1
1219+ std::optional<size_t > allowed_memory_maximum{std::nullopt };
12191220
12201221 // all live expandable segments
12211222 std::vector<ExpandableSegment*> expandable_segments_;
12221223 std::vector<c10::DeviceIndex> devices_with_peer_access_;
12231224
1224- bool set_fraction = false ;
1225-
12261225 bool record_history = false ;
12271226
12281227 std::atomic<CreateContextFn> context_recorder_;
@@ -1264,6 +1263,9 @@ class DeviceCachingAllocator {
12641263 : device_id(id),
12651264 large_blocks(/* small=*/ false ),
12661265 small_blocks(/* small=*/ true ) {
1266+ C10_CUDA_CHECK (cudaGetDeviceProperties (&device_prop, id));
1267+
1268+ setMemoryFraction (CUDAAllocatorConfig::per_process_memory_fraction ());
12671269 stats.max_split_size =
12681270 static_cast <int64_t >(AcceleratorAllocatorConfig::max_split_size ());
12691271 context_recorder_.store (nullptr );
@@ -1399,7 +1401,7 @@ class DeviceCachingAllocator {
13991401 if (!block_found) {
14001402 // Do garbage collection if the flag is set.
14011403 if (C10_UNLIKELY (
1402- set_fraction &&
1404+ allowed_memory_maximum. has_value () &&
14031405 AcceleratorAllocatorConfig::garbage_collection_threshold () >
14041406 0.0 )) {
14051407 garbage_collect_cached_blocks (context);
@@ -1456,11 +1458,12 @@ class DeviceCachingAllocator {
14561458 C10_CUDA_CHECK (cudaMemGetInfo (&device_free, &device_total));
14571459 std::string allowed_info;
14581460
1459- if (set_fraction) {
1460- allowed_info = format_size (allowed_memory_maximum) + " allowed; " ;
1461+ if (allowed_memory_maximum.has_value ()) {
1462+ allowed_info =
1463+ format_size (allowed_memory_maximum.value ()) + " allowed; " ;
14611464 }
14621465
1463- std::string proc_info = reportProcessMemoryInfo (device_id );
1466+ std::string proc_info = reportProcessMemoryInfo (device_prop );
14641467
14651468 record_trace (
14661469 TraceEntry::OOM,
@@ -1518,7 +1521,7 @@ class DeviceCachingAllocator {
15181521 for (const auto & obs : observers_local) {
15191522 obs (device_id,
15201523 alloc_size,
1521- set_fraction ? allowed_memory_maximum : device_total,
1524+ allowed_memory_maximum. value_or ( device_total) ,
15221525 device_free);
15231526 }
15241527
@@ -2015,25 +2018,26 @@ class DeviceCachingAllocator {
20152018
20162019 /* * get memory fraction limiting maximum allocated memory **/
20172020 double getMemoryFraction () {
2018- if (!set_fraction ) {
2021+ if (!allowed_memory_maximum. has_value () ) {
20192022 return 1.0 ;
20202023 }
20212024
2022- size_t device_free = 0 ;
2023- size_t device_total = 0 ;
2024- C10_CUDA_CHECK (cudaMemGetInfo (&device_free, &device_total));
2025- return static_cast <double >(allowed_memory_maximum) /
2026- static_cast <double >(device_total);
2025+ return static_cast <double >(allowed_memory_maximum.value ()) /
2026+ static_cast <double >(device_prop.totalGlobalMem );
20272027 }
20282028
20292029 /* * set memory fraction to limit maximum allocated memory **/
20302030 void setMemoryFraction (double fraction) {
2031- size_t device_free = 0 ;
2032- size_t device_total = 0 ;
2033- C10_CUDA_CHECK (cudaMemGetInfo (&device_free, &device_total));
2034- allowed_memory_maximum =
2035- static_cast <size_t >(fraction * static_cast <double >(device_total));
2036- set_fraction = true ;
2031+ TORCH_CHECK (
2032+ 0 <= fraction && fraction <= 1 ,
2033+ " invalid fraction:" ,
2034+ fraction,
2035+ " . Please set within [0, 1]." );
2036+ allowed_memory_maximum = std::nullopt ;
2037+ if (fraction < 1.0 ) {
2038+ allowed_memory_maximum = static_cast <size_t >(
2039+ fraction * static_cast <double >(device_prop.totalGlobalMem ));
2040+ }
20372041 }
20382042
20392043 /* * get expandable segment size for all the streams on device **/
@@ -3010,7 +3014,7 @@ class DeviceCachingAllocator {
30103014 BlockPool& pool = *p.pool ;
30113015
30123016 if (C10_UNLIKELY (
3013- set_fraction &&
3017+ allowed_memory_maximum. has_value () &&
30143018 AcceleratorAllocatorConfig::garbage_collection_threshold () > 0.0 )) {
30153019 // Track block reuse interval only when garbage collection is enabled.
30163020 ++pool.get_free_blocks_call_count ;
@@ -3083,7 +3087,7 @@ class DeviceCachingAllocator {
30833087
30843088 size_t gc_threshold = static_cast <size_t >(
30853089 AcceleratorAllocatorConfig::garbage_collection_threshold () *
3086- static_cast <double >(allowed_memory_maximum));
3090+ static_cast <double >(allowed_memory_maximum. value () ));
30873091 // No need to trigger GC yet
30883092 if (total_allocated_memory <= gc_threshold) {
30893093 return ;
@@ -3161,8 +3165,8 @@ class DeviceCachingAllocator {
31613165
31623166 bool active_pool =
31633167 p.pool ->owner_PrivatePool && p.pool ->owner_PrivatePool ->allocator ();
3164- if (set_fraction &&
3165- total_allocated_memory + size > allowed_memory_maximum) {
3168+ if (allowed_memory_maximum. has_value () &&
3169+ total_allocated_memory + size > allowed_memory_maximum. value () ) {
31663170 p.err = cudaErrorMemoryAllocation;
31673171 return false ;
31683172 // Temporarily disable checkpointing & cudagraphs internally
@@ -3859,7 +3863,6 @@ class NativeCachingAllocator : public CUDAAllocator {
38593863 " Allocator not initialized for device " ,
38603864 device,
38613865 " : did you call init?" );
3862- C10_CUDA_CHECK (c10::cuda::SetDevice (device));
38633866 return device_allocator[device]->getMemoryFraction ();
38643867 }
38653868
@@ -3869,12 +3872,6 @@ class NativeCachingAllocator : public CUDAAllocator {
38693872 " Allocator not initialized for device " ,
38703873 device,
38713874 " : did you call init?" );
3872- TORCH_CHECK (
3873- 0 <= fraction && fraction <= 1 ,
3874- " invalid fraction:" ,
3875- fraction,
3876- " . Please set within [0, 1]." );
3877- C10_CUDA_CHECK (c10::cuda::SetDevice (device));
38783875 device_allocator[device]->setMemoryFraction (fraction);
38793876 }
38803877
0 commit comments