Skip to content

Commit 86c2236

Browse files
committed
flux packed
1 parent 2937eb2 commit 86c2236

File tree

1 file changed

+36
-1
lines changed

1 file changed

+36
-1
lines changed

tests/remote/test_remote_decode.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class RemoteAutoencoderKLMixin:
4545
output_pt_slice: torch.Tensor = None
4646
partial_postprocess_return_pt_slice: torch.Tensor = None
4747
return_pt_slice: torch.Tensor = None
48+
width: int = None
49+
height: int = None
4850

4951
def get_dummy_inputs(self):
5052
inputs = {
@@ -57,6 +59,8 @@ def get_dummy_inputs(self):
5759
),
5860
"scaling_factor": self.scaling_factor,
5961
"shift_factor": self.shift_factor,
62+
"height": self.height,
63+
"width": self.width,
6064
}
6165
return inputs
6266

@@ -65,6 +69,7 @@ def test_output_type_pt(self):
6569
processor = self.processor_cls()
6670
output = remote_decode(output_type="pt", processor=processor, **inputs)
6771
assert isinstance(output, PIL.Image.Image)
72+
output.save("test_output_type_pt.png")
6873
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
6974
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
7075
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
@@ -77,13 +82,15 @@ def test_output_type_pt(self):
7782
def test_output_type_pil(self):
7883
inputs = self.get_dummy_inputs()
7984
output = remote_decode(output_type="pil", **inputs)
85+
output.save("test_output_type_pil.png")
8086
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
8187
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
8288
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
8389

8490
def test_output_type_pil_image_format(self):
8591
inputs = self.get_dummy_inputs()
8692
output = remote_decode(output_type="pil", image_format="png", **inputs)
93+
output.save("test_output_type_pil_image_format.png")
8794
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
8895
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
8996
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
@@ -96,6 +103,7 @@ def test_output_type_pil_image_format(self):
96103
def test_output_type_pt_partial_postprocess(self):
97104
inputs = self.get_dummy_inputs()
98105
output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
106+
output.save("test_output_type_pt_partial_postprocess.png")
99107
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
100108
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
101109
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
@@ -221,7 +229,6 @@ class RemoteAutoencoderKLFluxTests(
221229
RemoteAutoencoderKLMixin,
222230
unittest.TestCase,
223231
):
224-
# TODO: packed
225232
shape = (
226233
1,
227234
16,
@@ -242,3 +249,31 @@ class RemoteAutoencoderKLFluxTests(
242249
[202, 203, 203, 197, 195, 193, 189, 188, 178], dtype=torch.uint8
243250
)
244251
return_pt_slice = torch.tensor([0.5820, 0.5962, 0.5898, 0.5439, 0.5327, 0.5112, 0.4797, 0.4773, 0.3984])
252+
253+
254+
class RemoteAutoencoderKLFluxPackedTests(
255+
RemoteAutoencoderKLMixin,
256+
unittest.TestCase,
257+
):
258+
shape = (
259+
1,
260+
4096,
261+
64,
262+
)
263+
out_hw = (
264+
1024,
265+
1024,
266+
)
267+
height = 1024
268+
width = 1024
269+
endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
270+
dtype = torch.bfloat16
271+
scaling_factor = 0.3611
272+
shift_factor = 0.1159
273+
processor_cls = VaeImageProcessor
274+
# slices are different due to randn on different shape. we can pack the latent instead if we want the same
275+
output_pt_slice = torch.tensor([96, 116, 157, 45, 67, 104, 34, 56, 89], dtype=torch.uint8)
276+
partial_postprocess_return_pt_slice = torch.tensor(
277+
[168, 212, 202, 155, 191, 185, 150, 180, 168], dtype=torch.uint8
278+
)
279+
return_pt_slice = torch.tensor([0.3198, 0.6631, 0.5864, 0.2131, 0.4944, 0.4482, 0.1776, 0.4153, 0.3176])

0 commit comments

Comments
 (0)