|
6 | 6 | import json |
7 | 7 |
|
8 | 8 |
|
9 | | -class ColabLinkNode(nodes.General, nodes.Element): |
10 | | - """A custom docutils node to represent the Colab link.""" |
| 9 | +def setup_colab_link_getter(app, pagename, templatename, context, doctree): |
| 10 | + """Add a function to the HTML context to get the Colab link for a notebook.""" |
11 | 11 |
|
| 12 | + def get_colab_link() -> str: |
| 13 | + """Assume that the notebook path is the same as the pagename""" |
| 14 | + return f"https://colab.research.google.com/github/mind-inria/mri-nufft/blob/colab-examples/examples/{pagename}.ipynb" |
12 | 15 |
|
13 | | -def visit_colab_link_node_html(self, node): |
14 | | - self.body.append(node["html"]) |
15 | | - |
16 | | - |
17 | | -def depart_colab_link_node_html(self, node): |
18 | | - pass |
19 | | - |
20 | | - |
21 | | -class ColabLinkDirective(SphinxDirective): |
22 | | - """Directive to insert a link to open a notebook in Google Colab.""" |
23 | | - |
24 | | - has_content = True |
25 | | - option_spec = { |
26 | | - "needs_gpu": int, |
27 | | - } |
28 | | - |
29 | | - def run(self): |
30 | | - """Run the directive.""" |
31 | | - # Determine the path of the current .rst file |
32 | | - rst_file_path = self.env.doc2path(self.env.docname) |
33 | | - rst_file_dir = os.path.dirname(rst_file_path) |
34 | | - |
35 | | - # Determine the notebook file path assuming it is in the same directory as the .rst file |
36 | | - notebook_filename = os.path.basename(rst_file_path).replace(".rst", ".ipynb") |
37 | | - |
38 | | - # Full path to the notebook |
39 | | - notebook_full_path = os.path.join(rst_file_dir, notebook_filename) |
40 | | - |
41 | | - # Convert the full path back to a relative path from the repo root |
42 | | - # repo_root = self.config.project_root_dir |
43 | | - notebook_repo_relative_path = os.path.relpath( |
44 | | - notebook_full_path, os.path.join(os.getcwd(), "docs") |
45 | | - ) |
46 | | - |
47 | | - # Generate the Colab URL based on GitHub repo information |
48 | | - self.colab_url = f"https://colab.research.google.com/github/mind-inria/mri-nufft/blob/colab-examples/examples/{notebook_repo_relative_path}" |
49 | | - |
50 | | - # Create the HTML button or link |
51 | | - self.html = f"""<div class="colab-button"> |
52 | | - <a href="{self.colab_url}" target="_blank"> |
53 | | - <img src="https://colab.research.google.com/assets/colab-badge.svg" |
54 | | - alt="Open In Colab"/> |
55 | | - </a> |
56 | | - </div> |
57 | | - """ |
58 | | - self.notebook_modifier(notebook_full_path, "\n".join(self.content)) |
59 | | - |
60 | | - # Create the node to insert the HTML |
61 | | - node = ColabLinkNode(html=self.html) |
62 | | - return [node] |
63 | | - |
64 | | - def notebook_modifier(self, notebook_path, commands): |
65 | | - """Modify the notebook to add a warning about GPU requirement.""" |
66 | | - with open(notebook_path) as f: |
67 | | - notebook = json.load(f) |
68 | | - if "cells" not in notebook: |
69 | | - notebook["cells"] = [] |
70 | | - |
71 | | - # Add a cell to install the required libraries at the position where we have |
72 | | - # colab link |
73 | | - idx = self.find_index_of_colab_link(notebook) |
74 | | - code_lines = ["# Install libraries"] |
75 | | - code_lines.append(commands) |
76 | | - code_lines.append("!pip install brainweb-dl # Required for data") |
77 | | - dummy_notebook_content = {"cells": []} |
78 | | - add_code_cell( |
79 | | - dummy_notebook_content, |
80 | | - "\n".join(code_lines), |
81 | | - ) |
82 | | - notebook["cells"][idx] = dummy_notebook_content["cells"][0] |
83 | | - |
84 | | - needs_GPU = self.options.get("needs_gpu", False) |
85 | | - if needs_GPU: |
86 | | - # Add a warning cell at the top of the notebook |
87 | | - warning_template = "\n".join( |
88 | | - [ |
89 | | - "<div class='alert alert-{message_class}'>", |
90 | | - "", |
91 | | - "# Need GPU warning", |
92 | | - "", |
93 | | - "{message}", |
94 | | - "</div>", |
95 | | - self.html, |
96 | | - ] |
97 | | - ) |
98 | | - message_class = "warning" |
99 | | - message = ( |
100 | | - "Running this mri-nufft example requires a GPU, and hence is NOT " |
101 | | - "possible on binder currently We request you to kindly run this notebook " |
102 | | - "on Google Colab by clicking the link below. Additionally, please make " |
103 | | - "sure to set the runtime on Colab to use a GPU and install the below " |
104 | | - "libraries before running." |
105 | | - ) |
106 | | - idx = 0 |
107 | | - else: |
108 | | - # Add a warning cell at the top of the notebook |
109 | | - warning_template = "\n".join( |
110 | | - [ |
111 | | - "<div class='alert alert-{message_class}'>", |
112 | | - "", |
113 | | - "# Install libraries needed for Colab", |
114 | | - "", |
115 | | - "{message}", |
116 | | - "</div>", |
117 | | - self.html, |
118 | | - ] |
119 | | - ) |
120 | | - message_class = "info" |
121 | | - message = ( |
122 | | - "The below installation commands are needed to be run only on " |
123 | | - "Google Colab." |
124 | | - ) |
125 | | - |
126 | | - dummy_notebook_content = {"cells": []} |
127 | | - add_markdown_cell( |
128 | | - dummy_notebook_content, |
129 | | - warning_template.format(message_class=message_class, message=message), |
130 | | - ) |
131 | | - notebook["cells"] = ( |
132 | | - notebook["cells"][:idx] |
133 | | - + dummy_notebook_content["cells"] |
134 | | - + notebook["cells"][idx:] |
135 | | - ) |
136 | | - |
137 | | - # Write back updated notebook |
138 | | - with open(notebook_path, "w", encoding="utf-8") as f: |
139 | | - json.dump(notebook, f, ensure_ascii=False, indent=2) |
140 | | - |
141 | | - def find_index_of_colab_link(self, notebook): |
142 | | - """Find the index of the cell containing the Colab link.""" |
143 | | - for idx, cell in enumerate(notebook["cells"]): |
144 | | - if cell["cell_type"] == "markdown" and ".. colab-link::" in "".join( |
145 | | - cell.get("source", "") |
146 | | - ): |
147 | | - return idx |
148 | | - return 0 |
| 16 | + context["get_colab_link"] = get_colab_link |
149 | 17 |
|
150 | 18 |
|
151 | 19 | def setup(app): |
152 | | - """Set up the Sphinx extension.""" |
153 | | - app.add_node( |
154 | | - ColabLinkNode, html=(visit_colab_link_node_html, depart_colab_link_node_html) |
155 | | - ) |
156 | | - app.add_directive("colab-link", ColabLinkDirective) |
| 20 | + app.connect("html-page-context", setup_colab_link_getter) |
157 | 21 |
|
158 | 22 | return { |
159 | | - "version": "0.1", |
| 23 | + "version": "0.4", |
160 | 24 | "parallel_read_safe": True, |
161 | 25 | "parallel_write_safe": True, |
162 | 26 | } |
0 commit comments