Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion tests/lora/test_lora_layers_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import sys
import tempfile
Expand All @@ -23,7 +24,14 @@
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel

from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
require_peft_backend,
require_torch_gpu,
slow,
torch_device,
)


if is_peft_available():
Expand Down Expand Up @@ -145,3 +153,87 @@ def test_with_alpha_in_state_dict(self):
"Loading from saved checkpoints should give same results.",
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))


@slow
@require_torch_gpu
@require_peft_backend
class FluxLoRAIntegrationTests(unittest.TestCase):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add a skip tag like we have for the Flux pipeline tests:

@unittest.skip("We cannot run inference on this model with the current CI hardware")

"""internal note: The integration slices were obtained on audace."""

num_inference_steps = 10
seed = 0

def setUp(self):
super().setUp()

gc.collect()
torch.cuda.empty_cache()

self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

def tearDown(self):
super().tearDown()

gc.collect()
torch.cuda.empty_cache()

def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()

prompt = "jon snow eating pizza with ketchup"

out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.0,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1719, 0.1719, 0.1699, 0.1719, 0.1719, 0.1738, 0.1641, 0.1621, 0.2090])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)

def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()

prompt = "The cat with a brain slug earring"
out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images

out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)

def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()

prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"

out = self.pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3984, 0.4199, 0.4453, 0.4102, 0.4375, 0.4590, 0.4141, 0.4355, 0.4980])

assert np.allclose(out_slice, expected_slice, atol=1e-4, rtol=1e-4)