4646        >>> import torch 
4747        >>> from diffusers.utils import load_image 
4848        >>> from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline 
49+ 
4950        >>> base_model_path = "Qwen/Qwen-Image" 
5051        >>> controlnet_model_path = "InstantX/Qwen-Image-ControlNet-Inpainting" 
5152        >>> controlnet = QwenImageControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16) 
52-         >>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained(base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16).to("cuda") 
53-         >>> image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png") 
54-         >>> mask_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png") 
53+         >>> pipe = QwenImageControlNetInpaintPipeline.from_pretrained( 
54+         ...     base_model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 
55+         ... ).to("cuda") 
56+         >>> image = load_image( 
57+         ...     "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png" 
58+         ... ) 
59+         >>> mask_image = load_image( 
60+         ...     "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png" 
61+         ... ) 
5562        >>> prompt = "一辆绿色的出租车行驶在路上" 
5663        >>> result = pipe( 
5764        ...     prompt=prompt, 
@@ -80,6 +87,7 @@ def calculate_shift(
8087    mu  =  image_seq_len  *  m  +  b 
8188    return  mu 
8289
90+ 
8391# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 
8492def  retrieve_latents (
8593    encoder_output : torch .Tensor , generator : Optional [torch .Generator ] =  None , sample_mode : str  =  "sample" 
@@ -93,6 +101,7 @@ def retrieve_latents(
93101    else :
94102        raise  AttributeError ("Could not access latents of provided encoder_output" )
95103
104+ 
96105# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps 
97106def  retrieve_timesteps (
98107    scheduler ,
@@ -105,6 +114,7 @@ def retrieve_timesteps(
105114    r""" 
106115    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles 
107116    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. 
117+ 
108118    Args: 
109119        scheduler (`SchedulerMixin`): 
110120            The scheduler to get timesteps from. 
@@ -154,6 +164,7 @@ def retrieve_timesteps(
154164class  QwenImageControlNetInpaintPipeline (DiffusionPipeline , QwenImageLoraLoaderMixin ):
155165    r""" 
156166    The QwenImage pipeline for text-to-image generation. 
167+ 
157168    Args: 
158169        transformer ([`QwenImageTransformer2DModel`]): 
159170            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. 
@@ -472,7 +483,7 @@ def prepare_image(
472483            image  =  torch .cat ([image ] *  2 )
473484
474485        return  image 
475-      
486+ 
476487    # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetPipeline.prepare_image_with_mask 
477488    def  prepare_image_with_mask (
478489        self ,
@@ -501,45 +512,47 @@ def prepare_image_with_mask(
501512            repeat_by  =  num_images_per_prompt 
502513
503514        image  =  image .repeat_interleave (repeat_by , dim = 0 )
504-         image  =  image .to (device = device , dtype = dtype ) # (bsz, 3, height_ori, width_ori) 
515+         image  =  image .to (device = device , dtype = dtype )   # (bsz, 3, height_ori, width_ori) 
505516
506517        # Prepare mask 
507518        if  isinstance (mask , torch .Tensor ):
508519            pass 
509520        else :
510521            mask  =  self .mask_processor .preprocess (mask , height = height , width = width )
511522        mask  =  mask .repeat_interleave (repeat_by , dim = 0 )
512-         mask  =  mask .to (device = device , dtype = dtype ) # (bsz, 1, height_ori, width_ori) 
523+         mask  =  mask .to (device = device , dtype = dtype )   # (bsz, 1, height_ori, width_ori) 
513524
514525        if  image .ndim  ==  4 :
515526            image  =  image .unsqueeze (2 )
516-          
527+ 
517528        if  mask .ndim  ==  4 :
518529            mask  =  mask .unsqueeze (2 )
519530
520531        # Get masked image 
521532        masked_image  =  image .clone ()
522-         masked_image [(mask  >  0.5 ).repeat (1 , 3 , 1 , 1 , 1 )] =  - 1  # (bsz, 3, 1, height_ori, width_ori) 
523-          
533+         masked_image [(mask  >  0.5 ).repeat (1 , 3 , 1 , 1 , 1 )] =  - 1    # (bsz, 3, 1, height_ori, width_ori) 
534+ 
524535        self .vae_scale_factor  =  2  **  len (self .vae .temperal_downsample )
525536        latents_mean  =  (torch .tensor (self .vae .config .latents_mean ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 )).to (device )
526-         latents_std  =  1.0  /  torch .tensor (self .vae .config .latents_std ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).to (device )
537+         latents_std  =  1.0  /  torch .tensor (self .vae .config .latents_std ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).to (
538+             device 
539+         )
527540
528541        # Encode to latents 
529542        image_latents  =  self .vae .encode (masked_image .to (self .vae .dtype )).latent_dist .sample ()
530-         image_latents  =  (
531-             image_latents  -  latents_mean 
532-         ) *  latents_std 
533-         image_latents  =  image_latents .to (dtype ) # torch.Size([1, 16, 1, height_ori//8, width_ori//8]) 
543+         image_latents  =  (image_latents  -  latents_mean ) *  latents_std 
544+         image_latents  =  image_latents .to (dtype )  # torch.Size([1, 16, 1, height_ori//8, width_ori//8]) 
534545
535546        mask  =  torch .nn .functional .interpolate (
536547            mask , size = (image_latents .shape [- 3 ], image_latents .shape [- 2 ], image_latents .shape [- 1 ])
537548        )
538-         mask  =  1  -  mask  # torch.Size([1, 1, 1, height_ori//8, width_ori//8]) 
549+         mask  =  1  -  mask    # torch.Size([1, 1, 1, height_ori//8, width_ori//8]) 
539550
540-         control_image  =  torch .cat ([image_latents , mask ], dim = 1 ) # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8]) 
551+         control_image  =  torch .cat (
552+             [image_latents , mask ], dim = 1 
553+         )  # torch.Size([1, 16+1, 1, height_ori//8, width_ori//8]) 
541554
542-         control_image  =  control_image .permute (0 , 2 , 1 , 3 , 4 ) # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8]) 
555+         control_image  =  control_image .permute (0 , 2 , 1 , 3 , 4 )   # torch.Size([1, 1, 16+1, height_ori//8, width_ori//8]) 
543556
544557        # pack 
545558        control_image  =  self ._pack_latents (
@@ -608,6 +621,7 @@ def __call__(
608621    ):
609622        r""" 
610623        Function invoked when calling the pipeline for generation. 
624+ 
611625        Args: 
612626            prompt (`str` or `List[str]`, *optional*): 
613627                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 
@@ -670,8 +684,7 @@ def __call__(
670684                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 
671685                `._callback_tensor_inputs` attribute of your pipeline class. 
672686            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 
673-         Examples: 
674-         Returns: 
687+         Examples: Returns: 
675688            [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: 
676689            [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When 
677690            returning a tuple, the first element is a list with the generated images. 
@@ -839,7 +852,7 @@ def __call__(
839852                    txt_seq_lens = prompt_embeds_mask .sum (dim = 1 ).tolist (),
840853                    return_dict = False ,
841854                )
842-                  
855+ 
843856                with  self .transformer .cache_context ("cond" ):
844857                    noise_pred  =  self .transformer (
845858                        hidden_states = latents ,
0 commit comments