Skip to content

Commit 485d99e

Browse files
committed
more test
1 parent d16c855 commit 485d99e

File tree

2 files changed

+80
-28
lines changed

2 files changed

+80
-28
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,20 @@ def remote_decode(
193193
processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
194194
Used with `return_type="pt"`, and `return_type="pil"` for Video models.
195195
do_scaling (`bool`, default `True`, *optional*):
196-
**DEPRECATED**. When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If
196+
**DEPRECATED**. **pass `scaling_factor`/`shift_factor` instead.**
197+
**still set do_scaling=None/do_scaling=False for no scaling until option is removed**
198+
When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If
197199
`False`, input must be passed with scaling applied.
198200
scaling_factor (`float`, *optional*):
199-
Scaling is applied when passed e.g. `latents / self.vae.config.scaling_factor`. SD v1: 0.18215 SD XL:
200-
0.13025 Flux: 0.3611
201+
Scaling is applied when passed e.g. `latents / self.vae.config.scaling_factor`.
202+
- SD v1: 0.18215
203+
- SD XL: 0.13025
204+
- Flux: 0.3611
205+
If `None`, input must be passed with scaling applied.
201206
shift_factor (`float`, *optional*):
202-
Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`. Flux: 0.1159
207+
Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`.
208+
- Flux: 0.1159
209+
If `None`, input must be passed with scaling applied.
203210
output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
204211
**Endpoint** output type. Subject to change. Report feedback on preferred type.
205212

tests/remote/test_remote_decode.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import unittest
1717
from typing import Tuple, Union
1818

19+
import numpy as np
1920
import PIL.Image
2021
import torch
2122

2223
from diffusers.image_processor import VaeImageProcessor
2324
from diffusers.utils.remote_utils import remote_decode
2425
from diffusers.utils.testing_utils import (
2526
enable_full_determinism,
27+
torch_all_close,
2628
torch_device,
2729
)
2830
from diffusers.video_processor import VideoProcessor
@@ -39,11 +41,20 @@ class RemoteAutoencoderKLMixin:
3941
scaling_factor: float = None
4042
shift_factor: float = None
4143
processor_cls: Union[VaeImageProcessor, VideoProcessor] = None
44+
output_pil_slice: torch.Tensor = None
45+
output_pt_slice: torch.Tensor = None
46+
partial_postprocess_return_pt_slice: torch.Tensor = None
47+
return_pt_slice: torch.Tensor = None
4248

4349
def get_dummy_inputs(self):
4450
inputs = {
4551
"endpoint": self.endpoint,
46-
"tensor": torch.randn(self.shape, device=torch_device, dtype=self.dtype),
52+
"tensor": torch.randn(
53+
self.shape,
54+
device=torch_device,
55+
dtype=self.dtype,
56+
generator=torch.Generator(torch_device).manual_seed(13),
57+
),
4758
"scaling_factor": self.scaling_factor,
4859
"shift_factor": self.shift_factor,
4960
}
@@ -53,10 +64,16 @@ def test_output_type_pt(self):
5364
inputs = self.get_dummy_inputs()
5465
processor = self.processor_cls()
5566
output = remote_decode(output_type="pt", processor=processor, **inputs)
67+
assert isinstance(output, PIL.Image.Image)
5668
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
5769
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
5870
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
71+
output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
72+
self.assertTrue(
73+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}"
74+
)
5975

76+
# output is visually the same, slice is flaky?
6077
def test_output_type_pil(self):
6178
inputs = self.get_dummy_inputs()
6279
output = remote_decode(output_type="pil", **inputs)
@@ -71,13 +88,21 @@ def test_output_type_pil_image_format(self):
7188
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
7289
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
7390
self.assertEqual(output.format, "png", f"Expected image format `png`, got {output.format}")
91+
output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
92+
self.assertTrue(
93+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}"
94+
)
7495

7596
def test_output_type_pt_partial_postprocess(self):
7697
inputs = self.get_dummy_inputs()
7798
output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
7899
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
79100
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
80101
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.width}")
102+
output_slice = torch.from_numpy(np.array(output)[0, -3:, -3:].flatten())
103+
self.assertTrue(
104+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1e-2), f"{output_slice}"
105+
)
81106

82107
def test_output_type_pt_return_type_pt(self):
83108
inputs = self.get_dummy_inputs()
@@ -89,6 +114,11 @@ def test_output_type_pt_return_type_pt(self):
89114
self.assertEqual(
90115
output.shape[3], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}"
91116
)
117+
output_slice = output[0, 0, -3:, -3:].flatten()
118+
self.assertTrue(
119+
torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3),
120+
f"{output_slice}",
121+
)
92122

93123
def test_output_type_pt_partial_postprocess_return_type_pt(self):
94124
inputs = self.get_dummy_inputs()
@@ -100,6 +130,11 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self):
100130
self.assertEqual(
101131
output.shape[2], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[2]}"
102132
)
133+
output_slice = output[0, -3:, -3:, 0].flatten().cpu()
134+
self.assertTrue(
135+
torch_all_close(output_slice, self.partial_postprocess_return_pt_slice.to(output_slice.dtype), rtol=1e-2),
136+
f"{output_slice}",
137+
)
103138

104139
def test_do_scaling_deprecation(self):
105140
inputs = self.get_dummy_inputs()
@@ -133,6 +168,7 @@ def test_output_tensor_type_base64_deprecation(self):
133168
str(warning.warnings[0].message),
134169
)
135170

171+
136172
class RemoteAutoencoderKLSDv1Tests(
137173
RemoteAutoencoderKLMixin,
138174
unittest.TestCase,
@@ -152,6 +188,9 @@ class RemoteAutoencoderKLSDv1Tests(
152188
scaling_factor = 0.18215
153189
shift_factor = None
154190
processor_cls = VaeImageProcessor
191+
output_pt_slice = torch.tensor([31, 15, 11, 55, 30, 21, 66, 42, 30], dtype=torch.uint8)
192+
partial_postprocess_return_pt_slice = torch.tensor([100, 130, 99, 133, 106, 112, 97, 100, 121], dtype=torch.uint8)
193+
return_pt_slice = torch.tensor([-0.2177, 0.0217, -0.2258, 0.0412, -0.1687, -0.1232, -0.2416, -0.2130, -0.0543])
155194

156195

157196
class RemoteAutoencoderKLSDXLTests(
@@ -168,30 +207,36 @@ class RemoteAutoencoderKLSDXLTests(
168207
1024,
169208
1024,
170209
)
171-
endpoint = ""
210+
endpoint = "https://fagf07t3bwf0615i.us-east-1.aws.endpoints.huggingface.cloud/"
172211
dtype = torch.float16
173212
scaling_factor = 0.13025
174213
shift_factor = None
175214
processor_cls = VaeImageProcessor
176-
177-
178-
class RemoteAutoencoderKLFluxTests(
179-
RemoteAutoencoderKLMixin,
180-
unittest.TestCase,
181-
):
182-
# TODO: packed
183-
shape = (
184-
1,
185-
16,
186-
128,
187-
128,
188-
)
189-
out_hw = (
190-
1024,
191-
1024,
192-
)
193-
endpoint = ""
194-
dtype = torch.bfloat16
195-
scaling_factor = 0.3611
196-
shift_factor = 0.1159
197-
processor_cls = VaeImageProcessor
215+
output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8)
216+
partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8)
217+
return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845])
218+
219+
220+
# class RemoteAutoencoderKLFluxTests(
221+
# RemoteAutoencoderKLMixin,
222+
# unittest.TestCase,
223+
# ):
224+
# # TODO: packed
225+
# shape = (
226+
# 1,
227+
# 16,
228+
# 128,
229+
# 128,
230+
# )
231+
# out_hw = (
232+
# 1024,
233+
# 1024,
234+
# )
235+
# endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
236+
# dtype = torch.bfloat16
237+
# scaling_factor = 0.3611
238+
# shift_factor = 0.1159
239+
# processor_cls = VaeImageProcessor
240+
# output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8)
241+
# partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8)
242+
# return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845])

0 commit comments

Comments
 (0)