Skip to content

Commit e70bdb2

Browse files
committed
Add endpoints, tests
1 parent 140e0c2 commit e70bdb2

File tree

2 files changed

+80
-72
lines changed

2 files changed

+80
-72
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
5050

5151

52-
ENCODE_ENDPOINT_SD_V1 = ""
53-
ENCODE_ENDPOINT_SD_XL = ""
54-
ENCODE_ENDPOINT_FLUX = ""
52+
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
53+
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
54+
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
5555

5656

5757
def detect_image_type(data: bytes) -> str:

tests/remote/test_remote_encode.py

Lines changed: 77 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,25 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import unittest
1617

1718
import PIL.Image
1819
import torch
1920

2021
from diffusers.utils import load_image
2122
from diffusers.utils.remote_utils import (
23+
DECODE_ENDPOINT_FLUX,
24+
DECODE_ENDPOINT_SD_V1,
25+
DECODE_ENDPOINT_SD_XL,
26+
ENCODE_ENDPOINT_FLUX,
27+
ENCODE_ENDPOINT_SD_V1,
28+
ENCODE_ENDPOINT_SD_XL,
2229
remote_decode,
2330
remote_encode,
2431
)
2532
from diffusers.utils.testing_utils import (
2633
enable_full_determinism,
34+
slow,
2735
)
2836

2937

@@ -71,40 +79,40 @@ def test_image_input(self):
7179
# TODO: how to test this? encode->decode is lossy. expected slice of encoded latent?
7280

7381

74-
# class RemoteAutoencoderKLSDv1Tests(
75-
# RemoteAutoencoderKLEncodeMixin,
76-
# unittest.TestCase,
77-
# ):
78-
# channels = 4
79-
# endpoint = ENCODE_ENDPOINT_SD_V1
80-
# decode_endpoint = DECODE_ENDPOINT_SD_V1
81-
# dtype = torch.float16
82-
# scaling_factor = 0.18215
83-
# shift_factor = None
84-
85-
86-
# class RemoteAutoencoderKLSDXLTests(
87-
# RemoteAutoencoderKLEncodeMixin,
88-
# unittest.TestCase,
89-
# ):
90-
# channels = 4
91-
# endpoint = ENCODE_ENDPOINT_SD_XL
92-
# decode_endpoint = DECODE_ENDPOINT_SD_XL
93-
# dtype = torch.float16
94-
# scaling_factor = 0.13025
95-
# shift_factor = None
96-
97-
98-
# class RemoteAutoencoderKLFluxTests(
99-
# RemoteAutoencoderKLEncodeMixin,
100-
# unittest.TestCase,
101-
# ):
102-
# channels = 16
103-
# endpoint = DECODE_ENDPOINT_FLUX
104-
# decode_endpoint = ENCODE_ENDPOINT_FLUX
105-
# dtype = torch.bfloat16
106-
# scaling_factor = 0.3611
107-
# shift_factor = 0.1159
82+
class RemoteAutoencoderKLSDv1Tests(
83+
RemoteAutoencoderKLEncodeMixin,
84+
unittest.TestCase,
85+
):
86+
channels = 4
87+
endpoint = ENCODE_ENDPOINT_SD_V1
88+
decode_endpoint = DECODE_ENDPOINT_SD_V1
89+
dtype = torch.float16
90+
scaling_factor = 0.18215
91+
shift_factor = None
92+
93+
94+
class RemoteAutoencoderKLSDXLTests(
95+
RemoteAutoencoderKLEncodeMixin,
96+
unittest.TestCase,
97+
):
98+
channels = 4
99+
endpoint = ENCODE_ENDPOINT_SD_XL
100+
decode_endpoint = DECODE_ENDPOINT_SD_XL
101+
dtype = torch.float16
102+
scaling_factor = 0.13025
103+
shift_factor = None
104+
105+
106+
class RemoteAutoencoderKLFluxTests(
107+
RemoteAutoencoderKLEncodeMixin,
108+
unittest.TestCase,
109+
):
110+
channels = 16
111+
endpoint = ENCODE_ENDPOINT_FLUX
112+
decode_endpoint = DECODE_ENDPOINT_FLUX
113+
dtype = torch.bfloat16
114+
scaling_factor = 0.3611
115+
shift_factor = 0.1159
108116

109117

110118
class RemoteAutoencoderKLEncodeSlowTestMixin:
@@ -177,38 +185,38 @@ def test_multi_res(self):
177185
decoded.save(f"test_multi_res_{height}_{width}.png")
178186

179187

180-
# @slow
181-
# class RemoteAutoencoderKLSDv1SlowTests(
182-
# RemoteAutoencoderKLEncodeSlowTestMixin,
183-
# unittest.TestCase,
184-
# ):
185-
# endpoint = ENCODE_ENDPOINT_SD_V1
186-
# decode_endpoint = DECODE_ENDPOINT_SD_V1
187-
# dtype = torch.float16
188-
# scaling_factor = 0.18215
189-
# shift_factor = None
190-
191-
192-
# @slow
193-
# class RemoteAutoencoderKLSDXLSlowTests(
194-
# RemoteAutoencoderKLEncodeSlowTestMixin,
195-
# unittest.TestCase,
196-
# ):
197-
# endpoint = ENCODE_ENDPOINT_SD_XL
198-
# decode_endpoint = DECODE_ENDPOINT_SD_XL
199-
# dtype = torch.float16
200-
# scaling_factor = 0.13025
201-
# shift_factor = None
202-
203-
204-
# @slow
205-
# class RemoteAutoencoderKLFluxSlowTests(
206-
# RemoteAutoencoderKLEncodeSlowTestMixin,
207-
# unittest.TestCase,
208-
# ):
209-
# channels = 16
210-
# endpoint = ENCODE_ENDPOINT_FLUX
211-
# decode_endpoint = DECODE_ENDPOINT_FLUX
212-
# dtype = torch.bfloat16
213-
# scaling_factor = 0.3611
214-
# shift_factor = 0.1159
188+
@slow
189+
class RemoteAutoencoderKLSDv1SlowTests(
190+
RemoteAutoencoderKLEncodeSlowTestMixin,
191+
unittest.TestCase,
192+
):
193+
endpoint = ENCODE_ENDPOINT_SD_V1
194+
decode_endpoint = DECODE_ENDPOINT_SD_V1
195+
dtype = torch.float16
196+
scaling_factor = 0.18215
197+
shift_factor = None
198+
199+
200+
@slow
201+
class RemoteAutoencoderKLSDXLSlowTests(
202+
RemoteAutoencoderKLEncodeSlowTestMixin,
203+
unittest.TestCase,
204+
):
205+
endpoint = ENCODE_ENDPOINT_SD_XL
206+
decode_endpoint = DECODE_ENDPOINT_SD_XL
207+
dtype = torch.float16
208+
scaling_factor = 0.13025
209+
shift_factor = None
210+
211+
212+
@slow
213+
class RemoteAutoencoderKLFluxSlowTests(
214+
RemoteAutoencoderKLEncodeSlowTestMixin,
215+
unittest.TestCase,
216+
):
217+
channels = 16
218+
endpoint = ENCODE_ENDPOINT_FLUX
219+
decode_endpoint = DECODE_ENDPOINT_FLUX
220+
dtype = torch.bfloat16
221+
scaling_factor = 0.3611
222+
shift_factor = 0.1159

0 commit comments

Comments
 (0)