1616import unittest
1717from typing import Tuple , Union
1818
19+ import numpy as np
1920import PIL .Image
2021import torch
2122
2223from diffusers .image_processor import VaeImageProcessor
2324from diffusers .utils .remote_utils import remote_decode
2425from diffusers .utils .testing_utils import (
2526 enable_full_determinism ,
27+ torch_all_close ,
2628 torch_device ,
2729)
2830from diffusers .video_processor import VideoProcessor
@@ -39,11 +41,20 @@ class RemoteAutoencoderKLMixin:
3941 scaling_factor : float = None
4042 shift_factor : float = None
4143 processor_cls : Union [VaeImageProcessor , VideoProcessor ] = None
44+ output_pil_slice : torch .Tensor = None
45+ output_pt_slice : torch .Tensor = None
46+ partial_postprocess_return_pt_slice : torch .Tensor = None
47+ return_pt_slice : torch .Tensor = None
4248
4349 def get_dummy_inputs (self ):
4450 inputs = {
4551 "endpoint" : self .endpoint ,
46- "tensor" : torch .randn (self .shape , device = torch_device , dtype = self .dtype ),
52+ "tensor" : torch .randn (
53+ self .shape ,
54+ device = torch_device ,
55+ dtype = self .dtype ,
56+ generator = torch .Generator (torch_device ).manual_seed (13 ),
57+ ),
4758 "scaling_factor" : self .scaling_factor ,
4859 "shift_factor" : self .shift_factor ,
4960 }
@@ -53,10 +64,16 @@ def test_output_type_pt(self):
5364 inputs = self .get_dummy_inputs ()
5465 processor = self .processor_cls ()
5566 output = remote_decode (output_type = "pt" , processor = processor , ** inputs )
67+ assert isinstance (output , PIL .Image .Image )
5668 self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )} " )
5769 self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
5870 self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
71+ output_slice = torch .from_numpy (np .array (output )[0 , - 3 :, - 3 :].flatten ())
72+ self .assertTrue (
73+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1e-2 ), f"{ output_slice } "
74+ )
5975
76+ # output is visually the same, slice is flaky?
6077 def test_output_type_pil (self ):
6178 inputs = self .get_dummy_inputs ()
6279 output = remote_decode (output_type = "pil" , ** inputs )
@@ -71,13 +88,21 @@ def test_output_type_pil_image_format(self):
7188 self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
7289 self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
7390 self .assertEqual (output .format , "png" , f"Expected image format `png`, got { output .format } " )
91+ output_slice = torch .from_numpy (np .array (output )[0 , - 3 :, - 3 :].flatten ())
92+ self .assertTrue (
93+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1e-2 ), f"{ output_slice } "
94+ )
7495
7596 def test_output_type_pt_partial_postprocess (self ):
7697 inputs = self .get_dummy_inputs ()
7798 output = remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
7899 self .assertTrue (isinstance (output , PIL .Image .Image ), f"Expected `PIL.Image.Image` output, got { type (output )} " )
79100 self .assertEqual (output .height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .height } " )
80101 self .assertEqual (output .width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .width } " )
102+ output_slice = torch .from_numpy (np .array (output )[0 , - 3 :, - 3 :].flatten ())
103+ self .assertTrue (
104+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1e-2 ), f"{ output_slice } "
105+ )
81106
82107 def test_output_type_pt_return_type_pt (self ):
83108 inputs = self .get_dummy_inputs ()
@@ -89,6 +114,11 @@ def test_output_type_pt_return_type_pt(self):
89114 self .assertEqual (
90115 output .shape [3 ], self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .shape [3 ]} "
91116 )
117+ output_slice = output [0 , 0 , - 3 :, - 3 :].flatten ()
118+ self .assertTrue (
119+ torch_all_close (output_slice , self .return_pt_slice .to (output_slice .dtype ), rtol = 1e-3 , atol = 1e-3 ),
120+ f"{ output_slice } " ,
121+ )
92122
93123 def test_output_type_pt_partial_postprocess_return_type_pt (self ):
94124 inputs = self .get_dummy_inputs ()
@@ -100,6 +130,11 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self):
100130 self .assertEqual (
101131 output .shape [2 ], self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .shape [2 ]} "
102132 )
133+ output_slice = output [0 , - 3 :, - 3 :, 0 ].flatten ().cpu ()
134+ self .assertTrue (
135+ torch_all_close (output_slice , self .partial_postprocess_return_pt_slice .to (output_slice .dtype ), rtol = 1e-2 ),
136+ f"{ output_slice } " ,
137+ )
103138
104139 def test_do_scaling_deprecation (self ):
105140 inputs = self .get_dummy_inputs ()
@@ -133,6 +168,7 @@ def test_output_tensor_type_base64_deprecation(self):
133168 str (warning .warnings [0 ].message ),
134169 )
135170
171+
136172class RemoteAutoencoderKLSDv1Tests (
137173 RemoteAutoencoderKLMixin ,
138174 unittest .TestCase ,
@@ -152,6 +188,9 @@ class RemoteAutoencoderKLSDv1Tests(
152188 scaling_factor = 0.18215
153189 shift_factor = None
154190 processor_cls = VaeImageProcessor
191+ output_pt_slice = torch .tensor ([31 , 15 , 11 , 55 , 30 , 21 , 66 , 42 , 30 ], dtype = torch .uint8 )
192+ partial_postprocess_return_pt_slice = torch .tensor ([100 , 130 , 99 , 133 , 106 , 112 , 97 , 100 , 121 ], dtype = torch .uint8 )
193+ return_pt_slice = torch .tensor ([- 0.2177 , 0.0217 , - 0.2258 , 0.0412 , - 0.1687 , - 0.1232 , - 0.2416 , - 0.2130 , - 0.0543 ])
155194
156195
157196class RemoteAutoencoderKLSDXLTests (
@@ -168,30 +207,36 @@ class RemoteAutoencoderKLSDXLTests(
168207 1024 ,
169208 1024 ,
170209 )
171- endpoint = ""
210+ endpoint = "https://fagf07t3bwf0615i.us-east-1.aws.endpoints.huggingface.cloud/ "
172211 dtype = torch .float16
173212 scaling_factor = 0.13025
174213 shift_factor = None
175214 processor_cls = VaeImageProcessor
176-
177-
178- class RemoteAutoencoderKLFluxTests (
179- RemoteAutoencoderKLMixin ,
180- unittest .TestCase ,
181- ):
182- # TODO: packed
183- shape = (
184- 1 ,
185- 16 ,
186- 128 ,
187- 128 ,
188- )
189- out_hw = (
190- 1024 ,
191- 1024 ,
192- )
193- endpoint = ""
194- dtype = torch .bfloat16
195- scaling_factor = 0.3611
196- shift_factor = 0.1159
197- processor_cls = VaeImageProcessor
215+ output_pt_slice = torch .tensor ([104 , 52 , 23 , 114 , 61 , 35 , 108 , 87 , 38 ], dtype = torch .uint8 )
216+ partial_postprocess_return_pt_slice = torch .tensor ([77 , 86 , 89 , 49 , 60 , 75 , 52 , 65 , 78 ], dtype = torch .uint8 )
217+ return_pt_slice = torch .tensor ([- 0.3945 , - 0.3289 , - 0.2993 , - 0.6177 , - 0.5259 , - 0.4119 , - 0.5898 , - 0.4863 , - 0.3845 ])
218+
219+
220+ # class RemoteAutoencoderKLFluxTests(
221+ # RemoteAutoencoderKLMixin,
222+ # unittest.TestCase,
223+ # ):
224+ # # TODO: packed
225+ # shape = (
226+ # 1,
227+ # 16,
228+ # 128,
229+ # 128,
230+ # )
231+ # out_hw = (
232+ # 1024,
233+ # 1024,
234+ # )
235+ # endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
236+ # dtype = torch.bfloat16
237+ # scaling_factor = 0.3611
238+ # shift_factor = 0.1159
239+ # processor_cls = VaeImageProcessor
240+ # output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8)
241+ # partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8)
242+ # return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845])
0 commit comments