- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
Modify the implementation of retrieve_timesteps in CogView4-Control. #11125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Merged
      
      
    
  
     Merged
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            56 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      a97fca2
              
                1
              
              
                zRzRzRzRzRzRzR c30ca7a
              
                change to channel 1
              
              
                zRzRzRzRzRzRzR 5c25cd2
              
                cogview4 control training
              
              
                zRzRzRzRzRzRzR 44bfd4c
              
                add CacheMixin
              
              
                zRzRzRzRzRzRzR a9f448e
              
                1
              
              
                zRzRzRzRzRzRzR 2cbdf35
              
                remove initial_input_channels change for val
              
              
                zRzRzRzRzRzRzR df83bf2
              
                1
              
              
                zRzRzRzRzRzRzR 8bba67a
              
                update
              
              
                zRzRzRzRzRzRzR b9d864b
              
                use 3.5
              
              
                zRzRzRzRzRzRzR 5d2e994
              
                new loss
              
              
                zRzRzRzRzRzRzR ebeb1e4
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 95e8504
              
                1
              
              
                zRzRzRzRzRzRzR 940c23b
              
                Merge branch 'cogview4_control' of https://github.com/zRzRzRzRzRzRzR/…
              
              
                zRzRzRzRzRzRzR 7a68a3e
              
                use imagetoken
              
              
                zRzRzRzRzRzRzR 2a81772
              
                for megatron convert
              
              
                zRzRzRzRzRzRzR 1d91a24
              
                1
              
              
                zRzRzRzRzRzRzR dff4b29
              
                train con and uc
              
              
                zRzRzRzRzRzRzR 050b97c
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR b007be0
              
                2
              
              
                zRzRzRzRzRzRzR 25f4e4b
              
                remove guidance_scale
              
              
                zRzRzRzRzRzRzR 7ffecbc
              
                Update pipeline_cogview4_control.py
              
              
                zRzRzRzRzRzRzR b4e11e7
              
                fix
              
              
                zRzRzRzRzRzRzR efa0f41
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR f55e3cc
              
                use cogview4 pipeline with timestep
              
              
                zRzRzRzRzRzRzR 9410e46
              
                Merge branch 'cogview4_control' of https://github.com/zRzRzRzRzRzRzR/…
              
              
                zRzRzRzRzRzRzR 29b0c81
              
                update shift_factor
              
              
                zRzRzRzRzRzRzR 52d4ebf
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 65b3719
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 90830ed
              
                remove the uncond
              
              
                zRzRzRzRzRzRzR 71f9235
              
                add max length
              
              
                zRzRzRzRzRzRzR 19d7d27
              
                change convert and use GLMModel instead of GLMForCasualLM
              
              
                zRzRzRzRzRzRzR fe6287a
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 2f74c4e
              
                fix
              
              
                zRzRzRzRzRzRzR 264060e
              
                [cogview4] Add attention mask support to transformer model
              
              
                OleehyO 9a10ceb
              
                [fix] Add attention mask for padded token
              
              
                OleehyO b6e10e7
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 692e5cc
              
                update
              
              
                zRzRzRzRzRzRzR fc3830c
              
                remove padding type
              
              
                zRzRzRzRzRzRzR 98a2417
              
                Update train_control_cogview4.py
              
              
                zRzRzRzRzRzRzR c774f45
              
                resolve conflicts with #10981
              
              
                zRzRzRzRzRzRzR 687faa4
              
                Merge branch 'main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 8abca19
              
                add control convert
              
              
                zRzRzRzRzRzRzR cbfeb0b
              
                Merge branch 'cogview4_control' of https://github.com/zRzRzRzRzRzRzR/…
              
              
                zRzRzRzRzRzRzR 347dd17
              
                use control format
              
              
                zRzRzRzRzRzRzR 775bb8c
              
                fix
              
              
                zRzRzRzRzRzRzR 985baa9
              
                add missing import
              
              
                zRzRzRzRzRzRzR c2a1985
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 88abb39
              
                update with cogview4 formate
              
              
                zRzRzRzRzRzRzR 3e3387e
              
                make style
              
              
                yiyixuxu ddb31d3
              
                Merge branch 'huggingface:main' into cogview4_control
              
              
                zRzRzRzRzRzRzR 64637ef
              
                Update pipeline_cogview4_control.py
              
              
                zRzRzRzRzRzRzR 4174736
              
                Update pipeline_cogview4_control.py
              
              
                zRzRzRzRzRzRzR 8cdd36f
              
                remove
              
              
                zRzRzRzRzRzRzR 785e230
              
                Update pipeline_cogview4_control.py
              
              
                zRzRzRzRzRzRzR 07ef22e
              
                put back
              
              
                zRzRzRzRzRzRzR 1e93e98
              
                Apply style fixes
              
              
                github-actions[bot] File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -68,7 +68,7 @@ def calculate_shift( | |
| return mu | ||
|  | ||
|  | ||
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | ||
| # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps | ||
| def retrieve_timesteps( | ||
| scheduler, | ||
| num_inference_steps: Optional[int] = None, | ||
|  | @@ -100,10 +100,19 @@ def retrieve_timesteps( | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | ||
| second element is the number of inference steps. | ||
| """ | ||
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's update the  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that how I understand it? | ||
| accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
|  | ||
| if timesteps is not None and sigmas is not None: | ||
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | ||
| if timesteps is not None: | ||
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
| if not accepts_timesteps and not accepts_sigmas: | ||
| raise ValueError( | ||
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
| f" timestep or sigma schedules. Please check whether you are using the correct scheduler." | ||
| ) | ||
| scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs) | ||
| timesteps = scheduler.timesteps | ||
| num_inference_steps = len(timesteps) | ||
| elif timesteps is not None and sigmas is None: | ||
| if not accepts_timesteps: | ||
| raise ValueError( | ||
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
|  | @@ -112,9 +121,8 @@ def retrieve_timesteps( | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | ||
| timesteps = scheduler.timesteps | ||
| num_inference_steps = len(timesteps) | ||
| elif sigmas is not None: | ||
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
| if not accept_sigmas: | ||
| elif timesteps is None and sigmas is not None: | ||
| if not accepts_sigmas: | ||
| raise ValueError( | ||
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
| f" sigmas schedules. Please check whether you are using the correct scheduler." | ||
|  | @@ -515,8 +523,8 @@ def __call__( | |
| The output format of the generate image. Choose between | ||
| [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead | ||
| of a plain tuple. | ||
| Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain | ||
| tuple. | ||
| attention_kwargs (`dict`, *optional*): | ||
| A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under | ||
| `self.processor` in | ||
|  | @@ -532,7 +540,6 @@ def __call__( | |
| `._callback_tensor_inputs` attribute of your pipeline class. | ||
| max_sequence_length (`int`, defaults to `224`): | ||
| Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. | ||
|  | ||
| Examples: | ||
|  | ||
| Returns: | ||
|  | ||
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just the change here is correct @zRzRzRzRzRzRzR