|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import unittest |
16 | 17 |
|
17 | 18 | import PIL.Image |
18 | 19 | import torch |
19 | 20 |
|
20 | 21 | from diffusers.utils import load_image |
21 | 22 | from diffusers.utils.remote_utils import ( |
| 23 | + DECODE_ENDPOINT_FLUX, |
| 24 | + DECODE_ENDPOINT_SD_V1, |
| 25 | + DECODE_ENDPOINT_SD_XL, |
| 26 | + ENCODE_ENDPOINT_FLUX, |
| 27 | + ENCODE_ENDPOINT_SD_V1, |
| 28 | + ENCODE_ENDPOINT_SD_XL, |
22 | 29 | remote_decode, |
23 | 30 | remote_encode, |
24 | 31 | ) |
25 | 32 | from diffusers.utils.testing_utils import ( |
26 | 33 | enable_full_determinism, |
| 34 | + slow, |
27 | 35 | ) |
28 | 36 |
|
29 | 37 |
|
@@ -71,40 +79,40 @@ def test_image_input(self): |
71 | 79 | # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? |
72 | 80 |
|
73 | 81 |
|
74 | | -# class RemoteAutoencoderKLSDv1Tests( |
75 | | -# RemoteAutoencoderKLEncodeMixin, |
76 | | -# unittest.TestCase, |
77 | | -# ): |
78 | | -# channels = 4 |
79 | | -# endpoint = ENCODE_ENDPOINT_SD_V1 |
80 | | -# decode_endpoint = DECODE_ENDPOINT_SD_V1 |
81 | | -# dtype = torch.float16 |
82 | | -# scaling_factor = 0.18215 |
83 | | -# shift_factor = None |
84 | | - |
85 | | - |
86 | | -# class RemoteAutoencoderKLSDXLTests( |
87 | | -# RemoteAutoencoderKLEncodeMixin, |
88 | | -# unittest.TestCase, |
89 | | -# ): |
90 | | -# channels = 4 |
91 | | -# endpoint = ENCODE_ENDPOINT_SD_XL |
92 | | -# decode_endpoint = DECODE_ENDPOINT_SD_XL |
93 | | -# dtype = torch.float16 |
94 | | -# scaling_factor = 0.13025 |
95 | | -# shift_factor = None |
96 | | - |
97 | | - |
98 | | -# class RemoteAutoencoderKLFluxTests( |
99 | | -# RemoteAutoencoderKLEncodeMixin, |
100 | | -# unittest.TestCase, |
101 | | -# ): |
102 | | -# channels = 16 |
103 | | -# endpoint = DECODE_ENDPOINT_FLUX |
104 | | -# decode_endpoint = ENCODE_ENDPOINT_FLUX |
105 | | -# dtype = torch.bfloat16 |
106 | | -# scaling_factor = 0.3611 |
107 | | -# shift_factor = 0.1159 |
| 82 | +class RemoteAutoencoderKLSDv1Tests( |
| 83 | + RemoteAutoencoderKLEncodeMixin, |
| 84 | + unittest.TestCase, |
| 85 | +): |
| 86 | + channels = 4 |
| 87 | + endpoint = ENCODE_ENDPOINT_SD_V1 |
| 88 | + decode_endpoint = DECODE_ENDPOINT_SD_V1 |
| 89 | + dtype = torch.float16 |
| 90 | + scaling_factor = 0.18215 |
| 91 | + shift_factor = None |
| 92 | + |
| 93 | + |
| 94 | +class RemoteAutoencoderKLSDXLTests( |
| 95 | + RemoteAutoencoderKLEncodeMixin, |
| 96 | + unittest.TestCase, |
| 97 | +): |
| 98 | + channels = 4 |
| 99 | + endpoint = ENCODE_ENDPOINT_SD_XL |
| 100 | + decode_endpoint = DECODE_ENDPOINT_SD_XL |
| 101 | + dtype = torch.float16 |
| 102 | + scaling_factor = 0.13025 |
| 103 | + shift_factor = None |
| 104 | + |
| 105 | + |
| 106 | +class RemoteAutoencoderKLFluxTests( |
| 107 | + RemoteAutoencoderKLEncodeMixin, |
| 108 | + unittest.TestCase, |
| 109 | +): |
| 110 | + channels = 16 |
| 111 | + endpoint = ENCODE_ENDPOINT_FLUX |
| 112 | + decode_endpoint = DECODE_ENDPOINT_FLUX |
| 113 | + dtype = torch.bfloat16 |
| 114 | + scaling_factor = 0.3611 |
| 115 | + shift_factor = 0.1159 |
108 | 116 |
|
109 | 117 |
|
110 | 118 | class RemoteAutoencoderKLEncodeSlowTestMixin: |
@@ -177,38 +185,38 @@ def test_multi_res(self): |
177 | 185 | decoded.save(f"test_multi_res_{height}_{width}.png") |
178 | 186 |
|
179 | 187 |
|
180 | | -# @slow |
181 | | -# class RemoteAutoencoderKLSDv1SlowTests( |
182 | | -# RemoteAutoencoderKLEncodeSlowTestMixin, |
183 | | -# unittest.TestCase, |
184 | | -# ): |
185 | | -# endpoint = ENCODE_ENDPOINT_SD_V1 |
186 | | -# decode_endpoint = DECODE_ENDPOINT_SD_V1 |
187 | | -# dtype = torch.float16 |
188 | | -# scaling_factor = 0.18215 |
189 | | -# shift_factor = None |
190 | | - |
191 | | - |
192 | | -# @slow |
193 | | -# class RemoteAutoencoderKLSDXLSlowTests( |
194 | | -# RemoteAutoencoderKLEncodeSlowTestMixin, |
195 | | -# unittest.TestCase, |
196 | | -# ): |
197 | | -# endpoint = ENCODE_ENDPOINT_SD_XL |
198 | | -# decode_endpoint = DECODE_ENDPOINT_SD_XL |
199 | | -# dtype = torch.float16 |
200 | | -# scaling_factor = 0.13025 |
201 | | -# shift_factor = None |
202 | | - |
203 | | - |
204 | | -# @slow |
205 | | -# class RemoteAutoencoderKLFluxSlowTests( |
206 | | -# RemoteAutoencoderKLEncodeSlowTestMixin, |
207 | | -# unittest.TestCase, |
208 | | -# ): |
209 | | -# channels = 16 |
210 | | -# endpoint = ENCODE_ENDPOINT_FLUX |
211 | | -# decode_endpoint = DECODE_ENDPOINT_FLUX |
212 | | -# dtype = torch.bfloat16 |
213 | | -# scaling_factor = 0.3611 |
214 | | -# shift_factor = 0.1159 |
| 188 | +@slow |
| 189 | +class RemoteAutoencoderKLSDv1SlowTests( |
| 190 | + RemoteAutoencoderKLEncodeSlowTestMixin, |
| 191 | + unittest.TestCase, |
| 192 | +): |
| 193 | + endpoint = ENCODE_ENDPOINT_SD_V1 |
| 194 | + decode_endpoint = DECODE_ENDPOINT_SD_V1 |
| 195 | + dtype = torch.float16 |
| 196 | + scaling_factor = 0.18215 |
| 197 | + shift_factor = None |
| 198 | + |
| 199 | + |
| 200 | +@slow |
| 201 | +class RemoteAutoencoderKLSDXLSlowTests( |
| 202 | + RemoteAutoencoderKLEncodeSlowTestMixin, |
| 203 | + unittest.TestCase, |
| 204 | +): |
| 205 | + endpoint = ENCODE_ENDPOINT_SD_XL |
| 206 | + decode_endpoint = DECODE_ENDPOINT_SD_XL |
| 207 | + dtype = torch.float16 |
| 208 | + scaling_factor = 0.13025 |
| 209 | + shift_factor = None |
| 210 | + |
| 211 | + |
| 212 | +@slow |
| 213 | +class RemoteAutoencoderKLFluxSlowTests( |
| 214 | + RemoteAutoencoderKLEncodeSlowTestMixin, |
| 215 | + unittest.TestCase, |
| 216 | +): |
| 217 | + channels = 16 |
| 218 | + endpoint = ENCODE_ENDPOINT_FLUX |
| 219 | + decode_endpoint = DECODE_ENDPOINT_FLUX |
| 220 | + dtype = torch.bfloat16 |
| 221 | + scaling_factor = 0.3611 |
| 222 | + shift_factor = 0.1159 |
0 commit comments