|
| 1 | +"""A Sphinx extension to add a button to open a notebook in Google Colab.""" |
| 2 | + |
| 3 | +from docutils import nodes |
| 4 | +from sphinx.util.docutils import SphinxDirective |
| 5 | +from sphinx_gallery.notebook import add_code_cell, add_markdown_cell |
| 6 | + |
| 7 | +import os |
| 8 | +import json |
| 9 | + |
| 10 | + |
| 11 | +class ColabLinkNode(nodes.General, nodes.Element): |
| 12 | + """A custom docutils node to represent the Colab link.""" |
| 13 | + |
| 14 | + |
| 15 | +def visit_colab_link_node_html(self, node): |
| 16 | + self.body.append(node["html"]) |
| 17 | + |
| 18 | + |
| 19 | +def depart_colab_link_node_html(self, node): |
| 20 | + pass |
| 21 | + |
| 22 | + |
| 23 | +class ColabLinkDirective(SphinxDirective): |
| 24 | + """Directive to insert a link to open a notebook in Google Colab.""" |
| 25 | + |
| 26 | + has_content = True |
| 27 | + option_spec = { |
| 28 | + "needs_gpu": int, |
| 29 | + } |
| 30 | + |
| 31 | + def run(self): |
| 32 | + """Run the directive.""" |
| 33 | + # Determine the path of the current .rst file |
| 34 | + rst_file_path = self.env.doc2path(self.env.docname) |
| 35 | + rst_file_dir = os.path.dirname(rst_file_path) |
| 36 | + |
| 37 | + # Determine the notebook file path assuming it is in the same directory as the .rst file |
| 38 | + notebook_filename = os.path.basename(rst_file_path).replace(".rst", ".ipynb") |
| 39 | + |
| 40 | + # Full path to the notebook |
| 41 | + notebook_full_path = os.path.join(rst_file_dir, notebook_filename) |
| 42 | + |
| 43 | + # Convert the full path back to a relative path from the repo root |
| 44 | + # repo_root = self.config.project_root_dir |
| 45 | + notebook_repo_relative_path = os.path.relpath( |
| 46 | + notebook_full_path, os.path.join(os.getcwd(), "docs") |
| 47 | + ) |
| 48 | + |
| 49 | + config_ext = self.env.config["colab_notebook"] |
| 50 | + base_colab_url = config_ext.get( |
| 51 | + "base_colab_url", "https://colab.research.google.com" |
| 52 | + ) |
| 53 | + base_repo = config_ext["repo"] |
| 54 | + branch = config_ext["branch"] |
| 55 | + path = config_ext["path"] |
| 56 | + # Generate the Colab URL based on GitHub repo information |
| 57 | + self.colab_url = f"{base_colab_url}/{base_repo}/blob/{branch}/{path}/{notebook_repo_relative_path}" |
| 58 | + |
| 59 | + # Create the HTML button or link |
| 60 | + self.html = f"""<div class="colab-button"> |
| 61 | + <a href="{self.colab_url}" target="_blank"> |
| 62 | + <img src="https://colab.research.google.com/assets/colab-badge.svg" |
| 63 | + alt="Open In Colab"/> |
| 64 | + </a> |
| 65 | + </div> |
| 66 | + """ |
| 67 | + self.notebook_modifier(notebook_full_path, "\n".join(self.content)) |
| 68 | + |
| 69 | + # Create the node to insert the HTML |
| 70 | + node = ColabLinkNode(html=self.html) |
| 71 | + return [node] |
| 72 | + |
| 73 | + def notebook_modifier(self, notebook_path, commands): |
| 74 | + """Modify the notebook to add a warning about GPU requirement.""" |
| 75 | + with open(notebook_path) as f: |
| 76 | + notebook = json.load(f) |
| 77 | + if "cells" not in notebook: |
| 78 | + notebook["cells"] = [] |
| 79 | + |
| 80 | + # Add a cell to install the required libraries at the position where we have |
| 81 | + # colab link |
| 82 | + idx = self.find_index_of_colab_link(notebook) |
| 83 | + |
| 84 | + code_lines = ["# Install libraries"] |
| 85 | + code_lines.append(commands) |
| 86 | + code_lines.append( |
| 87 | + "!pip install" + " ".join(self.env.config["colab_notebook"]["dependencies"]) |
| 88 | + ) |
| 89 | + dummy_notebook_content = {"cells": []} |
| 90 | + add_code_cell( |
| 91 | + dummy_notebook_content, |
| 92 | + "\n".join(code_lines), |
| 93 | + ) |
| 94 | + notebook["cells"][idx] = dummy_notebook_content["cells"][0] |
| 95 | + |
| 96 | + needs_GPU = self.options.get("needs_gpu", False) |
| 97 | + if needs_GPU: |
| 98 | + # Add a warning cell at the top of the notebook |
| 99 | + warning_template = "\n".join( |
| 100 | + [ |
| 101 | + "<div class='alert alert-{message_class}'>", |
| 102 | + "", |
| 103 | + "# Need GPU warning", |
| 104 | + "", |
| 105 | + "{message}", |
| 106 | + "</div>", |
| 107 | + self.html, |
| 108 | + ] |
| 109 | + ) |
| 110 | + message_class = "warning" |
| 111 | + message = ( |
| 112 | + "Running this example requires a GPU, and hence is NOT " |
| 113 | + "possible on binder currently We request you to kindly run this notebook " |
| 114 | + "on Google Colab by clicking the link below. Additionally, please make " |
| 115 | + "sure to set the runtime on Colab to use a GPU and install the below " |
| 116 | + "libraries before running." |
| 117 | + ) |
| 118 | + idx = 0 |
| 119 | + else: |
| 120 | + # Add a warning cell at the top of the notebook |
| 121 | + warning_template = "\n".join( |
| 122 | + [ |
| 123 | + "<div class='alert alert-{message_class}'>", |
| 124 | + "", |
| 125 | + "# Install libraries needed for Colab", |
| 126 | + "", |
| 127 | + "{message}", |
| 128 | + "</div>", |
| 129 | + self.html, |
| 130 | + ] |
| 131 | + ) |
| 132 | + message_class = "info" |
| 133 | + message = ( |
| 134 | + "The below installation commands are needed to be run only on " |
| 135 | + "Google Colab." |
| 136 | + ) |
| 137 | + |
| 138 | + dummy_notebook_content = {"cells": []} |
| 139 | + add_markdown_cell( |
| 140 | + dummy_notebook_content, |
| 141 | + warning_template.format(message_class=message_class, message=message), |
| 142 | + ) |
| 143 | + notebook["cells"] = ( |
| 144 | + notebook["cells"][:idx] |
| 145 | + + dummy_notebook_content["cells"] |
| 146 | + + notebook["cells"][idx:] |
| 147 | + ) |
| 148 | + |
| 149 | + # Write back updated notebook |
| 150 | + with open(notebook_path, "w", encoding="utf-8") as f: |
| 151 | + json.dump(notebook, f, ensure_ascii=False, indent=2) |
| 152 | + |
| 153 | + def find_index_of_colab_link(self, notebook): |
| 154 | + """Find the index of the cell containing the Colab link.""" |
| 155 | + for idx, cell in enumerate(notebook["cells"]): |
| 156 | + if cell["cell_type"] == "markdown" and ".. colab-link::" in "".join( |
| 157 | + cell.get("source", "") |
| 158 | + ): |
| 159 | + return idx |
| 160 | + return 0 |
| 161 | + |
| 162 | + |
| 163 | +def setup(app): |
| 164 | + """Set up the Sphinx extension.""" |
| 165 | + app.add_config_value("colab_notebook", dict(), "html") |
| 166 | + app.add_node( |
| 167 | + ColabLinkNode, html=(visit_colab_link_node_html, depart_colab_link_node_html) |
| 168 | + ) |
| 169 | + app.add_directive("colab-link", ColabLinkDirective) |
| 170 | + |
| 171 | + return { |
| 172 | + "version": "0.1", |
| 173 | + "parallel_read_safe": True, |
| 174 | + "parallel_write_safe": True, |
| 175 | + } |
0 commit comments