@@ -45,6 +45,8 @@ class RemoteAutoencoderKLMixin:
4545    output_pt_slice : torch .Tensor  =  None 
4646    partial_postprocess_return_pt_slice : torch .Tensor  =  None 
4747    return_pt_slice : torch .Tensor  =  None 
48+     width : int  =  None 
49+     height : int  =  None 
4850
4951    def  get_dummy_inputs (self ):
5052        inputs  =  {
@@ -57,6 +59,8 @@ def get_dummy_inputs(self):
5759            ),
5860            "scaling_factor" : self .scaling_factor ,
5961            "shift_factor" : self .shift_factor ,
62+             "height" : self .height ,
63+             "width" : self .width ,
6064        }
6165        return  inputs 
6266
@@ -65,6 +69,7 @@ def test_output_type_pt(self):
6569        processor  =  self .processor_cls ()
6670        output  =  remote_decode (output_type = "pt" , processor = processor , ** inputs )
6771        assert  isinstance (output , PIL .Image .Image )
72+         output .save ("test_output_type_pt.png" )
6873        self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )}  )
6974        self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} { output .height }  )
7075        self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} { output .width }  )
@@ -77,13 +82,15 @@ def test_output_type_pt(self):
7782    def  test_output_type_pil (self ):
7883        inputs  =  self .get_dummy_inputs ()
7984        output  =  remote_decode (output_type = "pil" , ** inputs )
85+         output .save ("test_output_type_pil.png" )
8086        self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )}  )
8187        self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} { output .height }  )
8288        self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} { output .width }  )
8389
8490    def  test_output_type_pil_image_format (self ):
8591        inputs  =  self .get_dummy_inputs ()
8692        output  =  remote_decode (output_type = "pil" , image_format = "png" , ** inputs )
93+         output .save ("test_output_type_pil_image_format.png" )
8794        self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )}  )
8895        self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} { output .height }  )
8996        self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} { output .width }  )
@@ -96,6 +103,7 @@ def test_output_type_pil_image_format(self):
96103    def  test_output_type_pt_partial_postprocess (self ):
97104        inputs  =  self .get_dummy_inputs ()
98105        output  =  remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
106+         output .save ("test_output_type_pt_partial_postprocess.png" )
99107        self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )}  )
100108        self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} { output .height }  )
101109        self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} { output .width }  )
@@ -221,7 +229,6 @@ class RemoteAutoencoderKLFluxTests(
221229    RemoteAutoencoderKLMixin ,
222230    unittest .TestCase ,
223231):
224-     # TODO: packed 
225232    shape  =  (
226233        1 ,
227234        16 ,
@@ -242,3 +249,31 @@ class RemoteAutoencoderKLFluxTests(
242249        [202 , 203 , 203 , 197 , 195 , 193 , 189 , 188 , 178 ], dtype = torch .uint8 
243250    )
244251    return_pt_slice  =  torch .tensor ([0.5820 , 0.5962 , 0.5898 , 0.5439 , 0.5327 , 0.5112 , 0.4797 , 0.4773 , 0.3984 ])
252+ 
253+ 
254+ class  RemoteAutoencoderKLFluxPackedTests (
255+     RemoteAutoencoderKLMixin ,
256+     unittest .TestCase ,
257+ ):
258+     shape  =  (
259+         1 ,
260+         4096 ,
261+         64 ,
262+     )
263+     out_hw  =  (
264+         1024 ,
265+         1024 ,
266+     )
267+     height  =  1024 
268+     width  =  1024 
269+     endpoint  =  "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/" 
270+     dtype  =  torch .bfloat16 
271+     scaling_factor  =  0.3611 
272+     shift_factor  =  0.1159 
273+     processor_cls  =  VaeImageProcessor 
274+     # slices are different due to randn on different shape. we can pack the latent instead if we want the same 
275+     output_pt_slice  =  torch .tensor ([96 , 116 , 157 , 45 , 67 , 104 , 34 , 56 , 89 ], dtype = torch .uint8 )
276+     partial_postprocess_return_pt_slice  =  torch .tensor (
277+         [168 , 212 , 202 , 155 , 191 , 185 , 150 , 180 , 168 ], dtype = torch .uint8 
278+     )
279+     return_pt_slice  =  torch .tensor ([0.3198 , 0.6631 , 0.5864 , 0.2131 , 0.4944 , 0.4482 , 0.1776 , 0.4153 , 0.3176 ])
0 commit comments