diff --git a/demo_gradio.py b/demo_gradio.py index 466a5ff4..0c1d119e 100644 --- a/demo_gradio.py +++ b/demo_gradio.py @@ -212,6 +212,14 @@ def gradio_demo( with torch.no_grad(): predictions = run_model(target_dir, model) + + # Convert CUDA tensors to CPU before saving + for key, value in predictions.items(): + if isinstance(value, list): + predictions[key] = [item.cpu() if hasattr(item, 'is_cuda') and item.is_cuda else item for item in value] + elif hasattr(value, 'is_cuda') and value.is_cuda: + predictions[key] = value.cpu() + # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions)