-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Guiders support for Wan #11211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Guiders support for Wan #11211
Changes from 1 commit
74e34e5
1411b33
e72bcf4
98fdabd
0147a6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| from transformers import AutoTokenizer, UMT5EncoderModel | ||
|
|
||
| from ...callbacks import MultiPipelineCallbacks, PipelineCallback | ||
| from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning | ||
| from ...loaders import WanLoraLoaderMixin | ||
| from ...models import AutoencoderKLWan, WanTransformer3DModel | ||
| from ...schedulers import FlowMatchEulerDiscreteScheduler | ||
|
|
@@ -380,6 +381,7 @@ def __call__( | |
| ] = None, | ||
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | ||
| max_sequence_length: int = 512, | ||
| guidance: Optional[GuidanceMixin] = None, | ||
| ): | ||
| r""" | ||
| The call function to the pipeline for generation. | ||
|
|
@@ -444,6 +446,10 @@ def __call__( | |
| indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. | ||
| """ | ||
|
|
||
| _raise_guidance_deprecation_warning(guidance_scale=guidance_scale) | ||
| if guidance is None: | ||
| guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale) | ||
|
|
||
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | ||
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | ||
|
|
||
|
|
@@ -519,37 +525,38 @@ def __call__( | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | ||
| self._num_timesteps = len(timesteps) | ||
|
|
||
| conds = [prompt_embeds, negative_prompt_embeds] | ||
| prompt_embeds, negative_prompt_embeds = [[c] for c in conds] | ||
|
|
||
| with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: | ||
| for i, t in enumerate(timesteps): | ||
| self._current_timestep = t | ||
| if self.interrupt: | ||
| continue | ||
|
|
||
| self._current_timestep = t | ||
| latent_model_input = latents.to(transformer_dtype) | ||
| timestep = t.expand(latents.shape[0]) | ||
|
|
||
| cc.mark_state("cond") | ||
| noise_pred = self.transformer( | ||
| hidden_states=latent_model_input, | ||
| timestep=timestep, | ||
| encoder_hidden_states=prompt_embeds, | ||
| attention_kwargs=attention_kwargs, | ||
| return_dict=False, | ||
| )[0] | ||
|
|
||
| if self.do_classifier_free_guidance: | ||
| cc.mark_state("uncond") | ||
| noise_uncond = self.transformer( | ||
| hidden_states=latent_model_input, | ||
| guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) | ||
| guidance.prepare_models(self.transformer) | ||
| latents, prompt_embeds = guidance.prepare_inputs( | ||
| latents, (prompt_embeds[0], negative_prompt_embeds[0]) | ||
| ) | ||
|
|
||
| for batch_index, (latent, condition) in enumerate(zip(latents, prompt_embeds)): | ||
| cc.mark_state(f"batch_{batch_index}") | ||
| latent = latent.to(transformer_dtype) | ||
| timestep = t.expand(latent.shape[0]) | ||
| noise_pred = self.transformer( | ||
| hidden_states=latent, | ||
| timestep=timestep, | ||
| encoder_hidden_states=negative_prompt_embeds, | ||
| encoder_hidden_states=condition, | ||
| attention_kwargs=attention_kwargs, | ||
| return_dict=False, | ||
| )[0] | ||
| noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) | ||
| guidance.prepare_outputs(noise_pred) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so first, let's test very very thorougly on the potentially performance difference on this change (only need to SDXL for now, different num_images_per_prompt, machine type, etc) second, code-wise I think it's less confusing with something like this, i.e. explicitly pass the model as input (otherwise it's unclear there is a model call there), and a function should always return an output if it modify input noise_pred = guider.prepare_cond( self.transformer, ...)
outputs = guider.prepare_guider_output( self.transformer, ....) |
||
|
|
||
| # compute the previous noisy sample x_t -> x_t-1 | ||
| latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | ||
| outputs = guidance.outputs | ||
| noise_pred = guidance(**outputs) | ||
| latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0] | ||
| guidance.cleanup_models(self.transformer) | ||
|
|
||
| if callback_on_step_end is not None: | ||
| callback_kwargs = {} | ||
|
|
@@ -558,8 +565,10 @@ def __call__( | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | ||
|
|
||
| latents = callback_outputs.pop("latents", latents) | ||
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | ||
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | ||
| prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])] | ||
| negative_prompt_embeds = [ | ||
| callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0]) | ||
| ] | ||
|
|
||
| # call the callback, if provided | ||
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.