@@ -100,7 +100,6 @@ def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torc
100100 Shape: (1, hidden_size)
101101 """
102102 prompt = self .prompt
103- device = TorchDevice .choose_torch_device ()
104103
105104 text_encoder_info = context .models .load (self .qwen3_encoder .text_encoder )
106105 tokenizer_info = context .models .load (self .qwen3_encoder .tokenizer )
@@ -109,6 +108,9 @@ def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torc
109108 (cached_weights , text_encoder ) = exit_stack .enter_context (text_encoder_info .model_on_device ())
110109 (_ , tokenizer ) = exit_stack .enter_context (tokenizer_info .model_on_device ())
111110
111+ # you can now define the device, as the text_encoder exists here
112+ device = text_encoder .device
113+
112114 # Apply LoRA models to the text encoder
113115 lora_dtype = TorchDevice .choose_bfloat16_safe_dtype (device )
114116 exit_stack .enter_context (
@@ -157,28 +159,27 @@ def _encode_prompt(self, context: InvocationContext) -> Tuple[torch.Tensor, torc
157159 max_length = self .max_seq_len ,
158160 )
159161
160- input_ids = inputs ["input_ids" ]
161- attention_mask = inputs ["attention_mask" ]
162+ input_ids = inputs ["input_ids" ]. to ( device )
163+ attention_mask = inputs ["attention_mask" ]. to ( device )
162164
163165 # Move to device
164166 input_ids = input_ids .to (device )
165167 attention_mask = attention_mask .to (device )
166168
167169 # Forward pass through the model - matching diffusers exactly
170+ # Explicitly move inputs to the same device as the text_encoder
168171 outputs = text_encoder (
169172 input_ids = input_ids ,
170173 attention_mask = attention_mask ,
171174 output_hidden_states = True ,
172175 use_cache = False ,
173176 )
174-
175177 # Validate hidden_states output
176178 if not hasattr (outputs , "hidden_states" ) or outputs .hidden_states is None :
177179 raise RuntimeError (
178180 "Text encoder did not return hidden_states. "
179181 "Ensure output_hidden_states=True is supported by this model."
180182 )
181-
182183 num_hidden_layers = len (outputs .hidden_states )
183184
184185 # Extract and stack hidden states - EXACTLY like diffusers:
0 commit comments