@@ -3261,6 +3261,15 @@ class DeviceCachingAllocator {
32613261 }
32623262};
32633263
3264+ static bool zeroAllocations () {
3265+ static auto has_cuda_env =
3266+ c10::utils::check_env (" PYTORCH_CUDA_MEMORY_CACHING_MEMSET_ZEROS" ) == true ;
3267+ static auto has_rocm_env =
3268+ c10::utils::check_env (" PYTORCH_HIP_MEMORY_CACHING_MEMSET_ZEROS" ) == true ;
3269+ static bool zeros = has_cuda_env || has_rocm_env;
3270+ return zeros;
3271+ }
3272+
32643273// Returns whether to force all allocations to bypass the caching allocator and
32653274// go straight to cudaMalloc. This setting is useful when debugging GPU memory
32663275// errors, since the caching allocator foils cuda-memcheck.
@@ -3652,6 +3661,10 @@ class NativeCachingAllocator : public CUDAAllocator {
36523661 TORCH_SDT_WITH_SEMAPHORE (malloc, devPtr, device, size, stream.id ());
36533662 }
36543663
3664+ if (zeroAllocations ()) {
3665+ C10_CUDA_CHECK (cudaMemsetAsync (devPtr, 0 , size, stream));
3666+ }
3667+
36553668 return {devPtr, devPtr, deleteFunc, Device (DeviceType::CUDA, device)};
36563669 }
36573670 DeleterFnPtr raw_deleter () const override {
@@ -3734,6 +3747,12 @@ class NativeCachingAllocator : public CUDAAllocator {
37343747 C10_CUDA_CHECK (c10::cuda::GetDevice (&device));
37353748 malloc (&r, device, nbytes, cuda::getCurrentCUDAStream (device));
37363749 }
3750+ if (zeroAllocations ()) {
3751+ c10::DeviceIndex device = 0 ;
3752+ C10_CUDA_CHECK (c10::cuda::GetDevice (&device));
3753+ C10_CUDA_CHECK (
3754+ cudaMemsetAsync (r, 0 , nbytes, cuda::getCurrentCUDAStream (device)));
3755+ }
37373756 return r;
37383757 }
37393758
@@ -3749,6 +3768,9 @@ class NativeCachingAllocator : public CUDAAllocator {
37493768 C10_CUDA_CHECK (c10::cuda::GetDevice (&device));
37503769 malloc (&r, device, nbytes, stream);
37513770 }
3771+ if (zeroAllocations ()) {
3772+ C10_CUDA_CHECK (cudaMemsetAsync (r, 0 , nbytes, stream));
3773+ }
37523774 return r;
37533775 }
37543776
0 commit comments