Skip to content

Commit 3ed98a1

Browse files
committed
fix
1 parent e818907 commit 3ed98a1

File tree

3 files changed

+9
-12
lines changed

3 files changed

+9
-12
lines changed

tests/lora/test_lora_layers_flux.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_flux_the_last_ben(self):
202202
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
203203
self.pipeline.fuse_lora()
204204
self.pipeline.unload_lora_weights()
205-
self.pipeline = self.pipeline.to("cuda")
205+
self.pipeline = self.pipeline.to(torch_device)
206206

207207
prompt = "jon snow eating pizza with ketchup"
208208

@@ -227,7 +227,7 @@ def test_flux_kohya(self):
227227
# Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
228228
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
229229
# `enable_model_cpu_offload()`.
230-
self.pipeline = self.pipeline.to("cuda")
230+
self.pipeline = self.pipeline.to(torch_device)
231231

232232
prompt = "The cat with a brain slug earring"
233233
out = self.pipeline(
@@ -249,7 +249,7 @@ def test_flux_kohya_with_text_encoder(self):
249249
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
250250
self.pipeline.fuse_lora()
251251
self.pipeline.unload_lora_weights()
252-
self.pipeline = self.pipeline.to("cuda")
252+
self.pipeline = self.pipeline.to(torch_device)
253253

254254
prompt = "optimus is cleaning the house with broomstick"
255255
out = self.pipeline(
@@ -271,7 +271,7 @@ def test_flux_xlabs(self):
271271
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
272272
self.pipeline.fuse_lora()
273273
self.pipeline.unload_lora_weights()
274-
self.pipeline = self.pipeline.to("cuda")
274+
self.pipeline = self.pipeline.to(torch_device)
275275

276276
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
277277

tests/lora/test_lora_layers_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_sd3_img2img_lora(self):
177177
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
178178
pipe.fuse_lora()
179179
pipe.unload_lora_weights()
180-
pipe = pipe.to("cuda")
180+
pipe = pipe.to(torch_device)
181181

182182
inputs = self.get_inputs(torch_device)
183183

tests/pipelines/flux/test_pipeline_flux.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,16 @@ def tearDown(self):
212212
torch.cuda.empty_cache()
213213

214214
def get_inputs(self, device, seed=0):
215-
if str(device).startswith("mps"):
216-
generator = torch.manual_seed(seed)
217-
else:
218-
generator = torch.Generator(device="cpu").manual_seed(seed)
215+
generator = torch.Generator(device="cpu").manual_seed(seed)
219216

220217
prompt_embeds = torch.load(
221218
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
222-
)
219+
).to(torch_device)
223220
pooled_prompt_embeds = torch.load(
224221
hf_hub_download(
225222
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
226223
)
227-
)
224+
).to(torch_device)
228225
return {
229226
"prompt_embeds": prompt_embeds,
230227
"pooled_prompt_embeds": pooled_prompt_embeds,
@@ -238,7 +235,7 @@ def get_inputs(self, device, seed=0):
238235
def test_flux_inference(self):
239236
pipe = self.pipeline_class.from_pretrained(
240237
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
241-
).to("cuda")
238+
).to(torch_device)
242239

243240
inputs = self.get_inputs(torch_device)
244241

0 commit comments

Comments
 (0)