diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 972233bd987d..3c891f916035 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -56,7 +56,7 @@ _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] - _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel", "UNetMultiControlNetXSModel"] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] _import_structure["embeddings"] = ["ImageProjection"] @@ -152,6 +152,7 @@ SD3MultiControlNetModel, SparseControlNetModel, UNetControlNetXSModel, + UNetMultiControlNetXSModel, ) from .embeddings import ImageProjection from .modeling_utils import ModelMixin diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py index aabae709e988..ca289024f441 100644 --- a/src/diffusers/models/controlnets/controlnet_xs.py +++ b/src/diffusers/models/controlnets/controlnet_xs.py @@ -17,7 +17,7 @@ import torch import torch.utils.checkpoint -from torch import Tensor, nn +from torch import Tensor, e, nn from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput, logging @@ -48,7 +48,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name - @dataclass class ControlNetXSOutput(BaseOutput): """ @@ -128,19 +127,19 @@ def get_down_block_adapter( transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): - base_in_channels = base_in_channels if i == 0 else base_out_channels - ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + cur_base_in_channels = base_in_channels if i == 0 else base_out_channels + cur_ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels # Before the resnet/attention application, information is concatted from base to control. # Concat doesn't require change in number of channels - base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + base_to_ctrl.append(make_zero_conv(cur_base_in_channels, cur_base_in_channels)) resnets.append( ResnetBlock2D( - in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + in_channels=cur_ctrl_in_channels + cur_base_in_channels, # information from base is concatted to ctrl out_channels=ctrl_out_channels, temb_channels=temb_channels, - groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), + groups=find_largest_factor(cur_ctrl_in_channels + cur_base_in_channels, max_factor=max_norm_num_groups), groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), eps=1e-5, ) @@ -1218,7 +1217,6 @@ def forward( return ControlNetXSOutput(sample=h_base) - class ControlNetXSCrossAttnDownBlock2D(nn.Module): def __init__( self, @@ -1252,16 +1250,16 @@ def __init__( transformer_layers_per_block = [transformer_layers_per_block] * num_layers for i in range(num_layers): - base_in_channels = base_in_channels if i == 0 else base_out_channels - ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + cur_base_in_channels = base_in_channels if i == 0 else base_out_channels + cur_ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels # Before the resnet/attention application, information is concatted from base to control. # Concat doesn't require change in number of channels - base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) + base_to_ctrl.append(make_zero_conv(cur_base_in_channels, cur_base_in_channels)) base_resnets.append( ResnetBlock2D( - in_channels=base_in_channels, + in_channels=cur_base_in_channels, out_channels=base_out_channels, temb_channels=temb_channels, groups=norm_num_groups, @@ -1269,11 +1267,11 @@ def __init__( ) ctrl_resnets.append( ResnetBlock2D( - in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl + in_channels=cur_ctrl_in_channels + cur_base_in_channels, # information from base is concatted to ctrl out_channels=ctrl_out_channels, temb_channels=temb_channels, groups=find_largest_factor( - ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups + cur_ctrl_in_channels + cur_base_in_channels, max_factor=ctrl_max_norm_num_groups ), groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), eps=1e-5, @@ -1343,7 +1341,7 @@ def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: Down # get params def get_first_cross_attention(block): return block.attentions[0].transformer_blocks[0].attn2 - + base_in_channels = base_downblock.resnets[0].in_channels base_out_channels = base_downblock.resnets[0].out_channels ctrl_in_channels = ( @@ -1884,7 +1882,1266 @@ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): hidden_states = self.upsamplers(hidden_states, upsample_size) return hidden_states + + +class UNetMultiControlNetXSModel(ModelMixin, ConfigMixin): + r""" + A UNet fused with multiple ControlNet-XS adapters model + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic + methods implemented for all models (such as downloading or saving). + + `UNetMultiControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are + compatible with StableDiffusion. + + It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in + `ControlNetXSAdapter` . See their documentation for details. + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # unet configs + sample_size: Optional[int] = 96, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + norm_num_groups: Optional[int] = 32, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: Union[int, Tuple[int]] = 8, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + upcast_attention: bool = True, + use_linear_projection: bool = True, + time_cond_proj_dim: Optional[int] = None, + projection_class_embeddings_input_dim: Optional[int] = None, + # additional controlnet configs + time_embedding_mix: float = 1.0, + controlnets: Optional[List[dict]] = None, + ): + super().__init__() + + if time_embedding_mix < 0 or time_embedding_mix > 1: + raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") + if addition_embed_type is not None and addition_embed_type != "text_time": + raise ValueError( + "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." + ) + + if not isinstance(transformer_layers_per_block, (list, tuple)): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + if not isinstance(cross_attention_dim, (list, tuple)): + cross_attention_dim = [cross_attention_dim] * len(down_block_types) + if not isinstance(num_attention_heads, (list, tuple)): + num_attention_heads = [num_attention_heads] * len(down_block_types) + base_num_attention_heads = num_attention_heads + + self.num_of_controlnet = 0 + + if controlnets is not None: + self.num_of_controlnet = len(controlnets) + + self.in_channels = 4 + + # # Input + self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) + self.ctrl_conv_ins = [] # we have ctrl_conv_in for each controlnets + self.controlnet_cond_embeddings = [] + self.control_to_base_for_conv_ins = [] + + for controlnet in controlnets: + ctrl_block_out_channels = controlnet.get("ctrl_block_out_channels") + ctrl_conditioning_embedding_out_channels = controlnet.get("ctrl_conditioning_embedding_out_channels") + ctrl_conditioning_channels = controlnet.get("ctrl_conditioning_channels") + self.ctrl_conv_ins.append(nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1).half()) + + controlnet_cond_embedding = ControlNetConditioningEmbedding( + conditioning_embedding_channels=ctrl_block_out_channels[0], + block_out_channels=ctrl_conditioning_embedding_out_channels, + conditioning_channels=ctrl_conditioning_channels, + ) + + self.controlnet_cond_embeddings.append(controlnet_cond_embedding.half()) + # used for apply control + self.control_to_base_for_conv_ins.append(make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]).half()) + + # # Time + time_embed_input_dim = block_out_channels[0] + time_embed_dim = block_out_channels[0] * 4 + self.base_time_proj = Timesteps(time_embed_input_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_time_embedding = TimestepEmbedding( + time_embed_input_dim, + time_embed_dim, + cond_proj_dim=time_cond_proj_dim, + ) + + if addition_embed_type is None: + self.base_add_time_proj = None + self.base_add_embedding = None + else: + self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + self.ctrl_time_embeddings = [] + self.ctrl_time_embedding_mix = [] + for controlnet in controlnets: + self.ctrl_time_embedding_mix.append(controlnet.get("time_embedding_mix")) + if controlnet.get("ctrl_learn_time_embedding"): + self.ctrl_time_embeddings.append(TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)) + else: + self.ctrl_time_embeddings.append(None) + + + self.down_blocks = [] + self.mid_block = None + self.up_blocks = [] + + self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) + self.base_conv_act = nn.SiLU() + self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1) + + def num_of_controlnets(self): + return self.num_of_controlnet + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnets: List[ControlNetXSAdapter], + ): + r""" + Instantiate a [`UNetMultiControlNetXSModel`] from a [`UNet2DConditionModel`] and a list of [`ControlNetXSAdapter`] + . + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model we want to control. + controlnet (`ControlNetXSAdapter`): + The ControlNet-XS adapters with which the UNet will be fused. + """ + # # get params + params_for_unet = [ + "sample_size", + "down_block_types", + "up_block_types", + "block_out_channels", + "norm_num_groups", + "cross_attention_dim", + "transformer_layers_per_block", + "addition_embed_type", + "addition_time_embed_dim", + "upcast_attention", + "use_linear_projection", + "time_cond_proj_dim", + "projection_class_embeddings_input_dim", + ] + params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet} + + # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. + params_for_unet["num_attention_heads"] = unet.config.attention_head_dim + + params_for_controlnets = [] + + params_for_controlnet = [ + "conditioning_channels", + "conditioning_embedding_out_channels", + "conditioning_channel_order", + "learn_time_embedding", + "block_out_channels", + "num_attention_heads", + "max_norm_num_groups", + ] + + for controlnet in controlnets: + name_or_path = controlnet.config.get("_name_or_path") + params = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} + params["time_embedding_mix"] = controlnet.config.time_embedding_mix + params_for_controlnets.append(params) + + full_config = { + **params_for_unet, + "controlnets": params_for_controlnets + } + + # # create model + model = cls.from_config({**full_config}) + + # # load weights + # from unet + modules_from_unet = [ + "time_embedding", + "conv_in", + "conv_norm_out", + "conv_out", + ] + + for m in modules_from_unet: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + optional_modules_from_unet = [ + "add_time_proj", + "add_embedding", + ] + + for m in optional_modules_from_unet: + if hasattr(unet, m) and getattr(unet, m) is not None: + getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) + + # from controlnet + for index in range(len(controlnets)): + controlnet = controlnets[index] + model.controlnet_cond_embeddings[index].load_state_dict(controlnet.controlnet_cond_embedding.state_dict()) + model.ctrl_conv_ins[index].load_state_dict(controlnet.conv_in.state_dict()) + if controlnet.time_embedding is not None: + model.ctrl_time_embeddings[index].load_state_dict(controlnet.time_embedding.state_dict()) + model.control_to_base_for_conv_ins[index].load_state_dict(controlnet.control_to_base_for_conv_in.state_dict()) + + # create up_blocks, mid_blocks and down_blocks from unet + num_down_block_layer = len(unet.down_blocks) + layered_controlnet_downblock_list = [] + for layer_index in range(num_down_block_layer): + layered_controlnet_downblock = [] + for controlnet in controlnets: + layered_controlnet_downblock.append(controlnet.down_blocks[layer_index]) + layered_controlnet_downblock_list.append(layered_controlnet_downblock) + + # from both + model.down_blocks = nn.ModuleList( + MultiControlNetXSCrossAttnDownBlock2D.from_modules(b, c) + for b, c in zip(unet.down_blocks, layered_controlnet_downblock_list) + ) + + controlnet_midblock_list = [] + for controlnet in controlnets: + controlnet_midblock_list.append(controlnet.mid_block) + model.mid_block = MultiControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet_midblock_list) + + num_up_block_layer = len(unet.up_blocks) + layered_controlnet_upblock_list = [] + for layer_index in range(num_up_block_layer): + layered_controlnet_upblock = [] + for controlnet in controlnets: + layered_controlnet_upblock.append(controlnet.up_connections[layer_index]) + layered_controlnet_upblock_list.append(layered_controlnet_upblock) + model.up_blocks = nn.ModuleList( + MultiControlNetXSCrossAttnUpBlock2D.from_modules(b, c) + for b, c in zip(unet.up_blocks, layered_controlnet_upblock_list) + ) + # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def freeze_unet_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Freeze everything + for param in self.parameters(): + param.requires_grad = True + + # Unfreeze ControlNetXSAdapter + base_parts = [ + "base_time_proj", + "base_time_embedding", + "base_add_time_proj", + "base_add_embedding", + "base_conv_in", + "base_conv_norm_out", + "base_conv_act", + "base_conv_out", + ] + base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + for d in self.down_blocks: + d.freeze_base_params() + self.mid_block.freeze_base_params() + for u in self.up_blocks: + u.freeze_base_params() + + def forward( + self, + sample: Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.Tensor], # cannot be none + conditioning_scale: Optional[List[float]] = None, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + return_dict: bool = True, + apply_control: bool = True, + ) -> Union[ControlNetXSOutput, Tuple]: + + if len(self.ctrl_conv_ins) != len(controlnet_cond): + raise ValueError("Each controlnet must take exactly one image, as a result, controlnet_conds must be the same length of controlnets") + + # Prepare for base unet and controlnet inputs. + # For each controlnet[index], controlnet_conds[index] is its input. + # h_base has controlnet_cond_embeddings for each controlnet input + h_base = self.base_conv_in(sample) + h_ctrls = [] + for index in range(len(self.ctrl_conv_ins)): + # print(self.ctrl_conv_ins[index].weight.dtype, " self.ctrl_conv_ins[index].weight.dtype") + # print(self.ctrl_conv_ins[index].bias.dtype, "bias dtype") + # print(sample.dtype, " sample dtype") + self.ctrl_conv_ins[index].to(device="cuda", dtype=sample.dtype) + h_ctrl = self.ctrl_conv_ins[index](sample) + # add guided control + controlnet_cond = controlnet_cond[index] #input image + controlnet_cond_embedding = self.controlnet_cond_embeddings[index].to(device="cuda", dtype=sample.dtype) #controlnet embedding netword + guided_hint = controlnet_cond_embedding(controlnet_cond) + + # h_ctrl is a combination of input map and base sample + if guided_hint is not None: + h_ctrl += guided_hint + + if apply_control: + self.control_to_base_for_conv_ins[index].to(device="cuda", dtype=sample.dtype) + h_base = h_base + self.base_conv_in(sample) + self.control_to_base_for_conv_ins[index](h_ctrl) * conditioning_scale[index] + else: + h_base = self.base_conv_in(sample) + + h_ctrls.append(h_ctrl) + + # Prepare for timestamps + # added time & text embeddings + aug_emb = None + # this is added to base time embedding not the ctrl time embedding + if self.config.addition_embed_type is None: + pass + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") # spatial + time_embeds = self.base_add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(sample.dtype) + aug_emb = self.base_add_embedding(add_embeds) + else: + raise ValueError( + f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported." + ) + + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + is_npu = sample.device.type == "npu" + if isinstance(timestep, float): + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + else: + dtype = torch.int32 if (is_mps or is_npu) else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.base_time_proj(timesteps) + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + time_embedding_list = [] + for index in range(len(self.ctrl_time_embeddings)): + base_temb = self.base_time_embedding(t_emb, timestep_cond) + temb_module = self.ctrl_time_embeddings[index] + temb_module.to(t_emb.device, t_emb.dtype) + if temb_module is not None and apply_control: + ctrl_temb = temb_module(t_emb, timestep_cond) + interpolation_param = self.ctrl_time_embedding_mix[index]**0.3 + temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) + else: + temb = base_temb + temb = temb + aug_emb if aug_emb is not None else temb + time_embedding_list.append(temb) + + + # text embeddings + cemb = encoder_hidden_states + + history_h_bases, history_h_ctrls = [], [] + history_h_bases.append(h_base.clone()) + history_h_ctrls.append([t.clone().detach() for t in h_ctrls]) + + for down in self.down_blocks: + h_base, h_ctrls, residual_hb, residual_hc = down( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrls, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + for hb in residual_hb: + history_h_bases.append(hb.clone().detach()) + for hc in residual_hc: + history_h_ctrls.append([t.clone().detach() for t in hc]) + + # 2 - mid + h_base, h_ctrls = self.mid_block( + hidden_states_base=h_base, + hidden_states_ctrl=h_ctrls, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 3 - up + for up in self.up_blocks: + n_resnets = len(up.resnets) + skips_hb = history_h_bases[-n_resnets:] + skips_hc = history_h_ctrls[-n_resnets:] + history_h_bases = history_h_bases[:-n_resnets] + history_h_ctrls = history_h_ctrls[:-n_resnets] + h_base = up( + hidden_states=h_base, + res_hidden_states_tuple_base=skips_hb, + res_hidden_states_tuple_ctrl=skips_hc, + temb=temb, + encoder_hidden_states=cemb, + conditioning_scale=conditioning_scale, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + apply_control=apply_control, + ) + + # 4 - conv out + h_base = self.base_conv_norm_out(h_base) + h_base = self.base_conv_act(h_base) + h_base = self.base_conv_out(h_base) + + if not return_dict: + return (h_base,) + + return ControlNetXSOutput(sample=h_base) + +class MultiControlNetXSCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + # parameters for unet + base_in_channels: int, + base_out_channels: int, + temb_channels: int, + norm_num_groups: int, + has_crossattn, + transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, + base_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + add_downsample: bool = True, + upcast_attention: Optional[bool] = False, + use_linear_projection: Optional[bool] = True, + # parameters for controlnets + ctrl_in_channels_list: Optional[List[int]] = None, + ctrl_out_channels_list: Optional[List[int]] = None, + ctrl_max_norm_num_groups: Optional[List[int]] = None, + ctrl_num_attention_heads: Optional[List[int]] = None, + ): + super().__init__() + base_resnets = [] + base_attentions = [] + # a list of list resnets, each came from one unique controlnet + ctrl_resnets_list = [] + ctrl_attentions_list = [] + # a list of ctrl_to_base nn.Module + ctrl_to_base_list = [] + # a lists of base_to_ctrl nn.Module + base_to_ctrl_list = [] + ctrl_downsamplers_list = [] + num_layers = 2 # only support sd + sdxl + # used for base unet + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for index in range(len(ctrl_in_channels_list)): + ctrl_in_channels = ctrl_in_channels_list[index] + ctrl_out_channels = ctrl_out_channels_list[index] + + base_to_ctrl = [] + ctrl_resnets = [] + ctrl_attentions = [] + ctrl_to_base = [] + + #build down block 2d for this controlnet + for i in range(num_layers): + cur_base_in_channels = base_in_channels if i == 0 else base_out_channels + cur_ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels + + # Before the resnet/attention application, information is concatted from base to control. + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(cur_base_in_channels, cur_base_in_channels)) + + ctrl_resnets.append( + ResnetBlock2D( + in_channels=cur_ctrl_in_channels + cur_base_in_channels, # information from base is concatted to ctrl + out_channels=ctrl_out_channels, + temb_channels=temb_channels, + groups=find_largest_factor( + cur_ctrl_in_channels + cur_base_in_channels, max_factor=ctrl_max_norm_num_groups[index] + ), + groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups[index]), + eps=1e-5, + ) + ) + + if has_crossattn: + ctrl_attentions.append( + Transformer2DModel( + ctrl_num_attention_heads[index], + ctrl_out_channels // ctrl_num_attention_heads[index], + in_channels=ctrl_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups[index]), + ) + ) + else: + ctrl_attentions.append(None) + + # After the resnet/attention application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + + if add_downsample: + # Before the downsampler application, information is concatted from base to control + # Concat doesn't require change in number of channels + base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) + ctrl_downsampler = Downsample2D(ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op") + # After the downsampler application, information is added from control to base + # Addition requires change in number of channels + ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) + else: + ctrl_downsampler = None + ctrl_downsamplers_list.append(ctrl_downsampler) + base_to_ctrl_list.append(base_to_ctrl) + ctrl_resnets_list.append(ctrl_resnets) + ctrl_attentions_list.append(ctrl_attentions) + ctrl_to_base_list.append(ctrl_to_base) + + for i in range(num_layers): + cur_in_channels = base_in_channels if i == 0 else base_out_channels + base_resnets.append( + ResnetBlock2D( + in_channels=cur_in_channels, + out_channels=base_out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + + if has_crossattn: + base_attentions.append( + Transformer2DModel( + base_num_attention_heads, + base_out_channels // base_num_attention_heads, + in_channels=base_out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + else: + base_attentions.append(None) + if add_downsample: + self.base_downsampler = Downsample2D(base_out_channels, use_conv=True, out_channels=base_out_channels, name="op") + else: + self.base_downsampler = None + + self.base_resnets = nn.ModuleList(base_resnets) + self.base_attentions = nn.ModuleList(base_attentions) + + self.ctrl_resnets_list = nn.ModuleList([nn.ModuleList(sublist) for sublist in ctrl_resnets_list]) + self.ctrl_attentions_list = nn.ModuleList([nn.ModuleList(sublist) for sublist in ctrl_attentions_list]) + + self.base_to_ctrl = nn.ModuleList([nn.ModuleList(sublist) for sublist in base_to_ctrl_list]) + self.ctrl_to_base = nn.ModuleList([nn.ModuleList(sublist) for sublist in ctrl_to_base_list]) + + self.ctrl_downsamplers = nn.ModuleList(ctrl_downsamplers_list) + + self.gradient_checkpointing = False + + @classmethod + def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblocks: List[DownBlockControlNetXSAdapter]): + # get params + def get_first_cross_attention(block): + # attn1 is self-attention block and attn2 is cross-attention block + return block.attentions[0].transformer_blocks[0].attn2 + + # parameters for unet base_downblock + base_in_channels = base_downblock.resnets[0].in_channels + base_out_channels = base_downblock.resnets[0].out_channels + temb_channels = base_downblock.resnets[0].time_emb_proj.in_features + num_groups = base_downblock.resnets[0].norm1.num_groups + # parameters for controlnet downblocks + ctrl_in_channels_list = [] + ctrl_out_channels_list = [] + ctrl_num_groups_list = [] + + for i in range(len(ctrl_downblocks)): + ctrl_downblock = ctrl_downblocks[i] + ctrl_in_channels = (ctrl_downblock.resnets[0].in_channels - base_in_channels) # base channels are concatted to ctrl channels in init + ctrl_out_channels = ctrl_downblock.resnets[0].out_channels + ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups + + ctrl_in_channels_list.append(ctrl_in_channels) + ctrl_out_channels_list.append(ctrl_out_channels) + ctrl_num_groups_list.append(ctrl_num_groups) + + if hasattr(base_downblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) + base_num_attention_heads = get_first_cross_attention(base_downblock).heads + + ctrl_num_attention_heads = [] + for i in range(len(ctrl_downblocks)): + ctrl_num_attention_heads.append(get_first_cross_attention(ctrl_downblocks[i]).heads) + + cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_downblock).upcast_attention + use_linear_projection = base_downblock.attentions[0].use_linear_projection + else: + has_crossattn = False + transformer_layers_per_block = None + base_num_attention_heads = None + ctrl_num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + use_linear_projection = None + + add_downsample = base_downblock.downsamplers is not None + + model = cls( + # parameters for unet + base_in_channels=base_in_channels, + base_out_channels=base_out_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + base_num_attention_heads=base_num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_downsample=add_downsample, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + #parameters for controlnets + ctrl_in_channels_list=ctrl_in_channels_list, + ctrl_out_channels_list=ctrl_out_channels_list, + ctrl_max_norm_num_groups=ctrl_num_groups_list, + ctrl_num_attention_heads=ctrl_num_attention_heads, + ) + + + # # load weights + # load resnets + model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) + + for index in range(len(ctrl_downblocks)): + model.ctrl_resnets_list[index].load_state_dict(ctrl_downblocks[index].resnets.state_dict()) + # load attention blocks + if has_crossattn: #print("after upsampler: ", hidden_states.shape) + + model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) + for index in range(len(ctrl_downblocks)): + model.ctrl_attentions_list[index].load_state_dict(ctrl_downblocks[index].attentions.state_dict()) + # load downsampler + if add_downsample: + model.base_downsampler.load_state_dict(base_downblock.downsamplers[0].state_dict()) + for index in range(len(ctrl_downblocks)): + model.ctrl_downsamplers[index].load_state_dict(ctrl_downblocks[index].downsamplers.state_dict()) + # load base_to_ctrl + for index in range(len(ctrl_downblocks)): + model.base_to_ctrl[index].load_state_dict(ctrl_downblocks[index].base_to_ctrl.state_dict()) + model.ctrl_to_base[index].load_state_dict(ctrl_downblocks[index].ctrl_to_base.state_dict()) + return model + + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.base_resnets] + if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.base_attentions) + if self.base_downsamplers is not None: + base_parts.append(self.base_downsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states_base: Tensor, # latent space + temb: Tensor, + encoder_hidden_states: Optional[Tensor] = None, # text embedding + hidden_states_ctrl: Optional[List[Tensor]] = None, # ctrl embedding + conditioning_scale: Optional[List[float]] = None, + attention_mask: Optional[Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[Tensor] = None, + apply_control: bool = True, + ) -> Tuple[Tensor, List[Tensor], Tuple[Tensor, ...], Tuple[List[Tensor], ...]]: + + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + h_base = hidden_states_base + h_ctrl_list = hidden_states_ctrl + base_output_states = () + ctrl_output_states = () + + base_blocks = list(zip(self.base_resnets, self.base_attentions)) + # ctrl_blocks_list[index] handles hidden_states_ctrl[index] + ctrl_blocks_list = [] + + for index in range(len(self.ctrl_resnets_list)): + # both curr_ctrl_resnet and curr_ctrl_attention are nn.ModuleList + curr_ctrl_resnet = self.ctrl_resnets_list[index] + curr_ctrl_attention = self.ctrl_attentions_list[index] + ctrl_blocks = list(zip(curr_ctrl_resnet, curr_ctrl_attention)) + ctrl_blocks_list.append(ctrl_blocks) + + for layer_index in range(len(base_blocks)): + b_res, b_attn = base_blocks[layer_index] + # a list of controlnet input + if apply_control: + for index in range(len(hidden_states_ctrl)): + control_input = hidden_states_ctrl[index] + b2c = self.base_to_ctrl[index][layer_index] + control_input = torch.concat([control_input, b2c(h_base)], dim=1) + h_ctrl_list[index] = control_input + + # apply base subblock + if torch.is_grad_enabled() and self.gradient_checkpointing: + h_base = self._gradient_checkpointing_func(b_res, h_base, temb) + else: + h_base = b_res(h_base, temb) + + if b_attn is not None: + h_base = b_attn( + h_base, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply ctrl subblock + if apply_control: + for index in range(len(h_ctrl_list)): + c_res, c_attn = ctrl_blocks_list[index][layer_index] + if torch.is_grad_enabled() and self.gradient_checkpointing: + h_ctrl_list[index] = self._gradient_checkpointing_func(c_res, h_ctrl_list[index], temb) + else: + h_ctrl_list[index] = c_res(h_ctrl_list[index], temb) + if c_attn is not None: + h_ctrl_list[index] = c_attn( + h_ctrl_list[index], + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # add ctrl -> base + if apply_control: + for index in range(len(h_ctrl_list)): + c2b = self.ctrl_to_base[index][layer_index] + h_base = h_base + c2b(h_ctrl_list[index]) * conditioning_scale[index] + base_output_states = base_output_states + (h_base.clone().detach(),) + ctrl_output_states = ctrl_output_states + ([t.clone().detach() for t in h_ctrl_list], ) + + if self.base_downsampler is not None: + # concat base -> ctrl + if apply_control: + for index in range(len(h_ctrl_list)): + b2c = self.base_to_ctrl[index][-1] + h_ctrl_list[index] = torch.cat([h_ctrl_list[index], b2c(h_base)], dim=1) + + # apply base subblock + h_base = self.base_downsampler(h_base) + + # apply ctrl subblock + if apply_control: + for index in range(len(h_ctrl_list)): + ctrl_down = self.ctrl_downsamplers[index] + h_ctrl_list[index] = ctrl_down(h_ctrl_list[index]) + + # add ctrl -> base + if apply_control: + for index in range(len(h_ctrl_list)): + c2b = self.ctrl_to_base[index][-1] + h_base = h_base + c2b(h_ctrl_list[index]) * conditioning_scale[index] + base_output_states = base_output_states + (h_base.clone().detach(),) + ctrl_output_states = ctrl_output_states + ([t.clone().detach() for t in h_ctrl_list], ) + + return h_base, h_ctrl_list, base_output_states, ctrl_output_states + +class MultiControlNetXSCrossAttnMidBlock2D(nn.Module): + def __init__( + self, + # parameters for unet + base_channels: int, + transformer_layers_per_block: int = 1, + temb_channels: Optional[int] = None, + norm_num_groups: int = 32, + base_num_attention_heads: Optional[int] = 1, + cross_attention_dim: Optional[int] = 1024, + upcast_attention: bool = False, + use_linear_projection: Optional[bool] = True, + # parameters for controlnets + ctrl_channels_list: Optional[List[int]] = None, + ctrl_max_norm_num_groups_list: Optional[List[int]] = None, + ctrl_num_attention_heads_list: Optional[List[int]] = None, + ): + super().__init__() + + # Before the midblock application, information is concatted from base to control. + # Concat doesn't require change in number of channels + self.base_to_ctrl = [] + for i in range(len(ctrl_channels_list)): + self.base_to_ctrl.append(make_zero_conv(base_channels, base_channels)) + + self.base_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=base_channels, # + temb_channels=temb_channels, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=base_num_attention_heads, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + self.ctrl_midblocks = [] + for i in range(len(ctrl_channels_list)): + ctrl_channels = ctrl_channels_list[i] + ctrl_max_norm_num_groups = ctrl_max_norm_num_groups_list[i] + ctrl_num_attention_heads = ctrl_num_attention_heads_list[i] + + ctrl_midblock = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=ctrl_channels + base_channels, + out_channels=ctrl_channels, + temb_channels=temb_channels, + # number or norm groups must divide both in_channels and out_channels + resnet_groups=find_largest_factor( + gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups + ), + cross_attention_dim=cross_attention_dim, + num_attention_heads=ctrl_num_attention_heads, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + self.ctrl_midblocks.append(ctrl_midblock) + + # After the midblock application, information is added from control to base + # Addition requires change in number of channels + self.ctrl_to_base = [] + for i in range(len(ctrl_channels_list)): + ctrl_channels = ctrl_channels_list[i] + self.ctrl_to_base.append(make_zero_conv(ctrl_channels, base_channels)) + + self.gradient_checkpointing = False + + @classmethod + def from_modules( + cls, + base_midblock: UNetMidBlock2DCrossAttn, + ctrl_midblock_list: List[MidBlockControlNetXSAdapter], + ): + # get params + def get_first_cross_attention(midblock): + return midblock.attentions[0].transformer_blocks[0].attn2 + + # parameter from base_midblock + base_channels = base_midblock.in_channels + transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) + temb_channels = base_midblock.resnets[0].time_emb_proj.in_features + num_groups = base_midblock.resnets[0].norm1.num_groups + base_num_attention_heads = get_first_cross_attention(base_midblock).heads + cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_midblock).upcast_attention + use_linear_projection = base_midblock.attentions[0].use_linear_projection + + ctrl_channels_list = [] + ctrl_num_groups_list = [] + ctrl_num_attention_heads_list = [] + + for ctrl_midblock in ctrl_midblock_list: + base_to_ctrl = ctrl_midblock.base_to_ctrl + ctrl_to_base = ctrl_midblock.ctrl_to_base + ctrl_midblock = ctrl_midblock.midblock + + ctrl_channels_list.append(ctrl_to_base.in_channels) + ctrl_num_groups_list.append(ctrl_midblock.resnets[0].norm1.num_groups) + ctrl_num_attention_heads_list.append(get_first_cross_attention(ctrl_midblock).heads) + + # create model + model = cls( + base_channels=base_channels, + transformer_layers_per_block=transformer_layers_per_block, + temb_channels=temb_channels, + norm_num_groups=num_groups, + base_num_attention_heads=base_num_attention_heads, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + ctrl_channels_list=ctrl_channels_list, + ctrl_max_norm_num_groups_list=ctrl_num_groups_list, + ctrl_num_attention_heads_list=ctrl_num_attention_heads_list, + ) + + # load weights + for i in range(len(ctrl_midblock_list)): + ctrl_midblock = ctrl_midblock_list[i] + model.base_to_ctrl[i].load_state_dict(ctrl_midblock.base_to_ctrl.state_dict()) + model.ctrl_midblocks[i].load_state_dict(ctrl_midblock.midblock.state_dict()) + model.ctrl_to_base[i].load_state_dict(ctrl_midblock.ctrl_to_base.state_dict()) + model.base_midblock.load_state_dict(base_midblock.state_dict()) + return model + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + for param in self.base_midblock.parameters(): + param.requires_grad = False + + def forward( + self, + hidden_states_base: Tensor, + temb: Tensor, + encoder_hidden_states: Tensor, + hidden_states_ctrl: List[Tensor], + conditioning_scale: List[float], + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[Tensor] = None, + encoder_attention_mask: Optional[Tensor] = None, + apply_control: bool = True, + ) -> Tuple[Tensor, Tensor]: + + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + h_base = hidden_states_base + h_ctrl_list = hidden_states_ctrl + + joint_args = { + "temb": temb, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": attention_mask, + "cross_attention_kwargs": cross_attention_kwargs, + "encoder_attention_mask": encoder_attention_mask, + } + + if apply_control: + for i in range(len(h_ctrl_list)): + b2c = self.base_to_ctrl[i].to(device="cuda", dtype=h_base.dtype) + h_ctrl_list[i] = torch.cat([h_ctrl_list[i], b2c(h_base)], dim=1) + h_base = self.base_midblock(h_base, **joint_args) + if apply_control: + for i in range(len(h_ctrl_list)): + mid = self.ctrl_midblocks[i].to(device="cuda", dtype=h_ctrl_list[i].dtype) + h_ctrl_list[i] = mid(h_ctrl_list[i], **joint_args) # apply ctrl mid block + c2b = self.ctrl_to_base[i].to(device="cuda", dtype=h_ctrl_list[i].dtype) + h_base = h_base + c2b(h_ctrl_list[i]) * conditioning_scale[i] # add ctrl -> base + + return h_base, h_ctrl_list + +class MultiControlNetXSCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + # parameters for unet + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + norm_num_groups: int = 32, + resolution_idx: Optional[int] = None, + has_crossattn=True, + transformer_layers_per_block: int = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1024, + add_upsample: bool = True, + upcast_attention: bool = False, + use_linear_projection: Optional[bool] = True, + # parameters for controlnets + ctrl_skip_channels_list: Optional[List[List[int]]] = None, + ): + super().__init__() + + resnets = [] + attentions = [] + layered_ctrl_to_base = [] + + num_layers = 3 # only support sd + sdxl + self.has_cross_attention = has_crossattn + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + ctrl_to_base = [] + for j in range(len(ctrl_skip_channels_list)): + # the jth controlnet + current_ctrl_skip_channels = ctrl_skip_channels_list[j] + # the ith convolution layer from jth controlnet + ctrl_to_base.append(make_zero_conv(current_ctrl_skip_channels[i], resnet_in_channels)) + layered_ctrl_to_base.append(ctrl_to_base) + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=norm_num_groups, + ) + ) + + if has_crossattn: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + norm_num_groups=norm_num_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers + self.layered_ctrl_to_base = nn.ModuleList([nn.ModuleList(sublist) for sublist in layered_ctrl_to_base]) + + if add_upsample: + self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def freeze_base_params(self) -> None: + """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine + tuning.""" + # Unfreeze everything + for param in self.parameters(): + param.requires_grad = True + + # Freeze base part + base_parts = [self.resnets] + if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones + base_parts.append(self.attentions) + if self.upsamplers is not None: + base_parts.append(self.upsamplers) + for part in base_parts: + for param in part.parameters(): + param.requires_grad = False + + @classmethod + def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock_list: List[UpBlockControlNetXSAdapter]): + # get params + def get_first_cross_attention(block): + return block.attentions[0].transformer_blocks[0].attn2 + + # parameters for unet + out_channels = base_upblock.resnets[0].out_channels + in_channels = base_upblock.resnets[-1].in_channels - out_channels + prev_output_channels = base_upblock.resnets[0].in_channels - out_channels + temb_channels = base_upblock.resnets[0].time_emb_proj.in_features + num_groups = base_upblock.resnets[0].norm1.num_groups + resolution_idx = base_upblock.resolution_idx + + # parameters for controlnets + ctrl_skip_channels_list = [] + for ctrl_upblock in ctrl_upblock_list: + ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base + channel_skipped = [c.in_channels for c in ctrl_to_base_skip_connections] + ctrl_skip_channels_list.append(channel_skipped) + + if hasattr(base_upblock, "attentions"): + has_crossattn = True + transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) + num_attention_heads = get_first_cross_attention(base_upblock).heads + cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim + upcast_attention = get_first_cross_attention(base_upblock).upcast_attention + use_linear_projection = base_upblock.attentions[0].use_linear_projection + else: + has_crossattn = False + transformer_layers_per_block = None + num_attention_heads = None + cross_attention_dim = None + upcast_attention = None + use_linear_projection = None + add_upsample = base_upblock.upsamplers is not None + + # create model + model = cls( + # parameters for unet + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channels, + temb_channels=temb_channels, + norm_num_groups=num_groups, + resolution_idx=resolution_idx, + has_crossattn=has_crossattn, + transformer_layers_per_block=transformer_layers_per_block, + num_attention_heads=num_attention_heads, + cross_attention_dim=cross_attention_dim, + add_upsample=add_upsample, + upcast_attention=upcast_attention, + use_linear_projection=use_linear_projection, + # parameters for ctrl skip channels + ctrl_skip_channels_list=ctrl_skip_channels_list, + ) + + # load weights + model.resnets.load_state_dict(base_upblock.resnets.state_dict()) + if has_crossattn: + model.attentions.load_state_dict(base_upblock.attentions.state_dict()) + if add_upsample: + model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict()) + for i in range(len(ctrl_upblock_list)): + ctrl_upblock = ctrl_upblock_list[i] + ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base + for j in range(len(ctrl_to_base_skip_connections)): + model.layered_ctrl_to_base[j][i].load_state_dict(ctrl_to_base_skip_connections[j].state_dict()) + return model + + + def forward( + self, + hidden_states: Tensor, + res_hidden_states_tuple_base: Tuple[Tensor, ...], + res_hidden_states_tuple_ctrl: Tuple[List[Tensor], ...], + temb: Tensor, + encoder_hidden_states: Optional[Tensor] = None, + conditioning_scale: Optional[List[float]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[Tensor] = None, + upsample_size: Optional[int] = None, + encoder_attention_mask: Optional[Tensor] = None, + apply_control: bool = True, + ) -> Tensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + return apply_freeu( + self.resolution_idx, + hidden_states, + res_h_base, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + else: + return hidden_states, res_h_base + + base_blocks = list(zip(self.resnets, self.attentions)) + + reversed_h_bases = res_hidden_states_tuple_base[::-1] + reversed_h_ctrl_list = res_hidden_states_tuple_ctrl[::-1] + + for layer_index in range(len(base_blocks)): + resnet, attn = base_blocks[layer_index] + ctrl_to_base_list = self.layered_ctrl_to_base[layer_index] + h_ctrls = reversed_h_ctrl_list[layer_index] + + if apply_control: + for index in range(len(ctrl_to_base_list)): + c2b = ctrl_to_base_list[index] + h_ctrl = h_ctrls[index] + hidden_states += c2b(h_ctrl) * conditioning_scale[index] + + res_h_base = reversed_h_bases[layer_index] + hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) + hidden_states = torch.cat([hidden_states, res_h_base], dim=1) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + #print("after resnet ", hidden_states.shape) + + if attn is not None: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + #return hidden_states + hidden_states = self.upsamplers(hidden_states, upsample_size) + return hidden_states def make_zero_conv(in_channels, out_channels=None): return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index c10931a0f44a..c98254818d37 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -31,7 +31,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel +from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel, UNetMultiControlNetXSModel from ...models.attention_processor import ( AttnProcessor2_0, XFormersAttnProcessor, @@ -186,7 +186,7 @@ def __init__( tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: Union[UNet2DConditionModel, UNetControlNetXSModel], - controlnet: ControlNetXSAdapter, + controlnet: Union[ControlNetXSAdapter, List[ControlNetXSAdapter]], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -195,7 +195,10 @@ def __init__( super().__init__() if isinstance(unet, UNet2DConditionModel): + if isinstance(controlnet, ControlNetXSAdapter): unet = UNetControlNetXSModel.from_unet(unet, controlnet) + else: + unet = UNetMultiControlNetXSModel.from_unet(unet, controlnet) self.register_modules( vae=vae, @@ -556,11 +559,12 @@ def check_inputs( ) if ( isinstance(self.unet, UNetControlNetXSModel) + or isinstance(self.unet, UNetMultiControlNetXSModel) or is_compiled and isinstance(self.unet._orig_mod, UNetControlNetXSModel) ): self.check_image(image, prompt, prompt_embeds) - if not isinstance(controlnet_conditioning_scale, float): + if not isinstance(controlnet_conditioning_scale, float) and isinstance(self.unet, UNetControlNetXSModel): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") else: assert False @@ -584,6 +588,9 @@ def check_image(self, image, prompt, prompt_embeds): image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + unet_is_with_single_controlnet = isinstance(self.unet, UNetControlNetXSModel) + unet_is_with_multiple_controlnets = isinstance(self.unet, UNetMultiControlNetXSModel) + if ( not image_is_pil and not image_is_tensor @@ -608,11 +615,16 @@ def check_image(self, image, prompt, prompt_embeds): elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] - if image_batch_size != 1 and image_batch_size != prompt_batch_size: + if unet_is_with_single_controlnet and image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) - + + if unet_is_with_multiple_controlnets and image_batch_size != 1 and image_batch_size != self.unet.num_of_controlnets(): + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, unet controlnet size: {self.unet.num_of_controlnet()}" + ) + def prepare_image( self, image, @@ -642,6 +654,37 @@ def prepare_image( return image + + def prepare_image_list( + self, + images, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image_list = [] + for image in images: + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_list.append(image) + image_batch_size = len(image_list) + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + for index in range(len(image_list)): + image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0) + image_list[index] = image_list[index].to(device=device, dtype=dtype) + if do_classifier_free_guidance: + image_list[index] = torch.cat([image_list[index] ] * 2) + return image_list + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = ( @@ -971,6 +1014,18 @@ def __call__( do_classifier_free_guidance=do_classifier_free_guidance, ) height, width = image.shape[-2:] + elif isinstance(unet, UNetMultiControlNetXSModel): + image = self.prepare_image_list( + images=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=unet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + height, width = image[0].shape[-2:] else: assert False @@ -1050,6 +1105,8 @@ def __call__( latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = latent_model_input.to(device="cuda", dtype=self.unet.dtype) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # predict the noise residual diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 22efaccec140..fef32ba8d6fa 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -208,6 +208,10 @@ def register_modules(self, **kwargs): # retrieve library if module is None or isinstance(module, (tuple, list)) and module[0] is None: register_dict = {name: (None, None)} + elif isinstance(module, list) and module[0] is not None: + for mod in module: + library, class_name = _fetch_class_library_tuple(mod) + register_dict = {name: (library, class_name)} else: library, class_name = _fetch_class_library_tuple(module) register_dict = {name: (library, class_name)}