Skip to content

Commit c2a2daf

Browse files
committed
init test_remote_decode
1 parent 05b39ab commit c2a2daf

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

tests/remote/__init__.py

Whitespace-only changes.

tests/remote/test_remote_decode.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
from typing import Tuple, Union
18+
19+
import PIL.Image
20+
import torch
21+
22+
from diffusers.image_processor import VaeImageProcessor
23+
from diffusers.utils.remote_utils import remote_decode
24+
from diffusers.utils.testing_utils import (
25+
enable_full_determinism,
26+
torch_device,
27+
)
28+
from diffusers.video_processor import VideoProcessor
29+
30+
31+
enable_full_determinism()
32+
33+
34+
class RemoteAutoencoderKLMixin:
35+
shape: Tuple[int, ...] = None
36+
out_hw: Tuple[int, int] = None
37+
endpoint: str = None
38+
dtype: torch.dtype = None
39+
scale_factor: float = None
40+
shift_factor: float = None
41+
processor_cls: Union[VaeImageProcessor, VideoProcessor] = None
42+
43+
def get_dummy_inputs(self):
44+
inputs = {
45+
"endpoint": self.endpoint,
46+
"tensor": torch.randn(self.shape, device=torch_device, dtype=self.dtype),
47+
}
48+
return inputs
49+
50+
def test_output_type_pt(self):
51+
inputs = self.get_dummy_inputs()
52+
processor = self.processor_cls()
53+
output = remote_decode(output_type="pt", processor=processor, **inputs)
54+
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
55+
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
56+
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.height}")
57+
58+
def test_output_type_pil(self):
59+
inputs = self.get_dummy_inputs()
60+
output = remote_decode(output_type="pil", **inputs)
61+
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
62+
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
63+
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.height}")
64+
65+
def test_output_type_pil_image_format(self):
66+
inputs = self.get_dummy_inputs()
67+
output = remote_decode(output_type="pil", image_format="png", **inputs)
68+
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
69+
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
70+
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.height}")
71+
self.assertEqual(output.format, "png", f"Expected image format `png`, got {output.format}")
72+
73+
def test_output_type_pt_partial_postprocess(self):
74+
inputs = self.get_dummy_inputs()
75+
output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
76+
self.assertTrue(isinstance(output, PIL.Image.Image), f"Expected `PIL.Image.Image` output, got {type(output)}")
77+
self.assertEqual(output.height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.height}")
78+
self.assertEqual(output.width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.height}")
79+
80+
81+
class RemoteAutoencoderKLSDv1Tests(
82+
RemoteAutoencoderKLMixin,
83+
unittest.TestCase,
84+
):
85+
shape = (
86+
1,
87+
4,
88+
64,
89+
64,
90+
)
91+
out_hw = (
92+
512,
93+
512,
94+
)
95+
endpoint = "https://bz0b3zkoojf30bhx.us-east-1.aws.endpoints.huggingface.cloud/"
96+
dtype = torch.float16
97+
scale_factor = 0.18215
98+
shift_factor = None
99+
processor_cls = VaeImageProcessor

0 commit comments

Comments
 (0)