|
| 1 | +import base64 |
1 | 2 | import io
|
2 | 3 | import os
|
3 | 4 | import shutil
|
@@ -138,28 +139,15 @@ def execute(self):
|
138 | 139 | kernel_name=nb.metadata.kernelspec["name"], store_widget_state=True, cwd=staging_dir
|
139 | 140 | )
|
140 | 141 |
|
141 |
| - mlflow.set_tracking_uri(MLFLOW_SERVER_URI) |
142 |
| - with mlflow.start_run(run_id=job.mlflow_run_id): |
143 |
| - try: |
144 |
| - ep.preprocess(nb, {"metadata": {"path": staging_dir}}) |
145 |
| - if job.parameters: |
146 |
| - mlflow.log_params(job.parameters) |
147 |
| - |
148 |
| - for idx, cell in enumerate(nb.cells): |
149 |
| - if "tags" in cell.metadata and "mlflow_log" in cell.metadata["tags"]: |
150 |
| - mlflow.log_text(cell.source, f"source_cell_{idx}.txt") |
151 |
| - if cell.cell_type == "code" and cell.outputs: |
152 |
| - for output in cell.outputs: |
153 |
| - if "text/plain" in output.data: |
154 |
| - mlflow.log_text( |
155 |
| - output.data["text/plain"], f"output_cell_{idx}.txt" |
156 |
| - ) |
157 |
| - |
158 |
| - except CellExecutionError as e: |
159 |
| - raise e |
160 |
| - finally: |
161 |
| - self.add_side_effects_files(staging_dir) |
162 |
| - self.create_output_files(job, nb) |
| 142 | + try: |
| 143 | + ep.preprocess(nb, {"metadata": {"path": staging_dir}}) |
| 144 | + except CellExecutionError as e: |
| 145 | + raise e |
| 146 | + finally: |
| 147 | + self.add_side_effects_files(staging_dir) |
| 148 | + self.create_output_files(job, nb) |
| 149 | + if getattr(job, "mlflow_logging", False): |
| 150 | + self.log_to_mlflow(job, nb) |
163 | 151 |
|
164 | 152 | def add_side_effects_files(self, staging_dir: str):
|
165 | 153 | """Scan for side effect files potentially created after input file execution and update the job's packaged_files with these files"""
|
@@ -187,10 +175,105 @@ def create_output_files(self, job: DescribeJob, notebook_node):
|
187 | 175 | for output_format in job.output_formats:
|
188 | 176 | cls = nbconvert.get_exporter(output_format)
|
189 | 177 | output, _ = cls().from_notebook_node(notebook_node)
|
190 |
| - output_path = self.staging_paths[output_format] |
191 |
| - with fsspec.open(output_path, "w", encoding="utf-8") as f: |
| 178 | + with fsspec.open(self.staging_paths[output_format], "w", encoding="utf-8") as f: |
192 | 179 | f.write(output)
|
193 |
| - mlflow.log_artifact(output_path) |
| 180 | + |
| 181 | + def log_to_mlflow(self, job, nb): |
| 182 | + mlflow.set_tracking_uri(MLFLOW_SERVER_URI) |
| 183 | + with mlflow.start_run(run_id=job.mlflow_run_id): |
| 184 | + if job.parameters: |
| 185 | + mlflow.log_params(job.parameters) |
| 186 | + |
| 187 | + for cell_idx, cell in enumerate(nb.cells): |
| 188 | + if "tags" in cell.metadata: |
| 189 | + if "mlflow_log" in cell.metadata["tags"]: |
| 190 | + self.mlflow_log(cell, cell_idx) |
| 191 | + elif "mlflow_log_input" in cell.metadata["tags"]: |
| 192 | + self.mlflow_log_input(cell, cell_idx) |
| 193 | + elif "mlflow_log_output" in cell.metadata["tags"]: |
| 194 | + self.mlflow_log_output(cell, cell_idx) |
| 195 | + |
| 196 | + for output_format in job.output_formats: |
| 197 | + output_path = self.staging_paths[output_format] |
| 198 | + directory, file_name_with_extension = os.path.split(output_path) |
| 199 | + file_name, file_extension = os.path.splitext(file_name_with_extension) |
| 200 | + file_name_parts = file_name.split("-") |
| 201 | + file_name_without_timestamp = "-".join(file_name_parts[:-7]) |
| 202 | + file_name_final = f"{file_name_without_timestamp}{file_extension}" |
| 203 | + new_output_path = os.path.join(directory, file_name_final) |
| 204 | + shutil.copy(output_path, new_output_path) |
| 205 | + timestamp = "-".join(file_name_parts[-7:]).split(".")[0] |
| 206 | + mlflow.log_param("job_created", timestamp) |
| 207 | + mlflow.log_artifact(new_output_path, "") |
| 208 | + os.remove(new_output_path) |
| 209 | + |
| 210 | + def mlflow_log(self, cell, cell_idx): |
| 211 | + self.mlflow_log_input(cell, cell_idx) |
| 212 | + self.mlflow_log_output(cell, cell_idx) |
| 213 | + |
| 214 | + def mlflow_log_input(self, cell, cell_idx): |
| 215 | + mlflow.log_text(cell.source, f"cell_{cell_idx}_input.txt") |
| 216 | + |
| 217 | + def mlflow_log_output(self, cell, cell_idx): |
| 218 | + if cell.cell_type == "code" and hasattr(cell, "outputs"): |
| 219 | + self._log_code_output(cell_idx, cell.outputs) |
| 220 | + elif cell.cell_type == "markdown": |
| 221 | + self._log_markdown_output(cell, cell_idx) |
| 222 | + |
| 223 | + def _log_code_output(self, cell_idx, outputs): |
| 224 | + for output_idx, output in enumerate(outputs): |
| 225 | + if output.output_type == "stream": |
| 226 | + self._log_stream_output(cell_idx, output_idx, output) |
| 227 | + elif hasattr(output, "data"): |
| 228 | + for output_data_idx, output_data in enumerate(output.data): |
| 229 | + if output_data == "text/plain": |
| 230 | + mlflow.log_text( |
| 231 | + output.data[output_data], |
| 232 | + f"cell_{cell_idx}_output_{output_data_idx}.txt", |
| 233 | + ) |
| 234 | + elif output_data == "text/html": |
| 235 | + self._log_html_output(output, cell_idx, output_data_idx) |
| 236 | + elif output_data == "application/pdf": |
| 237 | + self._log_pdf_output(output, cell_idx, output_data_idx) |
| 238 | + elif output_data.startswith("image"): |
| 239 | + self._log_image_output(output, cell_idx, output_data_idx, output_data) |
| 240 | + |
| 241 | + def _log_stream_output(self, cell_idx, output_idx, output): |
| 242 | + mlflow.log_text("".join(output.text), f"cell_{cell_idx}_output_{output_idx}.txt") |
| 243 | + |
| 244 | + def _log_html_output(self, output, cell_idx, output_idx): |
| 245 | + if "text/html" in output.data: |
| 246 | + html_content = output.data["text/html"] |
| 247 | + if isinstance(html_content, list): |
| 248 | + html_content = "".join(html_content) |
| 249 | + mlflow.log_text(html_content, f"cell_{cell_idx}_output_{output_idx}.html") |
| 250 | + |
| 251 | + def _log_pdf_output(self, output, cell_idx, output_idx): |
| 252 | + pdf_data = base64.b64decode(output.data["application/pdf"].split(",")[1]) |
| 253 | + with open(f"cell_{cell_idx}_output_{output_idx}.pdf", "wb") as pdf_file: |
| 254 | + pdf_file.write(pdf_data) |
| 255 | + mlflow.log_artifact(f"cell_{cell_idx}_output_{output_idx}.pdf") |
| 256 | + |
| 257 | + def _log_image_output(self, output, cell_idx, output_idx, mime_type): |
| 258 | + image_data_str = output.data[mime_type] |
| 259 | + if "," in image_data_str: |
| 260 | + image_data_base64 = image_data_str.split(",")[1] |
| 261 | + else: |
| 262 | + image_data_base64 = image_data_str |
| 263 | + |
| 264 | + try: |
| 265 | + image_data = base64.b64decode(image_data_base64) |
| 266 | + image_extension = mime_type.split("/")[1] |
| 267 | + filename = f"cell_{cell_idx}_output_{output_idx}.{image_extension}" |
| 268 | + with open(filename, "wb") as image_file: |
| 269 | + image_file.write(image_data) |
| 270 | + mlflow.log_artifact(filename) |
| 271 | + os.remove(filename) |
| 272 | + except Exception as e: |
| 273 | + print(f"Error logging image output in cell {cell_idx}, output {output_idx}: {e}") |
| 274 | + |
| 275 | + def _log_markdown_output(self, cell, cell_idx): |
| 276 | + mlflow.log_text(cell.source, f"cell_{cell_idx}_output_0.md") |
194 | 277 |
|
195 | 278 | def supported_features(cls) -> Dict[JobFeature, bool]:
|
196 | 279 | return {
|
|
0 commit comments