Skip to content

Commit b10ea13

Browse files
committed
no scaling
1 parent 86c2236 commit b10ea13

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/remote/test_remote_decode.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,34 @@ def get_dummy_inputs(self):
6464
}
6565
return inputs
6666

67+
def test_no_scaling(self):
68+
inputs = self.get_dummy_inputs()
69+
if inputs["scaling_factor"] is not None:
70+
inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"]
71+
inputs["scaling_factor"] = None
72+
if inputs["shift_factor"] is not None:
73+
inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"]
74+
inputs["shift_factor"] = None
75+
processor = self.processor_cls()
76+
output = remote_decode(
77+
output_type="pt",
78+
# required for now, will be removed in next update
79+
do_scaling=False,
80+
processor=processor,
81+
**inputs,
82+
)
83+
assert isinstance(output, PIL.Image.Image)
84+
output.save("test_no_scaling.png")
85+
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
86+
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
87+
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
88+
output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
89+
# Increased tolerance for Flux Packed diff [1, 0, 1, 0, 0, 0, 0, 0, 0]
90+
self.assertTrue(
91+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
92+
f"{output_slice}",
93+
)
94+
6795
def test_output_type_pt(self):
6896
inputs = self.get_dummy_inputs()
6997
processor = self.processor_cls()

0 commit comments

Comments
 (0)