File tree Expand file tree Collapse file tree 1 file changed +9
-9
lines changed
src/diffusers/models/controlnets Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Original file line number Diff line number Diff line change @@ -298,15 +298,6 @@ def forward(
298298 )
299299 encoder_hidden_states = self .context_embedder (encoder_hidden_states )
300300
301- if self .union :
302- # union mode
303- if controlnet_mode is None :
304- raise ValueError ("`controlnet_mode` cannot be `None` when applying ControlNet-Union" )
305- # union mode emb
306- controlnet_mode_emb = self .controlnet_mode_embedder (controlnet_mode )
307- encoder_hidden_states = torch .cat ([controlnet_mode_emb , encoder_hidden_states ], dim = 1 )
308- txt_ids = torch .cat ([txt_ids [:1 ], txt_ids ], dim = 0 )
309-
310301 if txt_ids .ndim == 3 :
311302 logger .warning (
312303 "Passing `txt_ids` 3d torch.Tensor is deprecated."
@@ -320,6 +311,15 @@ def forward(
320311 )
321312 img_ids = img_ids [0 ]
322313
314+ if self .union :
315+ # union mode
316+ if controlnet_mode is None :
317+ raise ValueError ("`controlnet_mode` cannot be `None` when applying ControlNet-Union" )
318+ # union mode emb
319+ controlnet_mode_emb = self .controlnet_mode_embedder (controlnet_mode )
320+ encoder_hidden_states = torch .cat ([controlnet_mode_emb , encoder_hidden_states ], dim = 1 )
321+ txt_ids = torch .cat ([txt_ids [:1 ], txt_ids ], dim = 0 )
322+
323323 ids = torch .cat ((txt_ids , img_ids ), dim = 0 )
324324 image_rotary_emb = self .pos_embed (ids )
325325
You can’t perform that action at this time.
0 commit comments