Skip to content

Commit 8d65e99

Browse files
committed
Instead of trying to detect notebook vs console, collect one time when we have any kind if ipython environment.
1 parent 76fcdb7 commit 8d65e99

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

src/dataframe.rs

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ use crate::physical_plan::PyExecutionPlan;
4949
use crate::record_batch::PyRecordBatchStream;
5050
use crate::sql::logical::PyLogicalPlan;
5151
use crate::utils::{
52-
get_tokio_runtime, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
52+
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
5353
};
5454
use crate::{
5555
errors::PyDataFusionResult,
@@ -192,12 +192,18 @@ fn build_formatter_config_from_python(formatter: &Bound<'_, PyAny>) -> PyResult<
192192
#[derive(Clone)]
193193
pub struct PyDataFrame {
194194
df: Arc<DataFrame>,
195+
196+
// In IPython environment cache batches between __repr__ and _repr_html_ calls.
197+
batches: Option<(Vec<RecordBatch>, bool)>,
195198
}
196199

197200
impl PyDataFrame {
198201
/// creates a new PyDataFrame
199202
pub fn new(df: DataFrame) -> Self {
200-
Self { df: Arc::new(df) }
203+
Self {
204+
df: Arc::new(df),
205+
batches: None,
206+
}
201207
}
202208
}
203209

@@ -224,16 +230,22 @@ impl PyDataFrame {
224230
}
225231
}
226232

227-
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
233+
fn __repr__(&mut self, py: Python) -> PyDataFusionResult<String> {
228234
// Get the Python formatter config
229235
let PythonFormatter {
230236
formatter: _,
231237
config,
232238
} = get_python_formatter_with_config(py)?;
233-
let (batches, has_more) = wait_for_future(
234-
py,
235-
collect_record_batches_to_display(self.df.as_ref().clone(), config),
236-
)??;
239+
240+
let should_cache = *is_ipython_env(py) && self.batches.is_none();
241+
let (batches, has_more) = match self.batches.take() {
242+
Some(b) => b,
243+
None => wait_for_future(
244+
py,
245+
collect_record_batches_to_display(self.df.as_ref().clone(), config),
246+
)??,
247+
};
248+
237249
if batches.is_empty() {
238250
// This should not be reached, but do it for safety since we index into the vector below
239251
return Ok("No data to display".to_string());
@@ -247,16 +259,27 @@ impl PyDataFrame {
247259
false => "",
248260
};
249261

262+
if should_cache {
263+
self.batches = Some((batches, has_more));
264+
}
265+
250266
Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
251267
}
252268

253-
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
269+
fn _repr_html_(&mut self, py: Python) -> PyDataFusionResult<String> {
254270
// Get the Python formatter and config
255271
let PythonFormatter { formatter, config } = get_python_formatter_with_config(py)?;
256-
let (batches, has_more) = wait_for_future(
257-
py,
258-
collect_record_batches_to_display(self.df.as_ref().clone(), config),
259-
)??;
272+
273+
let should_cache = *is_ipython_env(py) && self.batches.is_none();
274+
275+
let (batches, has_more) = match self.batches.take() {
276+
Some(b) => b,
277+
None => wait_for_future(
278+
py,
279+
collect_record_batches_to_display(self.df.as_ref().clone(), config),
280+
)??,
281+
};
282+
260283
if batches.is_empty() {
261284
// This should not be reached, but do it for safety since we index into the vector below
262285
return Ok("No data to display".to_string());
@@ -266,7 +289,7 @@ impl PyDataFrame {
266289

267290
// Convert record batches to PyObject list
268291
let py_batches = batches
269-
.into_iter()
292+
.iter()
270293
.map(|rb| rb.to_pyarrow(py))
271294
.collect::<PyResult<Vec<PyObject>>>()?;
272295

@@ -282,6 +305,10 @@ impl PyDataFrame {
282305
let html_result = formatter.call_method("format_html", (), Some(&kwargs))?;
283306
let html_str: String = html_result.extract()?;
284307

308+
if should_cache {
309+
self.batches = Some((batches, has_more));
310+
}
311+
285312
Ok(html_str)
286313
}
287314

src/utils.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
3939
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
4040
}
4141

42+
#[inline]
43+
pub(crate) fn is_ipython_env(py: Python) -> &'static bool {
44+
static IS_IPYTHON_ENV: OnceLock<bool> = OnceLock::new();
45+
IS_IPYTHON_ENV.get_or_init(|| {
46+
py.import("IPython")
47+
.and_then(|ipython| ipython.call_method0("get_ipython"))
48+
.map(|ipython| !ipython.is_none())
49+
.unwrap_or(false)
50+
})
51+
}
52+
4253
/// Utility to get the Global Datafussion CTX
4354
#[inline]
4455
pub(crate) fn get_global_ctx() -> &'static SessionContext {

0 commit comments

Comments
 (0)