Skip to content

Commit f8f21f6

Browse files
committed
Add Dockerfile; client.py; deploy.py; deploy_requirements.txt; docker-build-ec2.yml.j2; download.py; entrypoint.sh; requirements.txt
1 parent ba4e04f commit f8f21f6

File tree

9 files changed

+1044
-2
lines changed

9 files changed

+1044
-2
lines changed

.gitignore

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,14 @@ weights/icon_caption_florence
33
weights/icon_detect/
44
.gradio
55
__pycache__
6+
7+
# Swap files
8+
*.swp
9+
10+
# Environment files
11+
.env
12+
.env.*
13+
14+
# Environment
15+
venv/
16+
*.pem

Dockerfile

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
FROM nvidia/cuda:12.3.1-devel-ubuntu22.04
2+
3+
# Install system dependencies with explicit OpenGL libraries
4+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
5+
git \
6+
git-lfs \
7+
wget \
8+
libgl1 \
9+
libglib2.0-0 \
10+
libsm6 \
11+
libxext6 \
12+
libxrender1 \
13+
libglu1-mesa \
14+
libglib2.0-0 \
15+
libsm6 \
16+
libxrender1 \
17+
libxext6 \
18+
python3-opencv \
19+
&& apt-get clean \
20+
&& rm -rf /var/lib/apt/lists/* \
21+
&& git lfs install
22+
23+
# Install Miniconda for Python 3.12
24+
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh && \
25+
bash miniconda.sh -b -p /opt/conda && \
26+
rm miniconda.sh
27+
ENV PATH="/opt/conda/bin:$PATH"
28+
29+
# Create and activate Conda environment with Python 3.12, and set it as the default
30+
RUN conda create -n omni python=3.12 && \
31+
echo "source activate omni" > ~/.bashrc
32+
ENV CONDA_DEFAULT_ENV=omni
33+
ENV PATH="/opt/conda/envs/omni/bin:$PATH"
34+
35+
# Set the working directory in the container
36+
WORKDIR /usr/src/app
37+
38+
# Copy project files and requirements
39+
COPY . .
40+
COPY requirements.txt /usr/src/app/requirements.txt
41+
42+
# Initialize Git LFS and pull LFS files
43+
RUN git lfs install && \
44+
git lfs pull
45+
46+
# Install dependencies from requirements.txt with specific opencv-python-headless version
47+
RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \
48+
pip uninstall -y opencv-python opencv-python-headless && \
49+
pip install --no-cache-dir opencv-python-headless==4.8.1.78 && \
50+
pip install -r requirements.txt && \
51+
pip install huggingface_hub
52+
53+
# Run download.py to fetch model weights and convert safetensors to .pt format
54+
RUN . /opt/conda/etc/profile.d/conda.sh && conda activate omni && \
55+
python download.py && \
56+
echo "Contents of weights directory:" && \
57+
ls -lR weights && \
58+
python weights/convert_safetensor_to_pt.py
59+
60+
# Expose the default Gradio port
61+
EXPOSE 7861
62+
63+
# Configure Gradio to be accessible externally
64+
ENV GRADIO_SERVER_NAME="0.0.0.0"
65+
66+
# Copy and set permissions for entrypoint script
67+
COPY entrypoint.sh /usr/src/app/entrypoint.sh
68+
RUN chmod +x /usr/src/app/entrypoint.sh
69+
70+
# To debug, keep the container running
71+
# CMD ["tail", "-f", "/dev/null"]
72+
73+
# Set the entrypoint
74+
ENTRYPOINT ["/usr/src/app/entrypoint.sh"]

client.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""
2+
This module provides a command-line interface to interact with the OmniParser Gradio server.
3+
4+
Usage:
5+
python client.py "http://<server_ip>:7861" "path/to/image.jpg"
6+
"""
7+
8+
import fire
9+
from gradio_client import Client
10+
from loguru import logger
11+
from PIL import Image
12+
import base64
13+
from io import BytesIO
14+
import os
15+
import shutil
16+
17+
def predict(server_url: str, image_path: str, box_threshold: float = 0.05, iou_threshold: float = 0.1):
18+
"""
19+
Makes a prediction using the OmniParser Gradio client with the provided server URL and image.
20+
21+
Args:
22+
server_url (str): The URL of the OmniParser Gradio server.
23+
image_path (str): Path to the image file to be processed.
24+
box_threshold (float): Box threshold value (default: 0.05).
25+
iou_threshold (float): IOU threshold value (default: 0.1).
26+
"""
27+
client = Client(server_url)
28+
29+
# Load and encode the image
30+
with open(image_path, "rb") as image_file:
31+
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
32+
33+
# Prepare the image input in the format expected by the server
34+
image_input = {
35+
"path": None,
36+
"url": f"data:image/png;base64,{encoded_image}",
37+
"size": None,
38+
"orig_name": image_path,
39+
"mime_type": "image/png",
40+
"is_stream": False,
41+
"meta": {}
42+
}
43+
44+
# Make the prediction
45+
try:
46+
result = client.predict(
47+
image_input, # image input as dictionary
48+
box_threshold, # box_threshold
49+
iou_threshold, # iou_threshold
50+
api_name="/process"
51+
)
52+
53+
# Process and log the results
54+
output_image, parsed_content = result
55+
56+
logger.info("Prediction completed successfully")
57+
logger.info(f"Parsed content:\n{parsed_content}")
58+
59+
# Save the output image
60+
output_image_path = "output_image.png"
61+
if isinstance(output_image, dict) and 'url' in output_image:
62+
# Handle base64 encoded image
63+
img_data = base64.b64decode(output_image['url'].split(',')[1])
64+
with open(output_image_path, 'wb') as f:
65+
f.write(img_data)
66+
elif isinstance(output_image, str):
67+
if output_image.startswith('data:image'):
68+
# Handle base64 encoded image string
69+
img_data = base64.b64decode(output_image.split(',')[1])
70+
with open(output_image_path, 'wb') as f:
71+
f.write(img_data)
72+
elif os.path.exists(output_image):
73+
# Handle file path
74+
shutil.copy(output_image, output_image_path)
75+
else:
76+
logger.warning(f"Unexpected output_image format: {output_image}")
77+
elif isinstance(output_image, Image.Image):
78+
output_image.save(output_image_path)
79+
else:
80+
logger.warning(f"Unexpected output_image format: {type(output_image)}")
81+
logger.warning(f"Output image content: {output_image[:100]}...") # Log the first 100 characters
82+
83+
if os.path.exists(output_image_path):
84+
logger.info(f"Output image saved to: {output_image_path}")
85+
else:
86+
logger.warning(f"Failed to save output image to: {output_image_path}")
87+
88+
except Exception as e:
89+
logger.error(f"An error occurred: {str(e)}")
90+
logger.exception("Traceback:")
91+
92+
if __name__ == "__main__":
93+
fire.Fire(predict)
94+

0 commit comments

Comments
 (0)