Skip to content

Commit 01e05a9

Browse files
committed
this fixes the inconsistent use of self.device, sometimes a str and sometimes an obj
1 parent dc30adf commit 01e05a9

File tree

2 files changed

+31
-38
lines changed

2 files changed

+31
-38
lines changed

ldm/simplet2i.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -133,31 +133,31 @@ def __init__(
133133
embedding_path=None,
134134
# just to keep track of this parameter when regenerating prompt
135135
latent_diffusion_weights=False,
136-
device='cuda',
137136
):
138-
self.iterations = iterations
139-
self.width = width
140-
self.height = height
141-
self.steps = steps
142-
self.cfg_scale = cfg_scale
143-
self.weights = weights
144-
self.config = config
145-
self.sampler_name = sampler_name
146-
self.latent_channels = latent_channels
147-
self.downsampling_factor = downsampling_factor
148-
self.grid = grid
149-
self.ddim_eta = ddim_eta
150-
self.precision = precision
151-
self.full_precision = full_precision
152-
self.strength = strength
153-
self.embedding_path = embedding_path
154-
self.model = None # empty for now
155-
self.sampler = None
137+
self.iterations = iterations
138+
self.width = width
139+
self.height = height
140+
self.steps = steps
141+
self.cfg_scale = cfg_scale
142+
self.weights = weights
143+
self.config = config
144+
self.sampler_name = sampler_name
145+
self.latent_channels = latent_channels
146+
self.downsampling_factor = downsampling_factor
147+
self.grid = grid
148+
self.ddim_eta = ddim_eta
149+
self.precision = precision
150+
self.full_precision = full_precision
151+
self.strength = strength
152+
self.embedding_path = embedding_path
153+
self.model = None # empty for now
154+
self.sampler = None
155+
self.device = None
156156
self.latent_diffusion_weights = latent_diffusion_weights
157-
self.device = device
158157

159158
# for VRAM usage statistics
160-
self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None
159+
device_type = choose_torch_device()
160+
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
161161

162162
if seed is None:
163163
self.seed = self._new_seed()
@@ -251,14 +251,15 @@ def process_image(image,seed):
251251
to create the requested output directory, select a unique informative name for each image, and
252252
write the prompt into the PNG metadata.
253253
"""
254-
steps = steps or self.steps
255-
seed = seed or self.seed
256-
width = width or self.width
257-
height = height or self.height
258-
cfg_scale = cfg_scale or self.cfg_scale
259-
ddim_eta = ddim_eta or self.ddim_eta
260-
iterations = iterations or self.iterations
261-
strength = strength or self.strength
254+
# TODO: convert this into a getattr() loop
255+
steps = steps or self.steps
256+
seed = seed or self.seed
257+
width = width or self.width
258+
height = height or self.height
259+
cfg_scale = cfg_scale or self.cfg_scale
260+
ddim_eta = ddim_eta or self.ddim_eta
261+
iterations = iterations or self.iterations
262+
strength = strength or self.strength
262263
self.log_tokenization = log_tokenization
263264

264265
model = (
@@ -279,7 +280,7 @@ def process_image(image,seed):
279280
self._set_sampler()
280281

281282
tic = time.time()
282-
torch.cuda.reset_peak_memory_stats() if self.device == 'cuda' else None
283+
torch.cuda.reset_peak_memory_stats() if self.device.type == 'cuda' else None
283284
results = list()
284285

285286
try:

scripts/dream.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def main():
6060
# this is solely for recreating the prompt
6161
latent_diffusion_weights=opt.laion400m,
6262
embedding_path=opt.embedding_path,
63-
device=opt.device,
6463
)
6564

6665
# make sure the output directory exists
@@ -376,13 +375,6 @@ def create_argv_parser():
376375
type=str,
377376
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
378377
)
379-
parser.add_argument(
380-
'--device',
381-
'-d',
382-
type=str,
383-
default='cuda',
384-
help='Device to run Stable Diffusion on. Defaults to cuda `torch.cuda.current_device()` if avalible',
385-
)
386378
parser.add_argument(
387379
'--prompt_as_dir',
388380
'-p',

0 commit comments

Comments
 (0)