Skip to content

Commit 4b560b5

Browse files
authored
fix AttributeError crash when running on non-CUDA systems (#256)
* fix AttributeError crash when running on non-CUDA systems; closes issue #234 and issue #250 * although this prevents dream.py script from crashing immediately on MPS systems, MPS support still very much a work in progress.
1 parent 9ad7920 commit 4b560b5

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

ldm/simplet2i.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def __init__(
157157
self.latent_diffusion_weights = latent_diffusion_weights
158158
self.device = device
159159

160-
self.session_peakmem = torch.cuda.max_memory_allocated()
160+
# for VRAM usage statistics
161+
self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None
162+
161163
if seed is None:
162164
self.seed = self._new_seed()
163165
else:
@@ -363,9 +365,6 @@ def process_image(image,seed):
363365
print('Are you sure your system has an adequate NVIDIA GPU?')
364366

365367
toc = time.time()
366-
self.session_peakmem = max(
367-
self.session_peakmem, torch.cuda.max_memory_allocated()
368-
)
369368
print('Usage stats:')
370369
print(
371370
f' {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
@@ -374,10 +373,15 @@ def process_image(image,seed):
374373
f' Max VRAM used for this generation:',
375374
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
376375
)
377-
print(
378-
f' Max VRAM used since script start: ',
379-
'%4.2fG' % (self.session_peakmem / 1e9),
380-
)
376+
377+
if self.session_peakmem:
378+
self.session_peakmem = max(
379+
self.session_peakmem, torch.cuda.max_memory_allocated()
380+
)
381+
print(
382+
f' Max VRAM used since script start: ',
383+
'%4.2fG' % (self.session_peakmem / 1e9),
384+
)
381385
return results
382386

383387
@torch.no_grad()

0 commit comments

Comments
 (0)