| 
23 | 23 | from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict  | 
24 | 24 | from ..utils import (  | 
25 | 25 |     USE_PEFT_BACKEND,  | 
 | 26 | +    _get_detailed_type,  | 
26 | 27 |     _get_model_file,  | 
 | 28 | +    _is_valid_type,  | 
27 | 29 |     is_accelerate_available,  | 
28 | 30 |     is_torch_version,  | 
29 | 31 |     is_transformers_available,  | 
@@ -577,29 +579,36 @@ def LinearStrengthModel(start, finish, size):  | 
577 | 579 |         pipeline.set_ip_adapter_scale(ip_strengths)  | 
578 | 580 |         ```  | 
579 | 581 |         """  | 
580 |  | -        transformer = self.transformer  | 
581 |  | -        if not isinstance(scale, list):  | 
582 |  | -            scale = [[scale] * transformer.config.num_layers]  | 
583 |  | -        elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):  | 
584 |  | -            if len(scale) != transformer.config.num_layers:  | 
585 |  | -                raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")  | 
 | 582 | + | 
 | 583 | +        scale_type = Union[int, float]  | 
 | 584 | +        num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters  | 
 | 585 | +        num_layers = self.transformer.config.num_layers  | 
 | 586 | + | 
 | 587 | +        # Single value for all layers of all IP-Adapters  | 
 | 588 | +        if isinstance(scale, scale_type):  | 
 | 589 | +            scale = [scale for _ in range(num_ip_adapters)]  | 
 | 590 | +        # List of per-layer scales for a single IP-Adapter  | 
 | 591 | +        elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:  | 
586 | 592 |             scale = [scale]  | 
 | 593 | +        # Invalid scale type  | 
 | 594 | +        elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):  | 
 | 595 | +            raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")  | 
587 | 596 | 
 
  | 
588 |  | -        scale_configs = scale  | 
 | 597 | +        if len(scale) != num_ip_adapters:  | 
 | 598 | +            raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")  | 
589 | 599 | 
 
  | 
590 |  | -        key_id = 0  | 
591 |  | -        for attn_name, attn_processor in transformer.attn_processors.items():  | 
592 |  | -            if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):  | 
593 |  | -                if len(scale_configs) != len(attn_processor.scale):  | 
594 |  | -                    raise ValueError(  | 
595 |  | -                        f"Cannot assign {len(scale_configs)} scale_configs to "  | 
596 |  | -                        f"{len(attn_processor.scale)} IP-Adapter."  | 
597 |  | -                    )  | 
598 |  | -                elif len(scale_configs) == 1:  | 
599 |  | -                    scale_configs = scale_configs * len(attn_processor.scale)  | 
600 |  | -                for i, scale_config in enumerate(scale_configs):  | 
601 |  | -                    attn_processor.scale[i] = scale_config[key_id]  | 
602 |  | -                key_id += 1  | 
 | 600 | +        if any(len(s) != num_layers for s in scale if isinstance(s, list)):  | 
 | 601 | +            invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}  | 
 | 602 | +            raise ValueError(  | 
 | 603 | +                f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."  | 
 | 604 | +            )  | 
 | 605 | + | 
 | 606 | +        # Scalars are transformed to lists with length num_layers  | 
 | 607 | +        scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]  | 
 | 608 | + | 
 | 609 | +        # Set scales. zip over scale_configs prevents going into single transformer layers  | 
 | 610 | +        for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):  | 
 | 611 | +            attn_processor.scale = scale  | 
603 | 612 | 
 
  | 
604 | 613 |     def unload_ip_adapter(self):  | 
605 | 614 |         """  | 
 | 
0 commit comments