Skip to content

Commit b8d2c3a

Browse files
committed
Refactor Gradio Demo Script for Clarity and Functionality
- Removed unused absolute path definitions for model and class names. - Simplified the `draw` and `prepare_image` functions for better readability. - Enhanced error handling in the `load_model` function. - Updated the Gradio interface to streamline model path selection and improve user experience. - Cleaned up commented code and improved variable naming for clarity.
1 parent f70c9de commit b8d2c3a

File tree

1 file changed

+57
-144
lines changed

1 file changed

+57
-144
lines changed

scripts/gradio_demo.py

Lines changed: 57 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,12 @@
77
import numpy as np
88
import onnxruntime as ort
99
import pandas as pd
10-
from PIL import Image, ImageDraw
10+
from PIL import Image
1111
import cv2
1212

1313
ort.preload_dlls()
1414

1515

16-
# Use absolute paths instead of relative paths
17-
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18-
MODEL_PATH = os.path.join(BASE_DIR, "models/deim-blood-cell-detection_nano.onnx")
19-
CLASS_NAMES_PATH = os.path.join(BASE_DIR, "models/classes.txt")
20-
2116
def generate_colors(num_classes):
2217
"""Generate a list of distinct colors for different classes."""
2318
# Generate evenly spaced hues
@@ -52,18 +47,22 @@ def draw(images, labels, boxes, scores, scales, paddings, thrh=0.4, class_names=
5247
# Scale boxes from padded size to original image size
5348
scale = scales[i]
5449
x_offset, y_offset = paddings[i]
55-
56-
valid_boxes[:, [0, 2]] = (valid_boxes[:, [0, 2]] - x_offset) / scale # x coordinates
57-
valid_boxes[:, [1, 3]] = (valid_boxes[:, [1, 3]] - y_offset) / scale # y coordinates
50+
51+
valid_boxes[:, [0, 2]] = (
52+
valid_boxes[:, [0, 2]] - x_offset
53+
) / scale # x coordinates
54+
valid_boxes[:, [1, 3]] = (
55+
valid_boxes[:, [1, 3]] - y_offset
56+
) / scale # y coordinates
5857

5958
# Draw boxes
6059
for label, box, score in zip(valid_labels, valid_boxes, valid_scores):
6160
class_idx = int(label)
6261
color = colors[class_idx % len(colors)]
63-
62+
6463
# Convert coordinates to integers
6564
box = [int(coord) for coord in box]
66-
65+
6766
# Draw rectangle
6867
cv2.rectangle(im, (box[0], box[1]), (box[2], box[3]), color, 2)
6968

@@ -125,35 +124,14 @@ def load_model(model_path):
125124
print(f"Loading model from: {model_path}")
126125
if not os.path.exists(model_path):
127126
return None, f"Model file not found at: {model_path}"
128-
127+
129128
sess = ort.InferenceSession(model_path, providers=providers)
130129
print(f"Using device: {ort.get_device()}")
131130
return sess, None
132131
except Exception as e:
133132
return None, f"Error creating inference session: {e}"
134133

135134

136-
def get_classes_path(custom_path, default_path):
137-
"""
138-
Get class names file path.
139-
140-
Args:
141-
custom_path: Custom path to class names file
142-
default_path: Default path to class names file
143-
144-
Returns:
145-
Path to a class names file
146-
"""
147-
if not custom_path:
148-
return default_path
149-
150-
# Treat as a file path
151-
if os.path.exists(custom_path):
152-
return custom_path
153-
154-
return default_path
155-
156-
157135
def load_class_names(class_names_path):
158136
"""
159137
Load class names from a text file.
@@ -180,31 +158,23 @@ def load_class_names(class_names_path):
180158
def prepare_image(image, target_size=640):
181159
"""
182160
Prepare image for inference by converting to PIL and resizing with padding.
183-
184-
Args:
185-
image: Input image (PIL or numpy array)
186-
target_size: Target size for resizing (default: 640)
187-
188-
Returns:
189-
tuple: (model_input, original_image, scale, padding)
190161
"""
191162
# Convert to numpy array if PIL Image
192163
if isinstance(image, Image.Image):
193164
image = np.array(image)
194-
195-
# Calculate scaling and padding
165+
196166
height, width = image.shape[:2]
197167
scale = target_size / max(height, width)
198-
new_height = int(height * scale)
199-
new_width = int(width * scale)
168+
new_height, new_width = int(height * scale), int(width * scale)
200169

201170
# Calculate padding
202-
y_offset = (target_size - new_height) // 2
203-
x_offset = (target_size - new_width) // 2
171+
y_offset, x_offset = [(target_size - dim) // 2 for dim in (new_height, new_width)]
204172

205173
# Create model input with padding
206174
model_input = np.zeros((target_size, target_size, 3), dtype=np.uint8)
207-
model_input[y_offset:y_offset + new_height, x_offset:x_offset + new_width] = cv2.resize(image, (new_width, new_height))
175+
model_input[y_offset : y_offset + new_height, x_offset : x_offset + new_width] = (
176+
cv2.resize(image, (new_width, new_height))
177+
)
208178

209179
return model_input, image, scale, (x_offset, y_offset)
210180

@@ -230,7 +200,9 @@ def run_inference(session, image, target_size=640):
230200
dtype=np.float32,
231201
)
232202
im_data = np.expand_dims(im_data, axis=0) # Add batch dimension
233-
orig_size = np.array([[target_size, target_size]], dtype=np.int64) # Use padded size
203+
orig_size = np.array(
204+
[[target_size, target_size]], dtype=np.int64
205+
) # Use padded size
234206

235207
# Get input name and run inference
236208
input_name = session.get_inputs()[0].name
@@ -320,7 +292,7 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
320292
model_load_start = time.time()
321293
session, error = load_model(model_path)
322294
model_load_time = time.time() - model_load_start
323-
295+
324296
if error:
325297
return None, error, None
326298

@@ -337,10 +309,10 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
337309
inference_start = time.time()
338310
outputs = run_inference(session, model_input, image_size)
339311
inference_time = time.time() - inference_start
340-
312+
341313
if not outputs or len(outputs) < 3:
342314
return None, "Error: Model output is invalid", None
343-
315+
344316
labels, boxes, scores = outputs
345317

346318
# Draw detections
@@ -363,17 +335,18 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
363335
# Create status message with timing information
364336
status_message = create_status_message(object_counts)
365337
status_message += "\n\nLatency Information:"
366-
status_message += f"\n- Model Loading: {model_load_time*1000:.1f}ms"
367-
status_message += f"\n- Preprocessing: {preprocess_time*1000:.1f}ms"
368-
status_message += f"\n- Inference: {inference_time*1000:.1f}ms"
369-
status_message += f"\n- Postprocessing: {postprocess_time*1000:.1f}ms"
370-
status_message += f"\n- Total Time: {(model_load_time + preprocess_time + inference_time + postprocess_time)*1000:.1f}ms"
371-
338+
status_message += f"\n- Model Loading: {model_load_time * 1000:.1f}ms"
339+
status_message += f"\n- Preprocessing: {preprocess_time * 1000:.1f}ms"
340+
status_message += f"\n- Inference: {inference_time * 1000:.1f}ms"
341+
status_message += f"\n- Postprocessing: {postprocess_time * 1000:.1f}ms"
342+
status_message += f"\n- Total Time: {(model_load_time + preprocess_time + inference_time + postprocess_time) * 1000:.1f}ms"
343+
372344
bar_data = create_bar_data(object_counts)
373345

374346
return result_images[0], status_message, bar_data
375347
except Exception as e:
376348
import traceback
349+
377350
error_details = traceback.format_exc()
378351
print(f"Error during inference: {error_details}")
379352
return None, f"Error during inference: {str(e)}", None
@@ -382,37 +355,15 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
382355
def build_interface(model_path, class_names_path, example_images=None):
383356
"""
384357
Build the Gradio interface components.
385-
386-
Args:
387-
model_path: Path to the ONNX model
388-
class_names_path: Path to the class names file
389-
example_images: List of example image paths
390-
391-
Returns:
392-
gr.Blocks: The Gradio demo interface
393358
"""
394359
with gr.Blocks(title="DEIMKit Detection") as demo:
395360
gr.Markdown("# DEIMKit Detection")
396-
gr.Markdown("Configure the model and run inference on an image.")
397-
398-
# Add model selection
399-
with gr.Accordion("Model Settings", open=False):
400-
with gr.Row():
401-
custom_model_path = gr.File(
402-
label="Custom Model File (ONNX)",
403-
file_types=[".onnx"],
404-
file_count="single"
405-
)
406-
custom_classes_path = gr.File(
407-
label="Custom Classes File (TXT)",
408-
file_types=[".txt"],
409-
file_count="single"
410-
)
361+
gr.Markdown("Upload an image and run inference.")
411362

412363
with gr.Row():
413364
with gr.Column():
414365
input_image = gr.Image(type="pil", label="Input Image")
415-
366+
416367
with gr.Row():
417368
confidence = gr.Slider(
418369
minimum=0.1,
@@ -421,24 +372,23 @@ def build_interface(model_path, class_names_path, example_images=None):
421372
step=0.01,
422373
label="Confidence Threshold",
423374
)
424-
375+
425376
image_size = gr.Slider(
426377
minimum=32,
427378
maximum=1920,
428379
value=640,
429380
step=32,
430381
label="Image Size",
431-
info="Select image size for inference (larger = slower but potentially more accurate)"
382+
info="Select image size for inference (larger = slower but potentially more accurate)",
432383
)
433-
384+
434385
submit_btn = gr.Button("Run Inference", variant="primary")
435386

436387
with gr.Column():
437388
output_image = gr.Image(type="pil", label="Detection Result")
438389

439390
with gr.Row(equal_height=True):
440391
output_message = gr.Textbox(label="Status")
441-
442392
count_plot = gr.BarPlot(
443393
x="Class",
444394
y="Count",
@@ -448,71 +398,43 @@ def build_interface(model_path, class_names_path, example_images=None):
448398
orientation="h",
449399
label_title="Object Counts",
450400
)
451-
401+
452402
# Add examples component if example images are provided
453403
if example_images:
454404
gr.Examples(
455405
examples=example_images,
456406
inputs=input_image,
457407
)
458408

459-
# Function to handle model path selection
460-
def get_model_path(custom_file, default_path):
461-
if custom_file is not None:
462-
return custom_file.name
463-
return default_path
464-
465-
def get_classes_path(custom_file, default_path):
466-
if custom_file is not None:
467-
return custom_file.name
468-
return default_path
469-
470-
# Set up the click event inside the Blocks context
409+
# Set up the click event
471410
submit_btn.click(
472-
fn=lambda img, custom_model, custom_classes, conf, img_size: predict(
411+
fn=lambda img, conf, img_size: predict(
473412
img,
474-
get_model_path(custom_model, model_path),
475-
get_classes_path(custom_classes, class_names_path),
413+
model_path,
414+
class_names_path,
476415
conf,
477-
img_size
416+
img_size,
478417
),
479-
inputs=[
480-
input_image,
481-
custom_model_path,
482-
custom_classes_path,
483-
confidence,
484-
image_size,
485-
],
418+
inputs=[input_image, confidence, image_size],
486419
outputs=[output_image, output_message, count_plot],
487420
)
488421

489422
with gr.Row():
490423
with gr.Column():
491-
gr.HTML("<div style='text-align: center; margin: 0 auto;'>Created by <a href='https://dicksonneoh.com' target='_blank'>Dickson Neoh</a>.</div>")
424+
gr.HTML(
425+
"<div style='text-align: center; margin: 0 auto;'>Created by <a href='https://dicksonneoh.com' target='_blank'>Dickson Neoh</a>.</div>"
426+
)
492427

493428
return demo
494429

495430

496431
def parse_args():
497432
"""Parse command line arguments."""
498-
parser = argparse.ArgumentParser(description='DEIMKit Detection Demo')
499-
parser.add_argument(
500-
'--model',
501-
type=str,
502-
default=MODEL_PATH,
503-
help='Path to ONNX model file'
504-
)
505-
parser.add_argument(
506-
'--classes',
507-
type=str,
508-
default=CLASS_NAMES_PATH,
509-
help='Path to class names file'
510-
)
433+
parser = argparse.ArgumentParser(description="DEIMKit Detection Demo")
434+
parser.add_argument("--model", type=str, required=True, help="Path to ONNX model file")
435+
parser.add_argument("--classes", type=str, required=True, help="Path to class names file")
511436
parser.add_argument(
512-
'--examples',
513-
type=str,
514-
default=os.path.join(BASE_DIR, "examples"),
515-
help='Path to directory containing example images'
437+
"--examples", type=str, help="Path to directory containing example images (optional)"
516438
)
517439
return parser.parse_args()
518440

@@ -522,27 +444,18 @@ def launch_demo():
522444
Launch the Gradio demo with model and class names paths from command line arguments.
523445
"""
524446
args = parse_args()
525-
526-
# Create examples directory if it doesn't exist
527-
examples_dir = args.examples
528-
if not os.path.exists(examples_dir):
529-
os.makedirs(examples_dir)
530-
print(f"Created examples directory at {examples_dir}")
531-
532-
# Get list of example images
447+
533448
example_images = []
534-
if os.path.exists(examples_dir):
449+
if args.examples and os.path.exists(args.examples):
535450
example_images = [
536-
os.path.join(examples_dir, f)
537-
for f in os.listdir(examples_dir)
538-
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
451+
os.path.join(args.examples, f)
452+
for f in os.listdir(args.examples)
453+
if f.lower().endswith((".png", ".jpg", ".jpeg"))
539454
]
540-
print(f"Found {len(example_images)} example images in {examples_dir}")
541-
455+
print(f"Found {len(example_images)} example images in {args.examples}")
456+
542457
demo = build_interface(args.model, args.classes, example_images)
543-
544-
# Launch the demo without the examples parameter
545-
demo.launch(share=False, inbrowser=True) # Set share=True if you want to create a shareable link
458+
demo.launch(share=False, inbrowser=True)
546459

547460

548461
if __name__ == "__main__":

0 commit comments

Comments
 (0)