2929 current_paths , _ = folder_paths .folder_names_and_paths ["ella_encoder" ]
3030folder_paths .folder_names_and_paths ["ella_encoder" ] = (current_paths , folder_paths .supported_pt_extensions )
3131
32-
32+ # === device/dtype alignment helpers ===
33+ def _infer_float_dtype_from_embeds (d : dict ):
34+ import torch
35+ for v in d .values ():
36+ if torch .is_tensor (v ) and v .is_floating_point ():
37+ return v .dtype
38+ if isinstance (v , (list , tuple )):
39+ for t in v :
40+ if torch .is_tensor (t ) and t .is_floating_point ():
41+ return t .dtype
42+ if isinstance (v , dict ):
43+ dt = _infer_float_dtype_from_embeds (v )
44+ if dt is not None :
45+ return dt
46+ return None
47+
48+ def _align_to_model_device_dtype (x , device , dtype ):
49+ import torch
50+ if x is None :
51+ return None
52+ if torch .is_tensor (x ):
53+ if x .is_floating_point ():
54+ return x .to (device = device , dtype = dtype , non_blocking = True )
55+ return x .to (device = device , non_blocking = True )
56+ if isinstance (x , (list , tuple )):
57+ return type (x )(_align_to_model_device_dtype (xx , device , dtype ) for xx in x )
58+ if isinstance (x , dict ):
59+ return {k : _align_to_model_device_dtype (v , device , dtype ) for k , v in x .items ()}
60+ return x
61+
62+ # === /helpers ===
3363def ella_encode (ella : ELLA , timesteps : torch .Tensor , embeds : dict ):
3464 num_steps = len (timesteps ) - 1
3565 # print(f"creating ELLA conds for {num_steps} timesteps")
@@ -39,7 +69,14 @@ def ella_encode(ella: ELLA, timesteps: torch.Tensor, embeds: dict):
3969 start = i / num_steps # Start percentage is calculated based on the index
4070 end = (i + 1 ) / num_steps # End percentage is calculated based on the next index
4171
42- cond_ella = ella (timestep , ** embeds )
72+ # cond_ella = ella(timestep, **embeds)
73+ # align dtype/device to ELLA model
74+ device = getattr (ella , "output_device" , timesteps .device )
75+ want_dtype = _infer_float_dtype_from_embeds (embeds ) or torch .float16
76+ _t = timestep .to (device = device , dtype = want_dtype )
77+ _embeds = _align_to_model_device_dtype (embeds , device , want_dtype )
78+
79+ cond_ella = ella (_t , ** _embeds )
4380
4481 cond_ella_dict = {"start_percent" : start , "end_percent" : end }
4582 conds .append ([cond_ella , cond_ella_dict ])
@@ -69,13 +106,24 @@ def __init__(
69106 self .embeds [i ][k ] = CONDCrossAttn (self .embeds [i ][k ])
70107
71108 def process_cond (self , embeds : Dict [str , CONDCrossAttn ], batch_size , ** kwargs ):
72- return {k : v .process_cond (batch_size , self .ella .output_device , ** kwargs ).cond for k , v in embeds .items ()}
109+ # return {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()}
110+ out = {k : v .process_cond (batch_size , self .ella .output_device , ** kwargs ).cond for k , v in embeds .items ()}
111+ # align floats to a common dtype inferred from outputs (or fallback fp16)
112+ want_dtype = _infer_float_dtype_from_embeds (out ) or torch .float16
113+ return _align_to_model_device_dtype (out , self .ella .output_device , want_dtype )
73114
74115 def prepare_conds (self ):
116+
75117 cond_embeds = self .process_cond (self .embeds [0 ], 1 )
76- cond = self .ella (torch .Tensor ([999 ]), ** cond_embeds )
118+ want_dtype = _infer_float_dtype_from_embeds (cond_embeds ) or torch .float16
119+ t999 = torch .tensor ([999.0 ], device = self .ella .output_device , dtype = want_dtype )
120+ cond = self .ella (t999 , ** cond_embeds )
121+
77122 uncond_embeds = self .process_cond (self .embeds [1 ], 1 )
78- uncond = self .ella (torch .Tensor ([999 ]), ** uncond_embeds )
123+ # same dtype for consistency
124+ t999u = t999
125+ uncond = self .ella (t999u , ** uncond_embeds )
126+
79127 if self .mode == APPLY_MODE_ELLA_ONLY :
80128 return cond , uncond
81129 if "clip_embeds" not in cond_embeds or "clip_embeds" not in uncond_embeds :
@@ -94,22 +142,28 @@ def __call__(self, apply_model, kwargs: dict):
94142 _device = c ["c_crossattn" ].device
95143
96144 time_aware_encoder_hidden_states = []
97- for i in cond_or_uncond :
145+ # get the dtype of the target model from the cond-data of the first group
146+ # (process_cond has already aligned device to self.ella.output_device)
147+ for idx , i in enumerate (cond_or_uncond ):
98148 cond_embeds = self .process_cond (self .embeds [i ], input_x .size (0 ) // len (cond_or_uncond ))
99- h = self .ella (
100- self .model_sampling .timestep (timestep_ [0 ]),
101- ** cond_embeds ,
102- )
103- if self .mode == APPLY_MODE_ELLA_ONLY :
149+ want_dtype = _infer_float_dtype_from_embeds (cond_embeds ) or torch .float16
150+
151+ # timestep from sampler can be on CPU and in fp32 - we will align it
152+ t_model = self .model_sampling .timestep (timestep_ [0 ])
153+ t_model = t_model .to (device = self .ella .output_device , dtype = want_dtype )
154+
155+ h = self .ella (t_model , ** cond_embeds )
156+
157+ if self .mode == APPLY_MODE_ELLA_ONLY or "clip_embeds" not in cond_embeds :
104158 time_aware_encoder_hidden_states .append (h )
105- continue
106- if "clip_embeds" not in cond_embeds :
159+ else :
160+ h = torch . concat ([ h , cond_embeds [ "clip_embeds" ]], dim = 1 )
107161 time_aware_encoder_hidden_states .append (h )
108- continue
109- h = torch .concat ([h , cond_embeds ["clip_embeds" ]], dim = 1 )
110- time_aware_encoder_hidden_states .append (h )
111162
112- c ["c_crossattn" ] = torch .cat (time_aware_encoder_hidden_states , dim = 0 ).to (_device )
163+ # build a batch and move it under the downstream-UNet device
164+ hidden = torch .cat (time_aware_encoder_hidden_states , dim = 0 )
165+ c ["c_crossattn" ] = hidden .to (_device )
166+
113167
114168 return apply_model (input_x , timestep_ , ** c )
115169
0 commit comments