diff --git a/onnxruntime-sys/Cargo.toml b/onnxruntime-sys/Cargo.toml index 92cf0911..e9a89de0 100644 --- a/onnxruntime-sys/Cargo.toml +++ b/onnxruntime-sys/Cargo.toml @@ -17,8 +17,8 @@ keywords = ["neuralnetworks", "onnx", "bindings"] [dependencies] [build-dependencies] -bindgen = {version = "0.55", optional = true} -ureq = "1.5.1" +bindgen = { version = "0.56", optional = true } +ureq = "2" # Used on Windows zip = "0.5" diff --git a/onnxruntime-sys/build.rs b/onnxruntime-sys/build.rs index a3eafb7b..47d0573e 100644 --- a/onnxruntime-sys/build.rs +++ b/onnxruntime-sys/build.rs @@ -11,7 +11,7 @@ use std::{ /// WARNING: If version is changed, bindings for all platforms will have to be re-generated. /// To do so, run this: /// cargo build --package onnxruntime-sys --features generate-bindings -const ORT_VERSION: &str = "1.5.2"; +const ORT_VERSION: &str = "1.6.0"; /// Base Url from which to download pre-built releases/ const ORT_RELEASE_BASE_URL: &str = "https://github.com/microsoft/onnxruntime/releases/download"; @@ -106,15 +106,19 @@ fn generate_bindings(include_dir: &Path) { } fn download>(source_url: &str, target_file: P) { - let resp = ureq::get(source_url) - .timeout_connect(1_000) // 1 second + let agent = ureq::AgentBuilder::new() + .timeout_read(std::time::Duration::from_secs(1)) // 1 second .timeout(std::time::Duration::from_secs(300)) - .call(); + .build(); - if resp.error() { + let resp = agent.get(source_url).call(); + + if resp.is_err() { panic!("ERROR: Failed to download {}: {:#?}", source_url, resp); } + let resp = resp.unwrap(); + let len = resp .header("Content-Length") .and_then(|s| s.parse::().ok()) diff --git a/onnxruntime/Cargo.toml b/onnxruntime/Cargo.toml index 8f8837d5..50abf6b5 100644 --- a/onnxruntime/Cargo.toml +++ b/onnxruntime/Cargo.toml @@ -19,21 +19,21 @@ name = "integration_tests" required-features = ["model-fetching"] [dependencies] -onnxruntime-sys = {version = "0.0.10", path = "../onnxruntime-sys"} +onnxruntime-sys = { version = "0.0.10", path = "../onnxruntime-sys" } lazy_static = "1.4" -ndarray = "0.13" +ndarray = "0.14" thiserror = "1.0" tracing = "0.1" # Enabled with 'model-fetching' feature -ureq = {version = "1.5.1", optional = true} +ureq = { version = "2", optional = true } [dev-dependencies] image = "0.23" -test-env-log = {version = "0.2", default-features = false, features = ["trace"]} +test-env-log = { version = "0.2", default-features = false, features = ["trace"] } tracing-subscriber = "0.2" -ureq = "1.5.1" +ureq = "2.0" [features] # Fetch model from ONNX Model Zoo (https://github.com/onnx/models) diff --git a/onnxruntime/src/download.rs b/onnxruntime/src/download.rs index 9b3cd50f..00e2be22 100644 --- a/onnxruntime/src/download.rs +++ b/onnxruntime/src/download.rs @@ -78,10 +78,15 @@ impl AvailableOnnxModel { "Downloading file, please wait....", ); - let resp = ureq::get(url) - .timeout_connect(1_000) // 1 second - .timeout(Duration::from_secs(180)) // 3 minutes - .call(); + let agent = ureq::AgentBuilder::new() + .timeout_connect(Duration::from_secs(1)) // 1 second .timeout_read(std::time::Duration::from_secs(1)) // 1 second + .timeout(Duration::from_secs(180)) // 3 minutes .timeout(std::time::Duration::from_secs(180))// 3 minutes + .build(); + + let resp = agent + .get(url) + .call() + .map_err(OrtDownloadError::DownloadError)?; assert!(resp.has("Content-Length")); let len = resp diff --git a/onnxruntime/src/error.rs b/onnxruntime/src/error.rs index bc98a862..0e919c48 100644 --- a/onnxruntime/src/error.rs +++ b/onnxruntime/src/error.rs @@ -131,6 +131,10 @@ pub enum OrtApiError { #[non_exhaustive] #[derive(Error, Debug)] pub enum OrtDownloadError { + /// Generic download error + #[cfg(feature = "model-fetching")] + #[error("Error downloading data")] + DownloadError(#[from] ureq::Error), /// Generic input/output error #[error("Error downloading data to file: {0}")] IoError(#[from] io::Error),