Skip to content

Commit 95bd195

Browse files
Merge nodes from jupyter-input and jupyter-output
1 parent f61dd8e commit 95bd195

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

jupyter_sphinx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
JupyterWidgetStateNode,
2424
WIDGET_VIEW_MIMETYPE,
2525
JupyterDownloadRole,
26+
CombineCellInputOutput,
2627
CellOutputsToNodes,
2728
)
2829
from .execute import JupyterKernel, ExecuteJupyterCells
@@ -275,6 +276,7 @@ def setup(app):
275276
app.add_role("jupyter-download:notebook", JupyterDownloadRole())
276277
app.add_role("jupyter-download:nb", JupyterDownloadRole())
277278
app.add_role("jupyter-download:script", JupyterDownloadRole())
279+
app.add_transform(CombineCellInputOutput)
278280
app.add_transform(ExecuteJupyterCells)
279281
app.add_transform(CellOutputsToNodes)
280282

jupyter_sphinx/ast.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sphinx.addnodes import download_reference
1313
from sphinx.transforms import SphinxTransform
1414
from sphinx.environment.collectors.asset import ImageCollector
15+
from sphinx.errors import ExtensionError
1516

1617
import ipywidgets.embed
1718
import nbconvert
@@ -271,7 +272,6 @@ def run(self):
271272
emphasize_lines=[],
272273
raises=False,
273274
stderr=False,
274-
classes=["jupyter_cell"],
275275
)
276276

277277
# Add a blank input and the given output to the cell
@@ -574,6 +574,39 @@ def get_widgets(notebook):
574574
return None
575575

576576

577+
class CombineCellInputOutput(SphinxTransform):
578+
"""Merge nodes from CellOutput with the preceding CellInput node."""
579+
580+
default_priority = 120
581+
582+
def apply(self):
583+
moved_outputs = set()
584+
585+
for cell_node in self.document.traverse(JupyterCellNode):
586+
if cell_node.attributes["execute"] == False:
587+
if cell_node.attributes["hide_code"] == False:
588+
# Cell came from jupyter-input
589+
sibling = cell_node.next_node(descend=False, siblings=True)
590+
if (
591+
isinstance(sibling, JupyterCellNode)
592+
and sibling.attributes["execute"] == False
593+
and sibling.attributes["hide_code"] == True
594+
):
595+
# Sibling came from jupyter-output, so we merge
596+
cell_node += sibling.children[1]
597+
cell_node.attributes["hide_output"] = False
598+
moved_outputs.update({sibling})
599+
else:
600+
# Call came from jupyter-output
601+
if cell_node not in moved_outputs:
602+
raise ExtensionError(
603+
"Found a jupyter-output node without a preceding jupyter-input"
604+
)
605+
606+
for output_node in moved_outputs:
607+
output_node.replace_self([])
608+
609+
577610
class CellOutputsToNodes(SphinxTransform):
578611
"""Use the builder context to transform a CellOutputNode into Sphinx nodes."""
579612

tests/test_execute.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -682,15 +682,29 @@ def test_input_cell_linenos(doctree):
682682

683683
def test_output_cell(doctree):
684684
source = """
685+
.. jupyter-input::
686+
687+
3 + 2
688+
685689
.. jupyter-output::
686690
687691
4
688692
"""
689693
tree = doctree(source)
690694
(cell,) = tree.traverse(JupyterCellNode)
691-
(celloutput,) = cell.children
695+
(cellinput, celloutput,) = cell.children
696+
assert cellinput.children[0].rawsource.strip() == "3 + 2"
692697
assert celloutput.children[0].rawsource.strip() == "4"
693698

699+
def test_output_only_error(doctree):
700+
source = """
701+
.. jupyter-output::
702+
703+
4
704+
"""
705+
with pytest.raises(ExtensionError):
706+
tree = doctree(source)
707+
694708
def test_multiple_directives(doctree):
695709
source = """
696710
.. jupyter-execute::
@@ -706,11 +720,10 @@ def test_multiple_directives(doctree):
706720
5
707721
"""
708722
tree = doctree(source)
709-
(ex, jin, jout) = tree.traverse(JupyterCellNode)
723+
(ex, jin) = tree.traverse(JupyterCellNode)
710724
(ex_in, ex_out) = ex.children
711-
(jin_in, _) = jin.children
712-
(jout_out,) = jout.children
725+
(jin_in, jin_out) = jin.children
713726
assert ex_in.children[0].rawsource.strip() == "2 + 2"
714727
assert ex_out.children[0].rawsource.strip() == "4"
715728
assert jin_in.children[0].rawsource.strip() == "3 + 3"
716-
assert jout_out.children[0].rawsource.strip() == "5"
729+
assert jin_out.children[0].rawsource.strip() == "5"

0 commit comments

Comments
 (0)