@@ -133,31 +133,31 @@ def __init__(
133133 embedding_path = None ,
134134 # just to keep track of this parameter when regenerating prompt
135135 latent_diffusion_weights = False ,
136- device = 'cuda' ,
137136 ):
138- self .iterations = iterations
139- self .width = width
140- self .height = height
141- self .steps = steps
142- self .cfg_scale = cfg_scale
143- self .weights = weights
144- self .config = config
145- self .sampler_name = sampler_name
146- self .latent_channels = latent_channels
147- self .downsampling_factor = downsampling_factor
148- self .grid = grid
149- self .ddim_eta = ddim_eta
150- self .precision = precision
151- self .full_precision = full_precision
152- self .strength = strength
153- self .embedding_path = embedding_path
154- self .model = None # empty for now
155- self .sampler = None
137+ self .iterations = iterations
138+ self .width = width
139+ self .height = height
140+ self .steps = steps
141+ self .cfg_scale = cfg_scale
142+ self .weights = weights
143+ self .config = config
144+ self .sampler_name = sampler_name
145+ self .latent_channels = latent_channels
146+ self .downsampling_factor = downsampling_factor
147+ self .grid = grid
148+ self .ddim_eta = ddim_eta
149+ self .precision = precision
150+ self .full_precision = full_precision
151+ self .strength = strength
152+ self .embedding_path = embedding_path
153+ self .model = None # empty for now
154+ self .sampler = None
155+ self .device = None
156156 self .latent_diffusion_weights = latent_diffusion_weights
157- self .device = device
158157
159158 # for VRAM usage statistics
160- self .session_peakmem = torch .cuda .max_memory_allocated () if self .device == 'cuda' else None
159+ device_type = choose_torch_device ()
160+ self .session_peakmem = torch .cuda .max_memory_allocated () if device_type == 'cuda' else None
161161
162162 if seed is None :
163163 self .seed = self ._new_seed ()
@@ -250,14 +250,15 @@ def process_image(image,seed):
250250 to create the requested output directory, select a unique informative name for each image, and
251251 write the prompt into the PNG metadata.
252252 """
253- steps = steps or self .steps
254- seed = seed or self .seed
255- width = width or self .width
256- height = height or self .height
257- cfg_scale = cfg_scale or self .cfg_scale
258- ddim_eta = ddim_eta or self .ddim_eta
259- iterations = iterations or self .iterations
260- strength = strength or self .strength
253+ # TODO: convert this into a getattr() loop
254+ steps = steps or self .steps
255+ seed = seed or self .seed
256+ width = width or self .width
257+ height = height or self .height
258+ cfg_scale = cfg_scale or self .cfg_scale
259+ ddim_eta = ddim_eta or self .ddim_eta
260+ iterations = iterations or self .iterations
261+ strength = strength or self .strength
261262 self .log_tokenization = log_tokenization
262263
263264 model = (
0 commit comments