Skip to content

Commit 54b8b47

Browse files
committed
add client.predict and documentation
1 parent 9cce7d7 commit 54b8b47

File tree

2 files changed

+47
-33
lines changed

2 files changed

+47
-33
lines changed

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# - Git LFS for large file support.
77
# - Required libraries: OpenCV, Hugging Face, Gradio, OpenGL.
88
# - Gradio server on port 7861.
9-
9+
#
1010
# 1. Build the image with CUDA support.
1111
# Example:
1212
# ```bash
1313
# sudo nvidia-docker build -t omniparser .
1414
# ```
15-
15+
#
1616
# 2. Run the Docker container with GPU access and port mapping for Gradio.
1717
# Example:
1818
# ```bash

client.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
2-
This module provides a command-line interface to interact with the OmniParser Gradio server.
2+
This module provides a command-line interface and programmatic API to interact with the OmniParser Gradio server.
33
4-
Usage:
4+
Command-line usage:
55
python client.py "http://<server_ip>:7861" "path/to/image.jpg"
66
77
View results:
@@ -11,6 +11,10 @@
1111
Windows: start output_image_<timestamp>.png
1212
Linux: xdg-open output_image_<timestamp>.png
1313
14+
Programmatic usage:
15+
from omniparse.client import predict
16+
result = predict("http://<server_ip>:7861", "path/to/image.jpg")
17+
1418
Result data format:
1519
{
1620
"label_coordinates": {
@@ -33,30 +37,31 @@
3337
import fire
3438
from gradio_client import Client
3539
from loguru import logger
36-
from PIL import Image
3740
import base64
38-
from io import BytesIO
3941
import os
4042
import shutil
4143
import json
4244
from datetime import datetime
4345

44-
def predict(server_url: str, image_path: str, box_threshold: float = 0.05, iou_threshold: float = 0.1):
46+
# Define constants for default thresholds
47+
DEFAULT_BOX_THRESHOLD = 0.05
48+
DEFAULT_IOU_THRESHOLD = 0.1
49+
50+
def predict(server_url: str, image_path: str, box_threshold: float = DEFAULT_BOX_THRESHOLD, iou_threshold: float = DEFAULT_IOU_THRESHOLD):
4551
"""
4652
Makes a prediction using the OmniParser Gradio client with the provided server URL and image.
47-
4853
Args:
4954
server_url (str): The URL of the OmniParser Gradio server.
5055
image_path (str): Path to the image file to be processed.
5156
box_threshold (float): Box threshold value (default: 0.05).
5257
iou_threshold (float): IOU threshold value (default: 0.1).
58+
Returns:
59+
dict: Parsed result data containing label coordinates and parsed content list.
5360
"""
5461
client = Client(server_url)
55-
56-
# Generate a timestamp for unique file naming
57-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
5862

5963
# Load and encode the image
64+
image_path = os.path.expanduser(image_path)
6065
with open(image_path, "rb") as image_file:
6166
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
6267

@@ -72,47 +77,56 @@ def predict(server_url: str, image_path: str, box_threshold: float = 0.05, iou_t
7277
}
7378

7479
# Make the prediction
75-
try:
76-
result = client.predict(
77-
image_input, # image input as dictionary
78-
box_threshold, # box_threshold
79-
iou_threshold, # iou_threshold
80-
api_name="/process"
81-
)
80+
result = client.predict(
81+
image_input,
82+
box_threshold,
83+
iou_threshold,
84+
api_name="/process"
85+
)
8286

83-
# Process and log the results
84-
output_image, result_json = result
85-
86-
logger.info("Prediction completed successfully")
87+
# Process and return the result
88+
output_image, result_json = result
89+
result_data = json.loads(result_json)
8790

88-
# Parse the JSON string into a Python object
89-
result_data = json.loads(result_json)
91+
return {"output_image": output_image, "result_data": result_data}
9092

91-
# Extract label_coordinates and parsed_content_list
92-
label_coordinates = result_data['label_coordinates']
93-
parsed_content_list = result_data['parsed_content_list']
9493

95-
logger.info(f"{label_coordinates=}")
96-
logger.info(f"{parsed_content_list=}")
94+
def predict_and_save(server_url: str, image_path: str, box_threshold: float = DEFAULT_BOX_THRESHOLD, iou_threshold: float = DEFAULT_IOU_THRESHOLD):
95+
"""
96+
Makes a prediction and saves the results to files, including logs and image outputs.
97+
Args:
98+
server_url (str): The URL of the OmniParser Gradio server.
99+
image_path (str): Path to the image file to be processed.
100+
box_threshold (float): Box threshold value (default: 0.05).
101+
iou_threshold (float): IOU threshold value (default: 0.1).
102+
"""
103+
# Generate a timestamp for unique file naming
104+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
105+
106+
# Call the predict function to get prediction data
107+
try:
108+
result = predict(server_url, image_path, box_threshold, iou_threshold)
109+
output_image = result["output_image"]
110+
result_data = result["result_data"]
97111

98112
# Save result data to JSON file
99113
result_data_path = f"result_data_{timestamp}.json"
100114
with open(result_data_path, "w") as json_file:
101115
json.dump(result_data, json_file, indent=4)
102116
logger.info(f"Parsed content saved to: {result_data_path}")
103-
117+
104118
# Save the output image
105119
output_image_path = f"output_image_{timestamp}.png"
106120
if isinstance(output_image, str) and os.path.exists(output_image):
107121
shutil.copy(output_image, output_image_path)
108122
logger.info(f"Output image saved to: {output_image_path}")
109123
else:
110124
logger.warning(f"Unexpected output_image format or file not found: {output_image}")
111-
125+
112126
except Exception as e:
113127
logger.error(f"An error occurred: {str(e)}")
114128
logger.exception("Traceback:")
115129

116-
if __name__ == "__main__":
117-
fire.Fire(predict)
118130

131+
if __name__ == "__main__":
132+
fire.Fire(predict_and_save)

0 commit comments

Comments
 (0)