Skip to content

Commit a78e767

Browse files
committed
Refactor Gradio Demo Script for Improved Model Loading and Prediction
- Updated the `predict` function to accept a pre-loaded session instead of a model path, enhancing efficiency by loading the model once at startup. - Removed model loading timing from the status message, focusing on preprocessing, inference, and postprocessing times. - Improved error handling by raising a RuntimeError if model loading fails, ensuring robust feedback for users. - Streamlined the click event in the Gradio interface to pass the session directly, simplifying the function signature.
1 parent 72863c0 commit a78e767

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

scripts/gradio_demo.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,19 +283,11 @@ def create_bar_data(object_counts):
283283
return pd.DataFrame({"Class": ["No objects detected"], "Count": [0]})
284284

285285

286-
def predict(image, model_path, class_names_path, confidence_threshold, image_size):
286+
def predict(image, session, class_names_path, confidence_threshold, image_size):
287287
"""Main prediction function."""
288288
if image is None:
289289
return None, "Error: No image provided", None
290290

291-
# Load model
292-
model_load_start = time.time()
293-
session, error = load_model(model_path)
294-
model_load_time = time.time() - model_load_start
295-
296-
if error:
297-
return None, error, None
298-
299291
# Load class names
300292
class_names = load_class_names(class_names_path)
301293

@@ -335,11 +327,10 @@ def predict(image, model_path, class_names_path, confidence_threshold, image_siz
335327
# Create status message with timing information
336328
status_message = create_status_message(object_counts)
337329
status_message += "\n\nLatency Information:"
338-
status_message += f"\n- Model Loading: {model_load_time * 1000:.1f}ms"
339330
status_message += f"\n- Preprocessing: {preprocess_time * 1000:.1f}ms"
340331
status_message += f"\n- Inference: {inference_time * 1000:.1f}ms"
341332
status_message += f"\n- Postprocessing: {postprocess_time * 1000:.1f}ms"
342-
status_message += f"\n- Total Time: {(model_load_time + preprocess_time + inference_time + postprocess_time) * 1000:.1f}ms"
333+
status_message += f"\n- Total Time: {(preprocess_time + inference_time + postprocess_time) * 1000:.1f}ms"
343334

344335
bar_data = create_bar_data(object_counts)
345336

@@ -356,6 +347,11 @@ def build_interface(model_path, class_names_path, example_images=None):
356347
"""
357348
Build the Gradio interface components.
358349
"""
350+
# Load model once at startup
351+
session, error = load_model(model_path)
352+
if error:
353+
raise RuntimeError(f"Failed to load model: {error}")
354+
359355
with gr.Blocks(title="DEIMKit Detection") as demo:
360356
gr.Markdown("# DEIMKit Detection")
361357
gr.Markdown("Upload an image and run inference.")
@@ -406,11 +402,11 @@ def build_interface(model_path, class_names_path, example_images=None):
406402
inputs=input_image,
407403
)
408404

409-
# Set up the click event
405+
# Modify the click event to pass the session
410406
submit_btn.click(
411407
fn=lambda img, conf, img_size: predict(
412408
img,
413-
model_path,
409+
session, # Pass session instead of model_path
414410
class_names_path,
415411
conf,
416412
img_size,

0 commit comments

Comments
 (0)