88import numpy as np
99import random
1010import os
11+ import traceback
1112from omegaconf import OmegaConf
1213from PIL import Image
1314from tqdm import tqdm , trange
2829from ldm .models .diffusion .ksampler import KSampler
2930from ldm .dream .pngwriter import PngWriter
3031from ldm .dream .image_util import InitImageResizer
31- from ldm .dream .devices import choose_torch_device
32+ from ldm .dream .devices import choose_autocast_device , choose_torch_device
3233
3334"""Simplified text to image API for stable diffusion/latent diffusion
3435
@@ -114,26 +115,28 @@ class T2I:
114115"""
115116
116117 def __init__ (
117- self ,
118- iterations = 1 ,
119- steps = 50 ,
120- seed = None ,
121- cfg_scale = 7.5 ,
122- weights = 'models/ldm/stable-diffusion-v1/model.ckpt' ,
123- config = 'configs/stable-diffusion/v1-inference.yaml' ,
124- grid = False ,
125- width = 512 ,
126- height = 512 ,
127- sampler_name = 'k_lms' ,
128- latent_channels = 4 ,
129- downsampling_factor = 8 ,
130- ddim_eta = 0.0 , # deterministic
131- precision = 'autocast' ,
132- full_precision = False ,
133- strength = 0.75 , # default in scripts/img2img.py
134- embedding_path = None ,
135- # just to keep track of this parameter when regenerating prompt
136- latent_diffusion_weights = False ,
118+ self ,
119+ iterations = 1 ,
120+ steps = 50 ,
121+ seed = None ,
122+ cfg_scale = 7.5 ,
123+ weights = 'models/ldm/stable-diffusion-v1/model.ckpt' ,
124+ config = 'configs/stable-diffusion/v1-inference.yaml' ,
125+ grid = False ,
126+ width = 512 ,
127+ height = 512 ,
128+ sampler_name = 'k_lms' ,
129+ latent_channels = 4 ,
130+ downsampling_factor = 8 ,
131+ ddim_eta = 0.0 , # deterministic
132+ precision = 'autocast' ,
133+ full_precision = False ,
134+ strength = 0.75 , # default in scripts/img2img.py
135+ embedding_path = None ,
136+ device_type = 'cuda' ,
137+ # just to keep track of this parameter when regenerating prompt
138+ # needs to be replaced when new configuration system implemented.
139+ latent_diffusion_weights = False ,
137140 ):
138141 self .iterations = iterations
139142 self .width = width
@@ -151,11 +154,17 @@ def __init__(
151154 self .full_precision = full_precision
152155 self .strength = strength
153156 self .embedding_path = embedding_path
157+ self .device_type = device_type
154158 self .model = None # empty for now
155159 self .sampler = None
156160 self .device = None
157161 self .latent_diffusion_weights = latent_diffusion_weights
158162
163+ if device_type == 'cuda' and not torch .cuda .is_available ():
164+ device_type = choose_torch_device ()
165+ print (">> cuda not available, using device" , device_type )
166+ self .device = torch .device (device_type )
167+
159168 # for VRAM usage statistics
160169 device_type = choose_torch_device ()
161170 self .session_peakmem = torch .cuda .max_memory_allocated () if device_type == 'cuda' else None
@@ -312,8 +321,9 @@ def process_image(image,seed):
312321 callback = step_callback ,
313322 )
314323
315- with scope (self .device .type ), self .model .ema_scope ():
316- for n in trange (iterations , desc = '>> Generating' ):
324+ device_type = choose_autocast_device (self .device )
325+ with scope (device_type ), self .model .ema_scope ():
326+ for n in trange (iterations , desc = 'Generating' ):
317327 seed_everything (seed )
318328 image = next (images_iterator )
319329 results .append ([image , seed ])
@@ -346,7 +356,7 @@ def process_image(image,seed):
346356 )
347357 except Exception as e :
348358 print (
349- f'Error running RealESRGAN - Your image was not upscaled.\n { e } '
359+ f'>> Error running RealESRGAN - Your image was not upscaled.\n { e } '
350360 )
351361 if image_callback is not None :
352362 if save_original :
@@ -359,11 +369,11 @@ def process_image(image,seed):
359369 except KeyboardInterrupt :
360370 print ('*interrupted*' )
361371 print (
362- 'Partial results will be returned; if --grid was requested, nothing will be returned.'
372+ '>> Partial results will be returned; if --grid was requested, nothing will be returned.'
363373 )
364374 except RuntimeError as e :
365- print (str ( e ) )
366- print ('Are you sure your system has an adequate NVIDIA GPU?' )
375+ print (traceback . format_exc (), file = sys . stderr )
376+ print ('>> Are you sure your system has an adequate NVIDIA GPU?' )
367377
368378 toc = time .time ()
369379 print ('>> Usage stats:' )
@@ -464,7 +474,6 @@ def _img2img(
464474 )
465475
466476 t_enc = int (strength * steps )
467- # print(f"target t_enc is {t_enc} steps")
468477
469478 while True :
470479 uc , c = self ._get_uc_and_c (prompt , skip_normalize )
@@ -515,7 +524,7 @@ def _sample_to_image(self, samples):
515524 x_samples = torch .clamp ((x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
516525 if len (x_samples ) != 1 :
517526 raise Exception (
518- f'expected to get a single image, but got { len (x_samples )} ' )
527+ f'>> expected to get a single image, but got { len (x_samples )} ' )
519528 x_sample = 255.0 * rearrange (
520529 x_samples [0 ].cpu ().numpy (), 'c h w -> h w c'
521530 )
@@ -525,17 +534,12 @@ def _new_seed(self):
525534 self .seed = random .randrange (0 , np .iinfo (np .uint32 ).max )
526535 return self .seed
527536
528- def _get_device (self ):
529- device_type = choose_torch_device ()
530- return torch .device (device_type )
531-
532537 def load_model (self ):
533538 """Load and initialize the model from configuration variables passed at object creation time"""
534539 if self .model is None :
535540 seed_everything (self .seed )
536541 try :
537542 config = OmegaConf .load (self .config )
538- self .device = self ._get_device ()
539543 model = self ._load_model_from_config (config , self .weights )
540544 if self .embedding_path is not None :
541545 model .embedding_manager .load (
@@ -544,12 +548,10 @@ def load_model(self):
544548 self .model = model .to (self .device )
545549 # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
546550 self .model .cond_stage_model .device = self .device
547- except AttributeError :
548- import traceback
549- print (
550- 'Error loading model. Only the CUDA backend is supported' , file = sys .stderr )
551+ except AttributeError as e :
552+ print (f'>> Error loading model. { str (e )} ' , file = sys .stderr )
551553 print (traceback .format_exc (), file = sys .stderr )
552- raise SystemExit
554+ raise SystemExit from e
553555
554556 self ._set_sampler ()
555557
0 commit comments