- 
                Notifications
    
You must be signed in to change notification settings  - Fork 6.5k
 
[hybrid inference 🍯🐝] Wan 2.1 decode #11031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
081e68f
              140e0c2
              e70bdb2
              e5448f2
              15914a9
              0a2231a
              0f5705b
              998c3c6
              b2756ad
              c6ac397
              abb3e3b
              73adcd8
              79620bf
              a1eacb3
              3f44fa1
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -26,6 +26,7 @@ | |
| DECODE_ENDPOINT_HUNYUAN_VIDEO, | ||
| DECODE_ENDPOINT_SD_V1, | ||
| DECODE_ENDPOINT_SD_XL, | ||
| DECODE_ENDPOINT_WAN_2_1, | ||
| ) | ||
| from diffusers.utils.remote_utils import ( | ||
| remote_decode, | ||
| 
          
            
          
           | 
    @@ -176,18 +177,6 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self): | |
| f"{output_slice}", | ||
| ) | ||
| 
     | 
||
| def test_do_scaling_deprecation(self): | ||
| inputs = self.get_dummy_inputs() | ||
| inputs.pop("scaling_factor", None) | ||
| inputs.pop("shift_factor", None) | ||
| with self.assertWarns(FutureWarning) as warning: | ||
| _ = remote_decode(output_type="pt", partial_postprocess=True, **inputs) | ||
| self.assertEqual( | ||
| str(warning.warnings[0].message), | ||
| "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", | ||
| str(warning.warnings[0].message), | ||
| ) | ||
| 
     | 
||
| def test_input_tensor_type_base64_deprecation(self): | ||
| inputs = self.get_dummy_inputs() | ||
| with self.assertWarns(FutureWarning) as warning: | ||
| 
        
          
        
         | 
    @@ -209,7 +198,7 @@ def test_output_tensor_type_base64_deprecation(self): | |
| ) | ||
| 
     | 
||
| 
     | 
||
| class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin): | ||
| class RemoteAutoencoderKLVideoMixin(RemoteAutoencoderKLMixin): | ||
| def test_no_scaling(self): | ||
| inputs = self.get_dummy_inputs() | ||
| if inputs["scaling_factor"] is not None: | ||
| 
        
          
        
         | 
    @@ -221,7 +210,6 @@ def test_no_scaling(self): | |
| processor = self.processor_cls() | ||
| output = remote_decode( | ||
| output_type="pt", | ||
| # required for now, will be removed in next update | ||
| do_scaling=False, | ||
| processor=processor, | ||
| **inputs, | ||
| 
          
            
          
           | 
    @@ -337,6 +325,8 @@ def test_output_type_mp4(self): | |
| inputs = self.get_dummy_inputs() | ||
| output = remote_decode(output_type="mp4", return_type="mp4", **inputs) | ||
| self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}") | ||
| with open("test.mp4", "wb") as f: | ||
| f.write(output) | ||
| 
     | 
||
| 
     | 
||
| class RemoteAutoencoderKLSDv1Tests( | ||
| 
          
            
          
           | 
    @@ -442,7 +432,7 @@ class RemoteAutoencoderKLFluxPackedTests( | |
| 
     | 
||
| 
     | 
||
| class RemoteAutoencoderKLHunyuanVideoTests( | ||
| RemoteAutoencoderKLHunyuanVideoMixin, | ||
| RemoteAutoencoderKLVideoMixin, | ||
| unittest.TestCase, | ||
| ): | ||
| shape = ( | ||
| 
        
          
        
         | 
    @@ -467,6 +457,31 @@ class RemoteAutoencoderKLHunyuanVideoTests( | |
| return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708]) | ||
| 
     | 
||
| 
     | 
||
| class RemoteAutoencoderKLWanTests( | ||
| RemoteAutoencoderKLVideoMixin, | ||
| unittest.TestCase, | ||
| ): | ||
| shape = ( | ||
| 1, | ||
| 16, | ||
| 3, | ||
| 40, | ||
| 64, | ||
| ) | ||
| out_hw = ( | ||
| 320, | ||
| 512, | ||
| ) | ||
| endpoint = DECODE_ENDPOINT_WAN_2_1 | ||
| dtype = torch.float16 | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @yiyixuxu Currently the endpoint is running in float16 and output seems ok (on random latent), the examples use float32 but we noticed in the original code that everything is under bfloat16 autocast context. Can we check with the authors regarding the use of float32 for VAE? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i haven't heard back from them,  | 
||
| processor_cls = VideoProcessor | ||
| output_pt_slice = torch.tensor([203, 174, 178, 204, 171, 177, 209, 183, 182], dtype=torch.uint8) | ||
| partial_postprocess_return_pt_slice = torch.tensor( | ||
| [206, 209, 221, 202, 199, 222, 207, 210, 217], dtype=torch.uint8 | ||
| ) | ||
| return_pt_slice = torch.tensor([0.6196, 0.6382, 0.7310, 0.5869, 0.5625, 0.7373, 0.6240, 0.6465, 0.7002]) | ||
| 
     | 
||
| 
     | 
||
| class RemoteAutoencoderKLSlowTestMixin: | ||
| channels: int = 4 | ||
| endpoint: str = None | ||
| 
          
            
          
           | 
    ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wan VAE doesn't have a
scaling_factorthat we could pass, so this deprecation/removingdo_scalingdoesn't work, we will keep it.