77import numpy as np
88import onnxruntime as ort
99import pandas as pd
10- from PIL import Image , ImageDraw
10+ from PIL import Image
1111import cv2
1212
1313ort .preload_dlls ()
1414
1515
16- # Use absolute paths instead of relative paths
17- BASE_DIR = os .path .dirname (os .path .abspath (__file__ ))
18- MODEL_PATH = os .path .join (BASE_DIR , "models/deim-blood-cell-detection_nano.onnx" )
19- CLASS_NAMES_PATH = os .path .join (BASE_DIR , "models/classes.txt" )
20-
2116def generate_colors (num_classes ):
2217 """Generate a list of distinct colors for different classes."""
2318 # Generate evenly spaced hues
@@ -52,18 +47,22 @@ def draw(images, labels, boxes, scores, scales, paddings, thrh=0.4, class_names=
5247 # Scale boxes from padded size to original image size
5348 scale = scales [i ]
5449 x_offset , y_offset = paddings [i ]
55-
56- valid_boxes [:, [0 , 2 ]] = (valid_boxes [:, [0 , 2 ]] - x_offset ) / scale # x coordinates
57- valid_boxes [:, [1 , 3 ]] = (valid_boxes [:, [1 , 3 ]] - y_offset ) / scale # y coordinates
50+
51+ valid_boxes [:, [0 , 2 ]] = (
52+ valid_boxes [:, [0 , 2 ]] - x_offset
53+ ) / scale # x coordinates
54+ valid_boxes [:, [1 , 3 ]] = (
55+ valid_boxes [:, [1 , 3 ]] - y_offset
56+ ) / scale # y coordinates
5857
5958 # Draw boxes
6059 for label , box , score in zip (valid_labels , valid_boxes , valid_scores ):
6160 class_idx = int (label )
6261 color = colors [class_idx % len (colors )]
63-
62+
6463 # Convert coordinates to integers
6564 box = [int (coord ) for coord in box ]
66-
65+
6766 # Draw rectangle
6867 cv2 .rectangle (im , (box [0 ], box [1 ]), (box [2 ], box [3 ]), color , 2 )
6968
@@ -125,35 +124,14 @@ def load_model(model_path):
125124 print (f"Loading model from: { model_path } " )
126125 if not os .path .exists (model_path ):
127126 return None , f"Model file not found at: { model_path } "
128-
127+
129128 sess = ort .InferenceSession (model_path , providers = providers )
130129 print (f"Using device: { ort .get_device ()} " )
131130 return sess , None
132131 except Exception as e :
133132 return None , f"Error creating inference session: { e } "
134133
135134
136- def get_classes_path (custom_path , default_path ):
137- """
138- Get class names file path.
139-
140- Args:
141- custom_path: Custom path to class names file
142- default_path: Default path to class names file
143-
144- Returns:
145- Path to a class names file
146- """
147- if not custom_path :
148- return default_path
149-
150- # Treat as a file path
151- if os .path .exists (custom_path ):
152- return custom_path
153-
154- return default_path
155-
156-
157135def load_class_names (class_names_path ):
158136 """
159137 Load class names from a text file.
@@ -180,31 +158,23 @@ def load_class_names(class_names_path):
180158def prepare_image (image , target_size = 640 ):
181159 """
182160 Prepare image for inference by converting to PIL and resizing with padding.
183-
184- Args:
185- image: Input image (PIL or numpy array)
186- target_size: Target size for resizing (default: 640)
187-
188- Returns:
189- tuple: (model_input, original_image, scale, padding)
190161 """
191162 # Convert to numpy array if PIL Image
192163 if isinstance (image , Image .Image ):
193164 image = np .array (image )
194-
195- # Calculate scaling and padding
165+
196166 height , width = image .shape [:2 ]
197167 scale = target_size / max (height , width )
198- new_height = int (height * scale )
199- new_width = int (width * scale )
168+ new_height , new_width = int (height * scale ), int (width * scale )
200169
201170 # Calculate padding
202- y_offset = (target_size - new_height ) // 2
203- x_offset = (target_size - new_width ) // 2
171+ y_offset , x_offset = [(target_size - dim ) // 2 for dim in (new_height , new_width )]
204172
205173 # Create model input with padding
206174 model_input = np .zeros ((target_size , target_size , 3 ), dtype = np .uint8 )
207- model_input [y_offset :y_offset + new_height , x_offset :x_offset + new_width ] = cv2 .resize (image , (new_width , new_height ))
175+ model_input [y_offset : y_offset + new_height , x_offset : x_offset + new_width ] = (
176+ cv2 .resize (image , (new_width , new_height ))
177+ )
208178
209179 return model_input , image , scale , (x_offset , y_offset )
210180
@@ -230,7 +200,9 @@ def run_inference(session, image, target_size=640):
230200 dtype = np .float32 ,
231201 )
232202 im_data = np .expand_dims (im_data , axis = 0 ) # Add batch dimension
233- orig_size = np .array ([[target_size , target_size ]], dtype = np .int64 ) # Use padded size
203+ orig_size = np .array (
204+ [[target_size , target_size ]], dtype = np .int64
205+ ) # Use padded size
234206
235207 # Get input name and run inference
236208 input_name = session .get_inputs ()[0 ].name
@@ -320,7 +292,7 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
320292 model_load_start = time .time ()
321293 session , error = load_model (model_path )
322294 model_load_time = time .time () - model_load_start
323-
295+
324296 if error :
325297 return None , error , None
326298
@@ -337,10 +309,10 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
337309 inference_start = time .time ()
338310 outputs = run_inference (session , model_input , image_size )
339311 inference_time = time .time () - inference_start
340-
312+
341313 if not outputs or len (outputs ) < 3 :
342314 return None , "Error: Model output is invalid" , None
343-
315+
344316 labels , boxes , scores = outputs
345317
346318 # Draw detections
@@ -363,17 +335,18 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
363335 # Create status message with timing information
364336 status_message = create_status_message (object_counts )
365337 status_message += "\n \n Latency Information:"
366- status_message += f"\n - Model Loading: { model_load_time * 1000 :.1f} ms"
367- status_message += f"\n - Preprocessing: { preprocess_time * 1000 :.1f} ms"
368- status_message += f"\n - Inference: { inference_time * 1000 :.1f} ms"
369- status_message += f"\n - Postprocessing: { postprocess_time * 1000 :.1f} ms"
370- status_message += f"\n - Total Time: { (model_load_time + preprocess_time + inference_time + postprocess_time )* 1000 :.1f} ms"
371-
338+ status_message += f"\n - Model Loading: { model_load_time * 1000 :.1f} ms"
339+ status_message += f"\n - Preprocessing: { preprocess_time * 1000 :.1f} ms"
340+ status_message += f"\n - Inference: { inference_time * 1000 :.1f} ms"
341+ status_message += f"\n - Postprocessing: { postprocess_time * 1000 :.1f} ms"
342+ status_message += f"\n - Total Time: { (model_load_time + preprocess_time + inference_time + postprocess_time ) * 1000 :.1f} ms"
343+
372344 bar_data = create_bar_data (object_counts )
373345
374346 return result_images [0 ], status_message , bar_data
375347 except Exception as e :
376348 import traceback
349+
377350 error_details = traceback .format_exc ()
378351 print (f"Error during inference: { error_details } " )
379352 return None , f"Error during inference: { str (e )} " , None
@@ -382,37 +355,15 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
382355def build_interface (model_path , class_names_path , example_images = None ):
383356 """
384357 Build the Gradio interface components.
385-
386- Args:
387- model_path: Path to the ONNX model
388- class_names_path: Path to the class names file
389- example_images: List of example image paths
390-
391- Returns:
392- gr.Blocks: The Gradio demo interface
393358 """
394359 with gr .Blocks (title = "DEIMKit Detection" ) as demo :
395360 gr .Markdown ("# DEIMKit Detection" )
396- gr .Markdown ("Configure the model and run inference on an image." )
397-
398- # Add model selection
399- with gr .Accordion ("Model Settings" , open = False ):
400- with gr .Row ():
401- custom_model_path = gr .File (
402- label = "Custom Model File (ONNX)" ,
403- file_types = [".onnx" ],
404- file_count = "single"
405- )
406- custom_classes_path = gr .File (
407- label = "Custom Classes File (TXT)" ,
408- file_types = [".txt" ],
409- file_count = "single"
410- )
361+ gr .Markdown ("Upload an image and run inference." )
411362
412363 with gr .Row ():
413364 with gr .Column ():
414365 input_image = gr .Image (type = "pil" , label = "Input Image" )
415-
366+
416367 with gr .Row ():
417368 confidence = gr .Slider (
418369 minimum = 0.1 ,
@@ -421,24 +372,23 @@ def build_interface(model_path, class_names_path, example_images=None):
421372 step = 0.01 ,
422373 label = "Confidence Threshold" ,
423374 )
424-
375+
425376 image_size = gr .Slider (
426377 minimum = 32 ,
427378 maximum = 1920 ,
428379 value = 640 ,
429380 step = 32 ,
430381 label = "Image Size" ,
431- info = "Select image size for inference (larger = slower but potentially more accurate)"
382+ info = "Select image size for inference (larger = slower but potentially more accurate)" ,
432383 )
433-
384+
434385 submit_btn = gr .Button ("Run Inference" , variant = "primary" )
435386
436387 with gr .Column ():
437388 output_image = gr .Image (type = "pil" , label = "Detection Result" )
438389
439390 with gr .Row (equal_height = True ):
440391 output_message = gr .Textbox (label = "Status" )
441-
442392 count_plot = gr .BarPlot (
443393 x = "Class" ,
444394 y = "Count" ,
@@ -448,71 +398,43 @@ def build_interface(model_path, class_names_path, example_images=None):
448398 orientation = "h" ,
449399 label_title = "Object Counts" ,
450400 )
451-
401+
452402 # Add examples component if example images are provided
453403 if example_images :
454404 gr .Examples (
455405 examples = example_images ,
456406 inputs = input_image ,
457407 )
458408
459- # Function to handle model path selection
460- def get_model_path (custom_file , default_path ):
461- if custom_file is not None :
462- return custom_file .name
463- return default_path
464-
465- def get_classes_path (custom_file , default_path ):
466- if custom_file is not None :
467- return custom_file .name
468- return default_path
469-
470- # Set up the click event inside the Blocks context
409+ # Set up the click event
471410 submit_btn .click (
472- fn = lambda img , custom_model , custom_classes , conf , img_size : predict (
411+ fn = lambda img , conf , img_size : predict (
473412 img ,
474- get_model_path ( custom_model , model_path ) ,
475- get_classes_path ( custom_classes , class_names_path ) ,
413+ model_path ,
414+ class_names_path ,
476415 conf ,
477- img_size
416+ img_size ,
478417 ),
479- inputs = [
480- input_image ,
481- custom_model_path ,
482- custom_classes_path ,
483- confidence ,
484- image_size ,
485- ],
418+ inputs = [input_image , confidence , image_size ],
486419 outputs = [output_image , output_message , count_plot ],
487420 )
488421
489422 with gr .Row ():
490423 with gr .Column ():
491- gr .HTML ("<div style='text-align: center; margin: 0 auto;'>Created by <a href='https://dicksonneoh.com' target='_blank'>Dickson Neoh</a>.</div>" )
424+ gr .HTML (
425+ "<div style='text-align: center; margin: 0 auto;'>Created by <a href='https://dicksonneoh.com' target='_blank'>Dickson Neoh</a>.</div>"
426+ )
492427
493428 return demo
494429
495430
496431def parse_args ():
497432 """Parse command line arguments."""
498- parser = argparse .ArgumentParser (description = 'DEIMKit Detection Demo' )
499- parser .add_argument (
500- '--model' ,
501- type = str ,
502- default = MODEL_PATH ,
503- help = 'Path to ONNX model file'
504- )
505- parser .add_argument (
506- '--classes' ,
507- type = str ,
508- default = CLASS_NAMES_PATH ,
509- help = 'Path to class names file'
510- )
433+ parser = argparse .ArgumentParser (description = "DEIMKit Detection Demo" )
434+ parser .add_argument ("--model" , type = str , required = True , help = "Path to ONNX model file" )
435+ parser .add_argument ("--classes" , type = str , required = True , help = "Path to class names file" )
511436 parser .add_argument (
512- '--examples' ,
513- type = str ,
514- default = os .path .join (BASE_DIR , "examples" ),
515- help = 'Path to directory containing example images'
437+ "--examples" , type = str , help = "Path to directory containing example images (optional)"
516438 )
517439 return parser .parse_args ()
518440
@@ -522,27 +444,18 @@ def launch_demo():
522444 Launch the Gradio demo with model and class names paths from command line arguments.
523445 """
524446 args = parse_args ()
525-
526- # Create examples directory if it doesn't exist
527- examples_dir = args .examples
528- if not os .path .exists (examples_dir ):
529- os .makedirs (examples_dir )
530- print (f"Created examples directory at { examples_dir } " )
531-
532- # Get list of example images
447+
533448 example_images = []
534- if os .path .exists (examples_dir ):
449+ if args . examples and os .path .exists (args . examples ):
535450 example_images = [
536- os .path .join (examples_dir , f )
537- for f in os .listdir (examples_dir )
538- if f .lower ().endswith ((' .png' , ' .jpg' , ' .jpeg' ))
451+ os .path .join (args . examples , f )
452+ for f in os .listdir (args . examples )
453+ if f .lower ().endswith ((" .png" , " .jpg" , " .jpeg" ))
539454 ]
540- print (f"Found { len (example_images )} example images in { examples_dir } " )
541-
455+ print (f"Found { len (example_images )} example images in { args . examples } " )
456+
542457 demo = build_interface (args .model , args .classes , example_images )
543-
544- # Launch the demo without the examples parameter
545- demo .launch (share = False , inbrowser = True ) # Set share=True if you want to create a shareable link
458+ demo .launch (share = False , inbrowser = True )
546459
547460
548461if __name__ == "__main__" :
0 commit comments