Skip to content

Commit 32dd803

Browse files
dreamgonflywilliamFalcon
authored andcommitted
Fix min_max gpu memory logging bug (#453)
* #452 Fix ValueError * #452 Use subprocess.run * #452 Simplify code for gpu_memory_map * #452 Simplify code for min max memory * #452 Add test for get_memory_profile * #452 Use os.sep * #452 Use os.linesep
1 parent 5a9afb1 commit 32dd803

File tree

2 files changed

+16
-23
lines changed

2 files changed

+16
-23
lines changed

pytorch_lightning/root_module/memory.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
'''
44

55
import gc
6+
import os
67
import subprocess
78

89
import numpy as np
@@ -199,19 +200,10 @@ def get_memory_profile(mode):
199200
memory_map = get_gpu_memory_map()
200201

201202
if mode == 'min_max':
202-
min_mem = 1000000
203-
min_k = None
204-
max_mem = 0
205-
max_k = None
206-
for k, v in memory_map:
207-
if v > max_mem:
208-
max_mem = v
209-
max_k = k
210-
if v < min_mem:
211-
min_mem = v
212-
min_k = k
213-
214-
memory_map = {min_k: min_mem, max_k: max_mem}
203+
min_index, min_memory = min(memory_map.items(), key=lambda item: item[1])
204+
max_index, max_memory = max(memory_map.items(), key=lambda item: item[1])
205+
206+
memory_map = {min_index: min_memory, max_index: max_memory}
215207

216208
return memory_map
217209

@@ -225,17 +217,18 @@ def get_gpu_memory_map():
225217
Keys are device ids as integers.
226218
Values are memory usage as integers in MB.
227219
"""
228-
result = subprocess.check_output(
220+
result = subprocess.run(
229221
[
230-
'nvidia-smi', '--query-gpu=memory.used',
231-
'--format=csv,nounits,noheader'
232-
], encoding='utf-8')
222+
'nvidia-smi',
223+
'--query-gpu=memory.used',
224+
'--format=csv,nounits,noheader',
225+
],
226+
encoding='utf-8',
227+
capture_output=True,
228+
check=True)
233229
# Convert lines into a dictionary
234-
gpu_memory = [int(x) for x in result.strip().split('\n')]
235-
gpu_memory_map = {}
236-
for k, v in zip(range(len(gpu_memory)), gpu_memory):
237-
k = f'gpu_{k}'
238-
gpu_memory_map[k] = v
230+
gpu_memory = [int(x) for x in result.stdout.strip().split(os.linesep)]
231+
gpu_memory_map = {f'gpu_{index}': memory for index, memory in enumerate(gpu_memory)}
239232
return gpu_memory_map
240233

241234

tests/test_gpu_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def test_multi_gpu_model_dp():
224224
testing_utils.run_gpu_model_test(trainer_options, model, hparams)
225225

226226
# test memory helper functions
227-
memory.get_gpu_memory_map()
227+
memory.get_memory_profile('min_max')
228228

229229

230230
def test_ddp_sampler_error():

0 commit comments

Comments
 (0)