| 
 | 1 | +"""Client module for interacting with the OmniParser server."""  | 
 | 2 | + | 
 | 3 | +import base64  | 
 | 4 | +import fire  | 
 | 5 | +import requests  | 
 | 6 | + | 
 | 7 | +from loguru import logger  | 
 | 8 | +from PIL import Image, ImageDraw  | 
 | 9 | + | 
 | 10 | + | 
 | 11 | +def image_to_base64(image_path: str) -> str:  | 
 | 12 | +    """Convert an image file to base64 string.  | 
 | 13 | +
  | 
 | 14 | +    Args:  | 
 | 15 | +        image_path: Path to the image file  | 
 | 16 | +
  | 
 | 17 | +    Returns:  | 
 | 18 | +        str: Base64 encoded string of the image  | 
 | 19 | +    """  | 
 | 20 | +    with open(image_path, "rb") as image_file:  | 
 | 21 | +        return base64.b64encode(image_file.read()).decode("utf-8")  | 
 | 22 | + | 
 | 23 | + | 
 | 24 | +def plot_results(  | 
 | 25 | +    original_image_path: str,  | 
 | 26 | +    som_image_base64: str,  | 
 | 27 | +    parsed_content_list: list[dict[str, list[float]]],  | 
 | 28 | +) -> None:  | 
 | 29 | +    """Plot parsing results on the original image.  | 
 | 30 | +
  | 
 | 31 | +    Args:  | 
 | 32 | +        original_image_path: Path to the original image  | 
 | 33 | +        som_image_base64: Base64 encoded SOM image  | 
 | 34 | +        parsed_content_list: List of parsed content with bounding boxes  | 
 | 35 | +    """  | 
 | 36 | +    # Open original image  | 
 | 37 | +    image = Image.open(original_image_path)  | 
 | 38 | +    width, height = image.size  | 
 | 39 | + | 
 | 40 | +    # Create drawable image  | 
 | 41 | +    draw = ImageDraw.Draw(image)  | 
 | 42 | + | 
 | 43 | +    # Draw bounding boxes and labels  | 
 | 44 | +    for item in parsed_content_list:  | 
 | 45 | +        # Get normalized coordinates and convert to pixel coordinates  | 
 | 46 | +        x1, y1, x2, y2 = item["bbox"]  | 
 | 47 | +        x1 = int(x1 * width)  | 
 | 48 | +        y1 = int(y1 * height)  | 
 | 49 | +        x2 = int(x2 * width)  | 
 | 50 | +        y2 = int(y2 * height)  | 
 | 51 | + | 
 | 52 | +        label = item["content"]  | 
 | 53 | + | 
 | 54 | +        # Draw rectangle  | 
 | 55 | +        draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=2)  | 
 | 56 | + | 
 | 57 | +        # Draw label background  | 
 | 58 | +        text_bbox = draw.textbbox((x1, y1), label)  | 
 | 59 | +        draw.rectangle(  | 
 | 60 | +            [text_bbox[0] - 2, text_bbox[1] - 2, text_bbox[2] + 2, text_bbox[3] + 2],  | 
 | 61 | +            fill="white",  | 
 | 62 | +        )  | 
 | 63 | + | 
 | 64 | +        # Draw label text  | 
 | 65 | +        draw.text((x1, y1), label, fill="red")  | 
 | 66 | + | 
 | 67 | +    # Show image  | 
 | 68 | +    image.show()  | 
 | 69 | + | 
 | 70 | + | 
 | 71 | +def parse_image(  | 
 | 72 | +    image_path: str,  | 
 | 73 | +    server_url: str,  | 
 | 74 | +) -> None:  | 
 | 75 | +    """Parse an image using the OmniParser server.  | 
 | 76 | +
  | 
 | 77 | +    Args:  | 
 | 78 | +        image_path: Path to the image file  | 
 | 79 | +        server_url: URL of the OmniParser server  | 
 | 80 | +    """  | 
 | 81 | +    # Remove trailing slash from server_url if present  | 
 | 82 | +    server_url = server_url.rstrip("/")  | 
 | 83 | + | 
 | 84 | +    # Convert image to base64  | 
 | 85 | +    base64_image = image_to_base64(image_path)  | 
 | 86 | + | 
 | 87 | +    # Prepare request  | 
 | 88 | +    url = f"{server_url}/parse/"  | 
 | 89 | +    payload = {"base64_image": base64_image}  | 
 | 90 | + | 
 | 91 | +    try:  | 
 | 92 | +        # First, check if the server is available  | 
 | 93 | +        probe_url = f"{server_url}/probe/"  | 
 | 94 | +        probe_response = requests.get(probe_url)  | 
 | 95 | +        probe_response.raise_for_status()  | 
 | 96 | +        logger.info("Server is available")  | 
 | 97 | + | 
 | 98 | +        # Make request to API  | 
 | 99 | +        response = requests.post(url, json=payload)  | 
 | 100 | +        response.raise_for_status()  | 
 | 101 | + | 
 | 102 | +        # Parse response  | 
 | 103 | +        result = response.json()  | 
 | 104 | +        som_image_base64 = result["som_image_base64"]  | 
 | 105 | +        parsed_content_list = result["parsed_content_list"]  | 
 | 106 | + | 
 | 107 | +        # Plot results  | 
 | 108 | +        plot_results(image_path, som_image_base64, parsed_content_list)  | 
 | 109 | + | 
 | 110 | +        # Print latency  | 
 | 111 | +        logger.info(f"API Latency: {result['latency']:.2f} seconds")  | 
 | 112 | + | 
 | 113 | +    except requests.exceptions.ConnectionError:  | 
 | 114 | +        logger.error(f"Error: Could not connect to server at {server_url}")  | 
 | 115 | +        logger.error("Please check if the server is running and the URL is correct")  | 
 | 116 | +    except requests.exceptions.RequestException as e:  | 
 | 117 | +        logger.error(f"Error making request to API: {e}")  | 
 | 118 | +    except Exception as e:  | 
 | 119 | +        logger.error(f"Error: {e}")  | 
 | 120 | + | 
 | 121 | + | 
 | 122 | +def main() -> None:  | 
 | 123 | +    """Main entry point for the client application."""  | 
 | 124 | +    fire.Fire(parse_image)  | 
 | 125 | + | 
 | 126 | + | 
 | 127 | +if __name__ == "__main__":  | 
 | 128 | +    main()  | 
0 commit comments