@@ -45,6 +45,8 @@ class RemoteAutoencoderKLMixin:
4545 output_pt_slice : torch .Tensor = None
4646 partial_postprocess_return_pt_slice : torch .Tensor = None
4747 return_pt_slice : torch .Tensor = None
48+ width : int = None
49+ height : int = None
4850
4951 def get_dummy_inputs (self ):
5052 inputs = {
@@ -57,6 +59,8 @@ def get_dummy_inputs(self):
5759 ),
5860 "scaling_factor" : self .scaling_factor ,
5961 "shift_factor" : self .shift_factor ,
62+ "height" : self .height ,
63+ "width" : self .width ,
6064 }
6165 return inputs
6266
@@ -65,6 +69,7 @@ def test_output_type_pt(self):
6569 processor = self .processor_cls ()
6670 output = remote_decode (output_type = "pt" , processor = processor , ** inputs )
6771 assert isinstance (output , PIL .Image .Image )
72+ output .save ("test_output_type_pt.png" )
6873 self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )} " )
6974 self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
7075 self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
@@ -77,13 +82,15 @@ def test_output_type_pt(self):
7782 def test_output_type_pil (self ):
7883 inputs = self .get_dummy_inputs ()
7984 output = remote_decode (output_type = "pil" , ** inputs )
85+ output .save ("test_output_type_pil.png" )
8086 self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )} " )
8187 self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
8288 self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
8389
8490 def test_output_type_pil_image_format (self ):
8591 inputs = self .get_dummy_inputs ()
8692 output = remote_decode (output_type = "pil" , image_format = "png" , ** inputs )
93+ output .save ("test_output_type_pil_image_format.png" )
8794 self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )} " )
8895 self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
8996 self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
@@ -96,6 +103,7 @@ def test_output_type_pil_image_format(self):
96103 def test_output_type_pt_partial_postprocess (self ):
97104 inputs = self .get_dummy_inputs ()
98105 output = remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
106+ output .save ("test_output_type_pt_partial_postprocess.png" )
99107 self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )} " )
100108 self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
101109 self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
@@ -221,7 +229,6 @@ class RemoteAutoencoderKLFluxTests(
221229 RemoteAutoencoderKLMixin ,
222230 unittest .TestCase ,
223231):
224- # TODO: packed
225232 shape = (
226233 1 ,
227234 16 ,
@@ -242,3 +249,31 @@ class RemoteAutoencoderKLFluxTests(
242249 [202 , 203 , 203 , 197 , 195 , 193 , 189 , 188 , 178 ], dtype = torch .uint8
243250 )
244251 return_pt_slice = torch .tensor ([0.5820 , 0.5962 , 0.5898 , 0.5439 , 0.5327 , 0.5112 , 0.4797 , 0.4773 , 0.3984 ])
252+
253+
254+ class RemoteAutoencoderKLFluxPackedTests (
255+ RemoteAutoencoderKLMixin ,
256+ unittest .TestCase ,
257+ ):
258+ shape = (
259+ 1 ,
260+ 4096 ,
261+ 64 ,
262+ )
263+ out_hw = (
264+ 1024 ,
265+ 1024 ,
266+ )
267+ height = 1024
268+ width = 1024
269+ endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
270+ dtype = torch .bfloat16
271+ scaling_factor = 0.3611
272+ shift_factor = 0.1159
273+ processor_cls = VaeImageProcessor
274+ # slices are different due to randn on different shape. we can pack the latent instead if we want the same
275+ output_pt_slice = torch .tensor ([96 , 116 , 157 , 45 , 67 , 104 , 34 , 56 , 89 ], dtype = torch .uint8 )
276+ partial_postprocess_return_pt_slice = torch .tensor (
277+ [168 , 212 , 202 , 155 , 191 , 185 , 150 , 180 , 168 ], dtype = torch .uint8
278+ )
279+ return_pt_slice = torch .tensor ([0.3198 , 0.6631 , 0.5864 , 0.2131 , 0.4944 , 0.4482 , 0.1776 , 0.4153 , 0.3176 ])
0 commit comments