|  | 
|  | 1 | +import base64 | 
|  | 2 | +import io | 
|  | 3 | +import json | 
|  | 4 | +from typing import List, Literal, Optional, Union, cast | 
|  | 5 | + | 
|  | 6 | +import requests | 
|  | 7 | +import torch | 
|  | 8 | +from PIL import Image | 
|  | 9 | +from safetensors.torch import _tobytes | 
|  | 10 | + | 
|  | 11 | +from ..image_processor import VaeImageProcessor | 
|  | 12 | +from ..video_processor import VideoProcessor | 
|  | 13 | + | 
|  | 14 | + | 
|  | 15 | +DTYPE_MAP = { | 
|  | 16 | +    "float16": torch.float16, | 
|  | 17 | +    "float32": torch.float32, | 
|  | 18 | +    "bfloat16": torch.bfloat16, | 
|  | 19 | +    "uint8": torch.uint8, | 
|  | 20 | +} | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +def remote_decode( | 
|  | 24 | +    endpoint: str, | 
|  | 25 | +    tensor: torch.Tensor, | 
|  | 26 | +    processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None, | 
|  | 27 | +    do_scaling: bool = True, | 
|  | 28 | +    output_type: Literal["mp4", "pil", "pt"] = "pil", | 
|  | 29 | +    image_format: Literal["png", "jpg"] = "jpg", | 
|  | 30 | +    partial_postprocess: bool = False, | 
|  | 31 | +    input_tensor_type: Literal["base64", "binary"] = "base64", | 
|  | 32 | +    output_tensor_type: Literal["base64", "binary"] = "base64", | 
|  | 33 | +    height: Optional[int] = None, | 
|  | 34 | +    width: Optional[int] = None, | 
|  | 35 | +) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]: | 
|  | 36 | +    if tensor.ndim == 3 and height is None and width is None: | 
|  | 37 | +        raise ValueError("`height` and `width` required for packed latents.") | 
|  | 38 | +    if output_type == "pt" and partial_postprocess is False and processor is None: | 
|  | 39 | +        raise ValueError("`processor` is required with `output_type='pt' and `partial_postprocess=False`.") | 
|  | 40 | +    headers = {} | 
|  | 41 | +    parameters = { | 
|  | 42 | +        "do_scaling": do_scaling, | 
|  | 43 | +        "output_type": output_type, | 
|  | 44 | +        "partial_postprocess": partial_postprocess, | 
|  | 45 | +        "shape": list(tensor.shape), | 
|  | 46 | +        "dtype": str(tensor.dtype).split(".")[-1], | 
|  | 47 | +    } | 
|  | 48 | +    if height is not None and width is not None: | 
|  | 49 | +        parameters["height"] = height | 
|  | 50 | +        parameters["width"] = width | 
|  | 51 | +    tensor_data = _tobytes(tensor, "tensor") | 
|  | 52 | +    if input_tensor_type == "base64": | 
|  | 53 | +        headers["Content-Type"] = "tensor/base64" | 
|  | 54 | +    elif input_tensor_type == "binary": | 
|  | 55 | +        headers["Content-Type"] = "tensor/binary" | 
|  | 56 | +    if output_type == "pil" and image_format == "jpg" and processor is None: | 
|  | 57 | +        headers["Accept"] = "image/jpeg" | 
|  | 58 | +    elif output_type == "pil" and image_format == "png" and processor is None: | 
|  | 59 | +        headers["Accept"] = "image/png" | 
|  | 60 | +    elif (output_tensor_type == "base64" and output_type == "pt") or ( | 
|  | 61 | +        output_tensor_type == "base64" and output_type == "pil" and processor is not None | 
|  | 62 | +    ): | 
|  | 63 | +        headers["Accept"] = "tensor/base64" | 
|  | 64 | +    elif (output_tensor_type == "binary" and output_type == "pt") or ( | 
|  | 65 | +        output_tensor_type == "binary" and output_type == "pil" and processor is not None | 
|  | 66 | +    ): | 
|  | 67 | +        headers["Accept"] = "tensor/binary" | 
|  | 68 | +    elif output_type == "mp4": | 
|  | 69 | +        headers["Accept"] = "text/plain" | 
|  | 70 | +    if input_tensor_type == "base64": | 
|  | 71 | +        kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}} | 
|  | 72 | +    elif input_tensor_type == "binary": | 
|  | 73 | +        kwargs = {"data": tensor_data} | 
|  | 74 | +    response = requests.post(endpoint, params=parameters, **kwargs, headers=headers) | 
|  | 75 | +    if not response.ok: | 
|  | 76 | +        raise RuntimeError(response.json()) | 
|  | 77 | +    if output_type == "pt" or (output_type == "pil" and processor is not None): | 
|  | 78 | +        if output_tensor_type == "base64": | 
|  | 79 | +            content = response.json() | 
|  | 80 | +            output_tensor = base64.b64decode(content["inputs"]) | 
|  | 81 | +            parameters = content["parameters"] | 
|  | 82 | +            shape = parameters["shape"] | 
|  | 83 | +            dtype = parameters["dtype"] | 
|  | 84 | +        elif output_tensor_type == "binary": | 
|  | 85 | +            output_tensor = response.content | 
|  | 86 | +            parameters = response.headers | 
|  | 87 | +            shape = json.loads(parameters["shape"]) | 
|  | 88 | +            dtype = parameters["dtype"] | 
|  | 89 | +        torch_dtype = DTYPE_MAP[dtype] | 
|  | 90 | +        output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) | 
|  | 91 | +    if output_type == "pt": | 
|  | 92 | +        if partial_postprocess: | 
|  | 93 | +            output = [Image.fromarray(image.numpy()) for image in output_tensor] | 
|  | 94 | +            if len(output) == 1: | 
|  | 95 | +                output = output[0] | 
|  | 96 | +        else: | 
|  | 97 | +            if processor is None: | 
|  | 98 | +                output = output_tensor | 
|  | 99 | +            else: | 
|  | 100 | +                if isinstance(processor, VideoProcessor): | 
|  | 101 | +                    output = cast( | 
|  | 102 | +                        List[Image.Image], | 
|  | 103 | +                        processor.postprocess_video(output_tensor, output_type="pil")[0], | 
|  | 104 | +                    ) | 
|  | 105 | +                else: | 
|  | 106 | +                    output = cast( | 
|  | 107 | +                        Image.Image, | 
|  | 108 | +                        processor.postprocess(output_tensor, output_type="pil")[0], | 
|  | 109 | +                    ) | 
|  | 110 | +    elif output_type == "pil" and processor is None: | 
|  | 111 | +        output = Image.open(io.BytesIO(response.content)).convert("RGB") | 
|  | 112 | +    elif output_type == "pil" and processor is not None: | 
|  | 113 | +        output = [ | 
|  | 114 | +            Image.fromarray(image) | 
|  | 115 | +            for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8") | 
|  | 116 | +        ] | 
|  | 117 | +    elif output_type == "mp4": | 
|  | 118 | +        output = response.content | 
|  | 119 | +    return output | 
0 commit comments