Skip to content

Commit 63142fc

Browse files
committed
fix: Always print CUDA memory summary
Inadvertently moved within the gptq-only block
1 parent 9e3fb05 commit 63142fc

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

server/text_generation_server/server.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,17 @@ async def serve_inner(
292292
except ImportError:
293293
print("WARN: Error setting up GPTQ exllama buffers")
294294

295-
if local_rank == 0 and device.type == "cuda":
296-
# Log GPU memory stats at startup
297-
device = model.engine.get_device()
298-
print(f"Cuda process memory fraction: {cuda_process_memory_fraction}")
299-
print(torch.cuda.memory_summary(device=device))
300-
# Start a thread to log GPU usage if configured
301-
interval = float(os.getenv("LOG_GPU_USAGE_INTERVAL", "0"))
302-
if interval > 0.0:
303-
t = threading.Thread(target=partial(log_gpu_stats, device, interval))
304-
t.start()
295+
if local_rank == 0 and device.type == "cuda":
296+
# Log GPU memory stats at startup
297+
print(f"Cuda process memory fraction: {cuda_process_memory_fraction}")
298+
print(torch.cuda.memory_summary(device=device))
299+
# Start a thread to log GPU usage if configured
300+
interval = float(os.getenv("LOG_GPU_USAGE_INTERVAL", "0"))
301+
if interval > 0.0:
302+
t = threading.Thread(target=partial(log_gpu_stats, device, interval))
303+
t.start()
305304

306305
if model.compiled:
307-
308306
# trigger pt2 compile for variety of tensor shapes
309307
print("Warming up PyTorch 2 compile...")
310308
warmup_t0 = time.time()

0 commit comments

Comments
 (0)