|  | 
| 13 | 13 | # See the License for the specific language governing permissions and | 
| 14 | 14 | # limitations under the License. | 
| 15 | 15 | 
 | 
|  | 16 | +import unittest | 
| 16 | 17 | 
 | 
| 17 | 18 | import PIL.Image | 
| 18 | 19 | import torch | 
| 19 | 20 | 
 | 
| 20 | 21 | from diffusers.utils import load_image | 
| 21 | 22 | from diffusers.utils.remote_utils import ( | 
|  | 23 | +    DECODE_ENDPOINT_FLUX, | 
|  | 24 | +    DECODE_ENDPOINT_SD_V1, | 
|  | 25 | +    DECODE_ENDPOINT_SD_XL, | 
|  | 26 | +    ENCODE_ENDPOINT_FLUX, | 
|  | 27 | +    ENCODE_ENDPOINT_SD_V1, | 
|  | 28 | +    ENCODE_ENDPOINT_SD_XL, | 
| 22 | 29 |     remote_decode, | 
| 23 | 30 |     remote_encode, | 
| 24 | 31 | ) | 
| 25 | 32 | from diffusers.utils.testing_utils import ( | 
| 26 | 33 |     enable_full_determinism, | 
|  | 34 | +    slow, | 
| 27 | 35 | ) | 
| 28 | 36 | 
 | 
| 29 | 37 | 
 | 
| @@ -71,40 +79,40 @@ def test_image_input(self): | 
| 71 | 79 |         # TODO: how to test this? encode->decode is lossy. expected slice of encoded latent? | 
| 72 | 80 | 
 | 
| 73 | 81 | 
 | 
| 74 |  | -# class RemoteAutoencoderKLSDv1Tests( | 
| 75 |  | -#     RemoteAutoencoderKLEncodeMixin, | 
| 76 |  | -#     unittest.TestCase, | 
| 77 |  | -# ): | 
| 78 |  | -#     channels = 4 | 
| 79 |  | -#     endpoint = ENCODE_ENDPOINT_SD_V1 | 
| 80 |  | -#     decode_endpoint = DECODE_ENDPOINT_SD_V1 | 
| 81 |  | -#     dtype = torch.float16 | 
| 82 |  | -#     scaling_factor = 0.18215 | 
| 83 |  | -#     shift_factor = None | 
| 84 |  | - | 
| 85 |  | - | 
| 86 |  | -# class RemoteAutoencoderKLSDXLTests( | 
| 87 |  | -#     RemoteAutoencoderKLEncodeMixin, | 
| 88 |  | -#     unittest.TestCase, | 
| 89 |  | -# ): | 
| 90 |  | -#     channels = 4 | 
| 91 |  | -#     endpoint = ENCODE_ENDPOINT_SD_XL | 
| 92 |  | -#     decode_endpoint = DECODE_ENDPOINT_SD_XL | 
| 93 |  | -#     dtype = torch.float16 | 
| 94 |  | -#     scaling_factor = 0.13025 | 
| 95 |  | -#     shift_factor = None | 
| 96 |  | - | 
| 97 |  | - | 
| 98 |  | -# class RemoteAutoencoderKLFluxTests( | 
| 99 |  | -#     RemoteAutoencoderKLEncodeMixin, | 
| 100 |  | -#     unittest.TestCase, | 
| 101 |  | -# ): | 
| 102 |  | -#     channels = 16 | 
| 103 |  | -#     endpoint = DECODE_ENDPOINT_FLUX | 
| 104 |  | -#     decode_endpoint = ENCODE_ENDPOINT_FLUX | 
| 105 |  | -#     dtype = torch.bfloat16 | 
| 106 |  | -#     scaling_factor = 0.3611 | 
| 107 |  | -#     shift_factor = 0.1159 | 
|  | 82 | +class RemoteAutoencoderKLSDv1Tests( | 
|  | 83 | +    RemoteAutoencoderKLEncodeMixin, | 
|  | 84 | +    unittest.TestCase, | 
|  | 85 | +): | 
|  | 86 | +    channels = 4 | 
|  | 87 | +    endpoint = ENCODE_ENDPOINT_SD_V1 | 
|  | 88 | +    decode_endpoint = DECODE_ENDPOINT_SD_V1 | 
|  | 89 | +    dtype = torch.float16 | 
|  | 90 | +    scaling_factor = 0.18215 | 
|  | 91 | +    shift_factor = None | 
|  | 92 | + | 
|  | 93 | + | 
|  | 94 | +class RemoteAutoencoderKLSDXLTests( | 
|  | 95 | +    RemoteAutoencoderKLEncodeMixin, | 
|  | 96 | +    unittest.TestCase, | 
|  | 97 | +): | 
|  | 98 | +    channels = 4 | 
|  | 99 | +    endpoint = ENCODE_ENDPOINT_SD_XL | 
|  | 100 | +    decode_endpoint = DECODE_ENDPOINT_SD_XL | 
|  | 101 | +    dtype = torch.float16 | 
|  | 102 | +    scaling_factor = 0.13025 | 
|  | 103 | +    shift_factor = None | 
|  | 104 | + | 
|  | 105 | + | 
|  | 106 | +class RemoteAutoencoderKLFluxTests( | 
|  | 107 | +    RemoteAutoencoderKLEncodeMixin, | 
|  | 108 | +    unittest.TestCase, | 
|  | 109 | +): | 
|  | 110 | +    channels = 16 | 
|  | 111 | +    endpoint = ENCODE_ENDPOINT_FLUX | 
|  | 112 | +    decode_endpoint = DECODE_ENDPOINT_FLUX | 
|  | 113 | +    dtype = torch.bfloat16 | 
|  | 114 | +    scaling_factor = 0.3611 | 
|  | 115 | +    shift_factor = 0.1159 | 
| 108 | 116 | 
 | 
| 109 | 117 | 
 | 
| 110 | 118 | class RemoteAutoencoderKLEncodeSlowTestMixin: | 
| @@ -177,38 +185,38 @@ def test_multi_res(self): | 
| 177 | 185 |                 decoded.save(f"test_multi_res_{height}_{width}.png") | 
| 178 | 186 | 
 | 
| 179 | 187 | 
 | 
| 180 |  | -# @slow | 
| 181 |  | -# class RemoteAutoencoderKLSDv1SlowTests( | 
| 182 |  | -#     RemoteAutoencoderKLEncodeSlowTestMixin, | 
| 183 |  | -#     unittest.TestCase, | 
| 184 |  | -# ): | 
| 185 |  | -#     endpoint = ENCODE_ENDPOINT_SD_V1 | 
| 186 |  | -#     decode_endpoint = DECODE_ENDPOINT_SD_V1 | 
| 187 |  | -#     dtype = torch.float16 | 
| 188 |  | -#     scaling_factor = 0.18215 | 
| 189 |  | -#     shift_factor = None | 
| 190 |  | - | 
| 191 |  | - | 
| 192 |  | -# @slow | 
| 193 |  | -# class RemoteAutoencoderKLSDXLSlowTests( | 
| 194 |  | -#     RemoteAutoencoderKLEncodeSlowTestMixin, | 
| 195 |  | -#     unittest.TestCase, | 
| 196 |  | -# ): | 
| 197 |  | -#     endpoint = ENCODE_ENDPOINT_SD_XL | 
| 198 |  | -#     decode_endpoint = DECODE_ENDPOINT_SD_XL | 
| 199 |  | -#     dtype = torch.float16 | 
| 200 |  | -#     scaling_factor = 0.13025 | 
| 201 |  | -#     shift_factor = None | 
| 202 |  | - | 
| 203 |  | - | 
| 204 |  | -# @slow | 
| 205 |  | -# class RemoteAutoencoderKLFluxSlowTests( | 
| 206 |  | -#     RemoteAutoencoderKLEncodeSlowTestMixin, | 
| 207 |  | -#     unittest.TestCase, | 
| 208 |  | -# ): | 
| 209 |  | -#     channels = 16 | 
| 210 |  | -#     endpoint = ENCODE_ENDPOINT_FLUX | 
| 211 |  | -#     decode_endpoint = DECODE_ENDPOINT_FLUX | 
| 212 |  | -#     dtype = torch.bfloat16 | 
| 213 |  | -#     scaling_factor = 0.3611 | 
| 214 |  | -#     shift_factor = 0.1159 | 
|  | 188 | +@slow | 
|  | 189 | +class RemoteAutoencoderKLSDv1SlowTests( | 
|  | 190 | +    RemoteAutoencoderKLEncodeSlowTestMixin, | 
|  | 191 | +    unittest.TestCase, | 
|  | 192 | +): | 
|  | 193 | +    endpoint = ENCODE_ENDPOINT_SD_V1 | 
|  | 194 | +    decode_endpoint = DECODE_ENDPOINT_SD_V1 | 
|  | 195 | +    dtype = torch.float16 | 
|  | 196 | +    scaling_factor = 0.18215 | 
|  | 197 | +    shift_factor = None | 
|  | 198 | + | 
|  | 199 | + | 
|  | 200 | +@slow | 
|  | 201 | +class RemoteAutoencoderKLSDXLSlowTests( | 
|  | 202 | +    RemoteAutoencoderKLEncodeSlowTestMixin, | 
|  | 203 | +    unittest.TestCase, | 
|  | 204 | +): | 
|  | 205 | +    endpoint = ENCODE_ENDPOINT_SD_XL | 
|  | 206 | +    decode_endpoint = DECODE_ENDPOINT_SD_XL | 
|  | 207 | +    dtype = torch.float16 | 
|  | 208 | +    scaling_factor = 0.13025 | 
|  | 209 | +    shift_factor = None | 
|  | 210 | + | 
|  | 211 | + | 
|  | 212 | +@slow | 
|  | 213 | +class RemoteAutoencoderKLFluxSlowTests( | 
|  | 214 | +    RemoteAutoencoderKLEncodeSlowTestMixin, | 
|  | 215 | +    unittest.TestCase, | 
|  | 216 | +): | 
|  | 217 | +    channels = 16 | 
|  | 218 | +    endpoint = ENCODE_ENDPOINT_FLUX | 
|  | 219 | +    decode_endpoint = DECODE_ENDPOINT_FLUX | 
|  | 220 | +    dtype = torch.bfloat16 | 
|  | 221 | +    scaling_factor = 0.3611 | 
|  | 222 | +    shift_factor = 0.1159 | 
0 commit comments