@@ -133,31 +133,31 @@ def __init__(
133133 embedding_path = None ,
134134 # just to keep track of this parameter when regenerating prompt
135135 latent_diffusion_weights = False ,
136- device = 'cuda' ,
137136 ):
138- self .iterations = iterations
139- self .width = width
140- self .height = height
141- self .steps = steps
142- self .cfg_scale = cfg_scale
143- self .weights = weights
144- self .config = config
145- self .sampler_name = sampler_name
146- self .latent_channels = latent_channels
147- self .downsampling_factor = downsampling_factor
148- self .grid = grid
149- self .ddim_eta = ddim_eta
150- self .precision = precision
151- self .full_precision = full_precision
152- self .strength = strength
153- self .embedding_path = embedding_path
154- self .model = None # empty for now
155- self .sampler = None
137+ self .iterations = iterations
138+ self .width = width
139+ self .height = height
140+ self .steps = steps
141+ self .cfg_scale = cfg_scale
142+ self .weights = weights
143+ self .config = config
144+ self .sampler_name = sampler_name
145+ self .latent_channels = latent_channels
146+ self .downsampling_factor = downsampling_factor
147+ self .grid = grid
148+ self .ddim_eta = ddim_eta
149+ self .precision = precision
150+ self .full_precision = full_precision
151+ self .strength = strength
152+ self .embedding_path = embedding_path
153+ self .model = None # empty for now
154+ self .sampler = None
155+ self .device = None
156156 self .latent_diffusion_weights = latent_diffusion_weights
157- self .device = device
158157
159158 # for VRAM usage statistics
160- self .session_peakmem = torch .cuda .max_memory_allocated () if self .device == 'cuda' else None
159+ device_type = choose_torch_device ()
160+ self .session_peakmem = torch .cuda .max_memory_allocated () if device_type == 'cuda' else None
161161
162162 if seed is None :
163163 self .seed = self ._new_seed ()
@@ -251,14 +251,15 @@ def process_image(image,seed):
251251 to create the requested output directory, select a unique informative name for each image, and
252252 write the prompt into the PNG metadata.
253253 """
254- steps = steps or self .steps
255- seed = seed or self .seed
256- width = width or self .width
257- height = height or self .height
258- cfg_scale = cfg_scale or self .cfg_scale
259- ddim_eta = ddim_eta or self .ddim_eta
260- iterations = iterations or self .iterations
261- strength = strength or self .strength
254+ # TODO: convert this into a getattr() loop
255+ steps = steps or self .steps
256+ seed = seed or self .seed
257+ width = width or self .width
258+ height = height or self .height
259+ cfg_scale = cfg_scale or self .cfg_scale
260+ ddim_eta = ddim_eta or self .ddim_eta
261+ iterations = iterations or self .iterations
262+ strength = strength or self .strength
262263 self .log_tokenization = log_tokenization
263264
264265 model = (
@@ -279,7 +280,7 @@ def process_image(image,seed):
279280 self ._set_sampler ()
280281
281282 tic = time .time ()
282- torch .cuda .reset_peak_memory_stats () if self .device == 'cuda' else None
283+ torch .cuda .reset_peak_memory_stats () if self .device . type == 'cuda' else None
283284 results = list ()
284285
285286 try :
0 commit comments