Skip to content

Commit 6724a78

Browse files
committed
step three
Signed-off-by: Andrei Gherghescu <[email protected]>
1 parent e3ed39a commit 6724a78

File tree

2 files changed

+50
-28
lines changed

2 files changed

+50
-28
lines changed

plotly/src/plot.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,7 @@ mod tests {
10911091
assert!(file_size > 0,);
10921092
#[cfg(not(feature = "debug"))]
10931093
assert!(std::fs::remove_file(&dst).is_ok());
1094+
exporter.close();
10941095
}
10951096

10961097
#[test]
@@ -1110,6 +1111,7 @@ mod tests {
11101111
assert!(file_size > 0,);
11111112
#[cfg(not(feature = "debug"))]
11121113
assert!(std::fs::remove_file(&dst).is_ok());
1114+
exporter.close();
11131115
}
11141116

11151117
#[test]
@@ -1129,6 +1131,7 @@ mod tests {
11291131
assert!(file_size > 0,);
11301132
#[cfg(not(feature = "debug"))]
11311133
assert!(std::fs::remove_file(&dst).is_ok());
1134+
exporter.close();
11321135
}
11331136

11341137
#[test]
@@ -1156,6 +1159,7 @@ mod tests {
11561159
assert!(file_size > 0,);
11571160
#[cfg(not(feature = "debug"))]
11581161
assert!(std::fs::remove_file(&dst).is_ok());
1162+
exporter.close();
11591163
}
11601164

11611165
#[test]
@@ -1175,6 +1179,7 @@ mod tests {
11751179
assert!(file_size > 0,);
11761180
#[cfg(not(feature = "debug"))]
11771181
assert!(std::fs::remove_file(&dst).is_ok());
1182+
exporter.close();
11781183
}
11791184

11801185
#[test]
@@ -1200,6 +1205,7 @@ mod tests {
12001205
// Limit the comparison to the first characters;
12011206
// As image contents seem to be slightly inconsistent across platforms
12021207
assert_eq!(expected_decoded[..2], result_decoded[..2]);
1208+
exporter.close();
12031209
}
12041210

12051211
#[test]
@@ -1221,6 +1227,7 @@ mod tests {
12211227
// seem to contain uniquely generated IDs
12221228
const LEN: usize = 10;
12231229
assert_eq!(expected[..LEN], image_svg[..LEN]);
1230+
exporter.close();
12241231
}
12251232

12261233
#[test]
@@ -1261,5 +1268,6 @@ mod tests {
12611268
assert!(file_size > 0,);
12621269
#[cfg(not(feature = "debug"))]
12631270
assert!(std::fs::remove_file(&dst).is_ok());
1271+
exporter.close();
12641272
}
12651273
}

plotly_static/src/lib.rs

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -622,27 +622,36 @@ impl StaticExporterBuilder {
622622

623623
/// Create a new WebDriver instance based on the spawn_webdriver flag
624624
fn create_webdriver(&self) -> Result<WebDriver> {
625-
match self.spawn_webdriver {
626-
// Try to connect to existing WebDriver or spawn new if not available
627-
true => WebDriver::connect_or_spawn(self.webdriver_port),
628-
// Create the WebDriver instance without spawning
629-
false => WebDriver::new(self.webdriver_port),
630-
}
625+
let port = self.webdriver_port;
626+
let in_async = tokio::runtime::Handle::try_current().is_ok();
627+
628+
let run_create_fn = |spawn: bool| -> Result<WebDriver> {
629+
let work = move || {
630+
if spawn {
631+
WebDriver::connect_or_spawn(port)
632+
} else {
633+
WebDriver::new(port)
634+
}
635+
};
636+
if in_async {
637+
std::thread::spawn(work)
638+
.join()
639+
.map_err(|_| anyhow!("failed to join webdriver thread"))?
640+
} else {
641+
work()
642+
}
643+
};
644+
645+
run_create_fn(self.spawn_webdriver)
631646
}
632647
}
633648

634-
// Async builder for async-first exporter (added without reordering existing items)
649+
// Async builder for async-first exporter (added without reordering existing
650+
// items)
635651
impl StaticExporterBuilder {
636652
/// Build an async exporter for use within async contexts.
637653
pub fn build_async(&self) -> Result<AsyncStaticExporter> {
638-
let wd = if self.spawn_webdriver {
639-
let port = self.webdriver_port;
640-
std::thread::spawn(move || WebDriver::connect_or_spawn(port))
641-
.join()
642-
.map_err(|_| anyhow!("failed to join webdriver spawn thread"))??
643-
} else {
644-
WebDriver::new(self.webdriver_port)?
645-
};
654+
let wd = self.create_webdriver()?;
646655
Ok(AsyncStaticExporter {
647656
webdriver_port: self.webdriver_port,
648657
webdriver_url: self.webdriver_url.clone(),
@@ -760,6 +769,13 @@ impl StaticExporter {
760769
height: usize,
761770
scale: f64,
762771
) -> Result<(), Box<dyn std::error::Error>> {
772+
if tokio::runtime::Handle::try_current().is_ok() {
773+
return Err(anyhow!(
774+
"StaticExporter sync methods cannot be used inside an async context. \
775+
Use StaticExporterBuilder::build_async() and the associated AsyncStaticExporter::write_fig(...)."
776+
)
777+
.into());
778+
}
763779
let rt = self.runtime.clone();
764780
rt.block_on(
765781
self.inner
@@ -818,17 +834,20 @@ impl StaticExporter {
818834
height: usize,
819835
scale: f64,
820836
) -> Result<String, Box<dyn std::error::Error>> {
837+
if tokio::runtime::Handle::try_current().is_ok() {
838+
return Err(anyhow!(
839+
"StaticExporter sync methods cannot be used inside an async context. \
840+
Use StaticExporterBuilder::build_async() and the associated AsyncStaticExporter::write_to_string(...)."
841+
)
842+
.into());
843+
}
821844
let rt = self.runtime.clone();
822845
rt.block_on(
823846
self.inner
824847
.write_to_string(plot, format, width, height, scale),
825848
)
826849
}
827850

828-
/// Convert the Plotly graph to a static image using Kaleido and return the
829-
/// result as a String
830-
// Removed internal export/extract in favor of delegating to AsyncStaticExporter
831-
832851
fn extract_plain(payload: &str, format: &ImageFormat) -> Result<String> {
833852
match payload.split_once(",") {
834853
Some((type_info, data)) => {
@@ -930,16 +949,11 @@ impl StaticExporter {
930949
}
931950

932951
/// Explicitly close the WebDriver session and stop the driver.
933-
/// Prefer calling this in long-running applications to ensure deterministic cleanup.
952+
/// Prefer calling this in long-running applications to ensure deterministic
953+
/// cleanup.
934954
pub fn close(&mut self) {
935-
if let Some(client) = self.inner.webdriver_client.take() {
936-
let runtime = self.runtime.clone();
937-
runtime.block_on(async {
938-
if let Err(e) = client.close().await {
939-
error!("Failed to close WebDriver client: {e}");
940-
}
941-
});
942-
}
955+
let runtime = self.runtime.clone();
956+
runtime.block_on(self.inner.close());
943957
}
944958
}
945959

0 commit comments

Comments
 (0)