@@ -64,6 +64,34 @@ def get_dummy_inputs(self):
6464 }
6565 return inputs
6666
67+ def test_no_scaling (self ):
68+ inputs = self .get_dummy_inputs ()
69+ if inputs ["scaling_factor" ] is not None :
70+ inputs ["tensor" ] = inputs ["tensor" ] / inputs ["scaling_factor" ]
71+ inputs ["scaling_factor" ] = None
72+ if inputs ["shift_factor" ] is not None :
73+ inputs ["tensor" ] = inputs ["tensor" ] + inputs ["shift_factor" ]
74+ inputs ["shift_factor" ] = None
75+ processor = self .processor_cls ()
76+ output = remote_decode (
77+ output_type = "pt" ,
78+ # required for now, will be removed in next update
79+ do_scaling = False ,
80+ processor = processor ,
81+ ** inputs ,
82+ )
83+ assert isinstance (output , PIL .Image .Image )
84+ output .save ("test_no_scaling.png" )
85+ self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )} " )
86+ self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
87+ self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
88+ output_slice = torch .from_numpy (np .array (output )[0 , - 3 :, - 3 :].flatten ())
89+ # Increased tolerance for Flux Packed diff [1, 0, 1, 0, 0, 0, 0, 0, 0]
90+ self .assertTrue (
91+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1 , atol = 1 ),
92+ f"{ output_slice } " ,
93+ )
94+
6795 def test_output_type_pt (self ):
6896 inputs = self .get_dummy_inputs ()
6997 processor = self .processor_cls ()
0 commit comments