diff --git a/benchmarks/attention/README.md b/benchmarks/attention/README.md new file mode 100644 index 000000000..5341b0bd5 --- /dev/null +++ b/benchmarks/attention/README.md @@ -0,0 +1,14 @@ +## JAX Fused-Attention Benchmarking +The benchmarking process is split into two stages: *generating* the timing data, and *visualizing* the timing data. The following steps assume you are located in `TransformerEngine/benchmarks/attention` (i.e. where this README is located). First, ensure that you install requirements via `pip install -r requirements.txt`. + +### Generate Timing Data +Run the following command to generate timing data. Please use the `-h` flag for details on the available arguments. The output csv, which will later be parsed to generate the interactive visualizations, is generated in the same directory as the script, since that is where the visualization stage expects it. + +```bash +python benchmark_attention_jax.py --bench-bwd --fwd-v3 --bwd-v3 -v +``` + +Note that you can also specify a target HIP device via `HIP_VISIBLE_DEVICES=` which may be useful in isolating the benchmarks to an unused GPU on a shared machine. + +### Generating Interactive Visualization +Simply run `panel serve panel_app.py`. This will launch a web-service on your localhost which displays an interactive visualization app. If launching on a remote server, VS code users will find that their IDE automatically port-forwards the correct ports, and thus they may directly open the link that is printed after running the command. Other users must ensure that their `ssh` into the remote server includes an appropriate port-forwarding (the default port is `5006`). \ No newline at end of file diff --git a/benchmarks/attention/benchmark_attention_jax.py b/benchmarks/attention/benchmark_attention_jax.py new file mode 100644 index 000000000..93e66f526 --- /dev/null +++ b/benchmarks/attention/benchmark_attention_jax.py @@ -0,0 +1,406 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# See LICENSE for license information. + +import os, sys +from pathlib import Path +import pandas as pd +import argparse +from functools import partial +from itertools import product +import jax +from jax import numpy as jnp +import csv +from transformer_engine.jax.attention import ( + AttnBiasType, + AttnMaskType, + QKVLayout, +) +from transformer_engine.jax import fp8_autocast + +# Needed in order to dump timings properly +os.environ["XLA_FLAGS"]="--xla_gpu_graph_level=0" + +# Add test_fused_attn to the sys path +tests_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../tests/jax/") +) +sys.path.append(tests_path) + +from test_fused_attn import ( + FusedAttnRunner, + FusedAttnHelper, + SeqDescFormat, + BiasShape, + customcall_fused_dpa, +) + + +# "b, s_q, s_kv, h_q, h_kv, d_qk, d_v," +SHAPES = ((2, 2048, 2048, 12, 12, 64, 64),) + +# data type +DTYPES = [jnp.float16, jnp.bfloat16] + +ATTN_MASK_TYPES = ( + AttnMaskType.NO_MASK, + AttnMaskType.PADDING_MASK, + AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, +) +QKV_LAYOUTS = ( + QKVLayout.BS3HD, + QKVLayout.BSHD_BS2HD, + QKVLayout.BSHD_BSHD_BSHD, + QKVLayout.T3HD, + QKVLayout.THD_T2HD, + QKVLayout.THD_THD_THD +) +SEQ_DESC_FORMATS = (SeqDescFormat.Mask, SeqDescFormat.Seqlens, SeqDescFormat.SegmentIDs) +SWA = (True, False) +IS_TRAINING = (True, False) +DROPOUT = (0.0, 0.1) +BIAS_CONFIGS = ((AttnBiasType.NO_BIAS, None), (AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS)) +CONFIGS = tuple( + product( + SHAPES, + DTYPES, + ATTN_MASK_TYPES, + QKV_LAYOUTS, + SEQ_DESC_FORMATS, + SWA, + IS_TRAINING, + DROPOUT, + BIAS_CONFIGS, + ) +) + +COLUMNS = [ + "batch_size", + "q_seq_len", + "kv_seq_len", + "q_heads", + "kv_heads", + "qk_dim", + "v_dim", + "attn_bias_type", + "attn_mask_type", + "dropout", + "dtype", + "is_training", + "qkv_layout", + "bias_shape", + "swa", + "seq_desc_format", + "mode", + "time", +] + +CWD = os.getcwd() + +class FusedAttnBenchRunner(FusedAttnRunner): + def bench_forward(self, warmup, iters, timings_dir): + """ + Run forward + """ + self._setup_inputs() + customcall_args = [ + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), + ] + kwargs = { + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, + "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, + } + + customcall_fused_dpa_jit = jax.jit( + partial(customcall_fused_dpa, **kwargs), + static_argnames=kwargs.keys(), + in_shardings=[ + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.seq_desc_sharding, + self.dropout_rng_sharding, + ], + ) + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + for _ in range(warmup): + customcall_fused_dpa_jit(*customcall_args) + + os.environ["NVTE_DUMP_AITER_RT"] = str(timings_dir) + '/' + + for _ in range(iters): + customcall_fused_dpa_jit(*customcall_args) + + del os.environ["NVTE_DUMP_AITER_RT"] + + def bench_backward(self, warmup, iters, timings_dir): + """ + Run value_and_grad with JIT, which includes both forward and backward. + """ + self._setup_inputs() + + def grad_func(func, *args, cp_reverse_out=False, **kwargs): + # Gradient is small, use a gradient multiplier to amplify the gradient + gradient_multiplier = self.max_seqlen_q * self.num_heads_q + if self.attn_mask_type.is_causal(): + gradient_multiplier /= 10 + # Keep only valid result for the gradient + if not cp_reverse_out: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + func(*args, **kwargs), + ) + else: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + self.cp_inverse_reorder_fn(func(*args, **kwargs)), + ) + return ( + jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier + ).astype(self.dtype) + + customcall_args = [ + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), + ] + kwargs = { + "attn_bias_type": self.attn_bias_type, + "attn_mask_type": self.attn_mask_type, + "scaling_factor": self.scaling_factor, + "dropout_probability": self.dropout_prob, + "is_training": self.is_training, + "qkv_layout": self.qkv_layout, + "max_segments_per_seq": self._get_max_segments_per_sequence(), + "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, + } + + # We can compute dBias only for the [1, h, s, s] layout + if self.bias_shape == BiasShape._1HSS: + arg_nums = (0, 1, 2, 3) + grad_shardings = ( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + ) + else: + arg_nums = (0, 1, 2) + grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding) + + # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation + jitted_primitive = jax.jit( + jax.value_and_grad( + lambda q, k, v, bias, *args: grad_func( + customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs + ), + arg_nums, + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.seq_desc_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), + ) + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + for _ in range(warmup): + jitted_primitive(*customcall_args) + + os.environ["NVTE_DUMP_AITER_RT"] = str(timings_dir) + '/' + + for _ in range(iters): + jitted_primitive(*customcall_args) + + del os.environ["NVTE_DUMP_AITER_RT"] + +def _filter_configs(configs): + for config in configs: + ( + shape, + dtype, + attn_mask_type, + qkv_layout, + seq_desc_format, + swa, + is_training, + dropout_prob, + bias_config + ) = config + b, s_q, s_kv, h_q, h_kv, d_qk, d_v = shape + attn_bias_type, bias_shape = bias_config + window_size = None + if swa: + window_size = (s_kv // 10, 0) + if qkv_layout.is_thd(): + if not attn_mask_type.is_padding(): + continue + if seq_desc_format == SeqDescFormat.Mask: + continue + if qkv_layout.is_qkvpacked(): + if (s_q != s_kv) or h_q != h_kv: + continue + if s_q > s_kv and window_size is not None: + continue + if d_qk != d_v and not qkv_layout.is_separate(): + continue + + backend = FusedAttnHelper( + dtype, + dtype, + qkv_layout, + attn_bias_type, + attn_mask_type, + dropout_prob, + h_q, h_kv, + s_q, s_kv, + d_qk, d_v, + (-1, -1) if window_size is None else window_size, + ).get_fused_attn_backend() + if backend == -1: + continue + if ( + attn_bias_type == AttnBiasType.POST_SCALE_BIAS + and bias_shape != BiasShape._1HSS + ): + if attn_mask_type.is_padding(): + continue + yield config + +def read_timings(timings_dir, rows, output, mode): + timings_path = timings_dir / f'aiter-{mode}-timings.txt' + times = pd.read_csv(timings_path, header=None, dtype=float) + os.remove(timings_path) + rows.extend([output | {"mode": mode, "time": t} for t in times[0].to_list()]) + +# Runs profiler and records timing information +def benchmark_dot_product_attention_profiler(args): + rows = [] + src_dir = Path(__file__).parent + timings_dir = src_dir / "timings" + os.makedirs(timings_dir, exist_ok=True) + for n, config in enumerate(_filter_configs(CONFIGS)): + ( + shape, + dtype, + attn_mask_type, + qkv_layout, + seq_desc_format, + swa, + is_training, + dropout_prob, + bias_config + ) = config + b, s_q, s_kv, h_q, h_kv, d_qk, d_v = shape + attn_bias_type, bias_shape = bias_config + window_size = None + if swa: + window_size = (s_kv // 10, 0) + output = { + "batch_size":b, + "q_seq_len":s_q, + "kv_seq_len":s_kv, + "q_heads":h_q, + "kv_heads":h_kv, + "qk_dim":d_qk, + "v_dim":d_v, + "attn_bias_type":attn_bias_type, + "attn_mask_type":attn_mask_type, + "dropout":dropout_prob, + "dtype":dtype, + "is_training":is_training, + "qkv_layout":qkv_layout, + "bias_shape":bias_shape, + "swa":swa, + "seq_desc_format":seq_desc_format, + } + if args.v: + print(f"Progress: {n+1}") + if args.v > 1: + print(output) + runner = FusedAttnBenchRunner( + b, s_q, s_kv, + h_q, h_kv, + d_qk, d_v, + attn_bias_type, + attn_mask_type, + dropout_prob, + True, + dtype, + is_training, + qkv_layout, + bias_shape, + window_size, + seq_desc_format, + ) + bench_fn = runner.bench_backward if args.bench_bwd else runner.bench_forward + bench_fn(args.warmup, args.iters, timings_dir) + + read_timings(timings_dir, rows, output, mode="fwd") + if args.bench_bwd: + read_timings(timings_dir, rows, output, mode="bwd") + + os.rmdir(timings_dir) + output_path = Path(__file__).parent + os.makedirs(output_path, exist_ok=True) + with open(output_path / "times.csv", "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=COLUMNS) + writer.writeheader() + writer.writerows(rows) + +class env_manager: + def __init__(self, fwd, bwd): + self.vals = {} + self.config = ((fwd, "FWD"), (bwd, "BWD")) + + def __enter__(self): + for flag, mode in self.config: + if flag: + self.vals[mode] = os.environ.get(f"NVTE_CK_USES_{mode}_V3") + os.environ[f"NVTE_CK_USES_{mode}_V3"] = "1" + + def __exit__(self, exc_type, exc_value, traceback): + for flag, mode in self.config: + if flag: + del os.environ[f"NVTE_CK_USES_{mode}_V3"] + if self.vals[mode]: + os.environ[f"NVTE_CK_USES_{mode}_V3"] = self.vals[mode] + +def main(args): + with env_manager(args.fwd_v3, args.bwd_v3): + benchmark_dot_product_attention_profiler(args) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--bench-bwd", action="store_true", help="Whether to bench the backwards pass as well.") + parser.add_argument("--fwd-v3", action="store_true", help="Use NVTE_CK_USES_FWD_V3=1 for AITER fwd kernels.") + parser.add_argument("--bwd-v3", action="store_true", help="Use NVTE_CK_USES_BWD_V3=1 for AITER bwd kernels.") + parser.add_argument("-v", action='count', default=0, help="Whether to include verbose debug outputs.") + parser.add_argument("--warmup", type=int, default=10, help="The number of iterations to run the kernel before logging run time. (default 10)") + parser.add_argument("--iters", type=int, default=50, help="The number of iterations to run the kernel while logging run time. (default 50)") + args = parser.parse_args() + main(args) diff --git a/benchmarks/attention/panel_app.py b/benchmarks/attention/panel_app.py new file mode 100644 index 000000000..df9b5c486 --- /dev/null +++ b/benchmarks/attention/panel_app.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# See LICENSE for license information. + +import pandas as pd +import panel as pn +import seaborn as sns +from matplotlib.figure import Figure +from jax import numpy as jnp + +pn.extension(design="material", sizing_mode="stretch_width") + +ATTRIBUTES = [ + "bias_config", + "attn_mask_type", + "qkv_layout", + "is_training", + "swa", + "dropout", + "mode", + "dtype", + "seq_desc_format", +] +CONVERTERS = { + "attn_mask_type": { + "AttnMaskType.NO_MASK": "None", + "AttnMaskType.CAUSAL_MASK": "Causal", + "AttnMaskType.PADDING_MASK": "Padding", + "AttnMaskType.PADDING_CAUSAL_MASK": "Padding Causal", + "AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK": "Causal Bottom-Right", + "AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK": "Padding Causal Bottom-Right", + }, + "attn_bias_type": { + "AttnBiasType.NO_BIAS":"None", + "AttnBiasType.POST_SCALE_BIAS":"Post-Scale Bias", + }, + "dropout": { + 0: False, + 0.1: True, + }, + "dtype": { + str(jnp.float16):"FP16", + str(jnp.bfloat16):"BF16", + }, + "qkv_layout": { + "QKVLayout.BS3HD":"BSHD-Packed", + "QKVLayout.BSHD_BS2HD":"BSHD-KV-Packed", + "QKVLayout.BSHD_BSHD_BSHD":"BSHD-Separate", + "QKVLayout.T3HD":"THD-Packed", + "QKVLayout.THD_T2HD":"THD-KV-Packed", + "QKVLayout.THD_THD_THD":"THD-Separate", + }, + "bias_shape":{ + "BiasShape._1HSS":"_1HSS", + "NaN":"None", + }, + "seq_desc_format":{ + "SeqDescFormat.Mask": "Mask", + "SeqDescFormat.SegmentIDs": "SegmentIDs", + "SeqDescFormat.Seqlens": "Seqlens", + } +} +BIAS_CONFIGS = { + ("None", "None"): "None", + ("Post-Scale Bias", "_1HSS"): "Post-Scale _1HSS" +} + +@pn.cache +def get_data(): + df = pd.read_csv("times.csv").fillna("NaN") + df["time"] *= 1000 + for key in CONVERTERS: + df[key] = df[key].map(lambda x: CONVERTERS[key][x]) + df["bias_config"] = df.apply( + lambda row: BIAS_CONFIGS[(row["attn_bias_type"], row["bias_shape"])], + axis=1 + ) + df = df.drop(columns=["attn_bias_type", "bias_shape"]) + return df + +def _selector_widgets(): + df = get_data() + return { + cat: pn.widgets.Select(name=cat, options=list(df[cat].unique())) + for cat in ATTRIBUTES + } + + +selector_widgets = _selector_widgets() + +def make_plot(hue, indep, percentile, **kwargs): + fig = Figure(figsize=(8, 8)) + ax = fig.add_subplot(111) + df = get_data() + for attr in ATTRIBUTES: + if attr not in {hue, indep}: + df = df[(df[attr]==kwargs[attr])] + + for idx in df[indep].unique(): + for jdx in df[hue].unique(): + subset = df[(df[indep]==idx) & (df[hue]==jdx)] + df[(df[indep]==idx) & (df[hue]==jdx)] = subset[subset.time < subset.time.quantile(percentile)] + + if not df.empty: + ax.set(xlabel=indep, ylabel='Time (ms)') + sns.swarmplot(ax=ax, data=df, x=indep, y="time", hue=hue, dodge=True) + return fig + +hue_selector = pn.widgets.Select(name="Hue", options=ATTRIBUTES, value="dtype") +indep_selector = pn.widgets.Select(name="Independent Variable", options=ATTRIBUTES, value="attn_mask_type") +percentile_trim = pn.widgets.FloatSlider(value=.95, start=0, end=1, step=.01, name="Percentile Trim") +bound_make_plot = pn.bind( + make_plot, + hue=hue_selector, + indep=indep_selector, + percentile=percentile_trim, + **selector_widgets, +) + +template = pn.template.BootstrapTemplate( + title='JAX Fused Attention Benchmarks', + sidebar=pn.Row( + pn.Column(hue_selector, indep_selector, percentile_trim), + pn.Column(*[selector_widgets[k] for k in selector_widgets]), + ) +) +template.main.append(pn.pane.Matplotlib(bound_make_plot, dpi=144, height=600)) +template.servable(); \ No newline at end of file diff --git a/benchmarks/attention/plotting.ipynb b/benchmarks/attention/plotting.ipynb new file mode 100644 index 000000000..cc9fd5086 --- /dev/null +++ b/benchmarks/attention/plotting.ipynb @@ -0,0 +1,898 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 50, + "id": "729dd513", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "(function(root) {\n", + " function now() {\n", + " return new Date();\n", + " }\n", + "\n", + " const force = true;\n", + " const py_version = '3.8.0'.replace('rc', '-rc.').replace('.dev', '-dev.');\n", + " const reloading = false;\n", + " const Bokeh = root.Bokeh;\n", + "\n", + " // Set a timeout for this load but only if we are not already initializing\n", + " if (typeof (root._bokeh_timeout) === \"undefined\" || (force || !root._bokeh_is_initializing)) {\n", + " root._bokeh_timeout = Date.now() + 5000;\n", + " root._bokeh_failed_load = false;\n", + " }\n", + "\n", + " function run_callbacks() {\n", + " try {\n", + " root._bokeh_onload_callbacks.forEach(function(callback) {\n", + " if (callback != null)\n", + " callback();\n", + " });\n", + " } finally {\n", + " delete root._bokeh_onload_callbacks;\n", + " }\n", + " console.debug(\"Bokeh: all callbacks have finished\");\n", + " }\n", + "\n", + " function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n", + " if (css_urls == null) css_urls = [];\n", + " if (js_urls == null) js_urls = [];\n", + " if (js_modules == null) js_modules = [];\n", + " if (js_exports == null) js_exports = {};\n", + "\n", + " root._bokeh_onload_callbacks.push(callback);\n", + "\n", + " if (root._bokeh_is_loading > 0) {\n", + " // Don't load bokeh if it is still initializing\n", + " console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n", + " return null;\n", + " } else if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n", + " // There is nothing to load\n", + " run_callbacks();\n", + " return null;\n", + " }\n", + "\n", + " function on_load() {\n", + " root._bokeh_is_loading--;\n", + " if (root._bokeh_is_loading === 0) {\n", + " console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n", + " run_callbacks()\n", + " }\n", + " }\n", + " window._bokeh_on_load = on_load\n", + "\n", + " function on_error(e) {\n", + " const src_el = e.srcElement\n", + " console.error(\"failed to load \" + (src_el.href || src_el.src));\n", + " }\n", + "\n", + " const skip = [];\n", + " if (window.requirejs) {\n", + " window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n", + " root._bokeh_is_loading = css_urls.length + 0;\n", + " } else {\n", + " root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n", + " }\n", + "\n", + " const existing_stylesheets = []\n", + " const links = document.getElementsByTagName('link')\n", + " for (let i = 0; i < links.length; i++) {\n", + " const link = links[i]\n", + " if (link.href != null) {\n", + " existing_stylesheets.push(link.href)\n", + " }\n", + " }\n", + " for (let i = 0; i < css_urls.length; i++) {\n", + " const url = css_urls[i];\n", + " const escaped = encodeURI(url)\n", + " if (existing_stylesheets.indexOf(escaped) !== -1) {\n", + " on_load()\n", + " continue;\n", + " }\n", + " const element = document.createElement(\"link\");\n", + " element.onload = on_load;\n", + " element.onerror = on_error;\n", + " element.rel = \"stylesheet\";\n", + " element.type = \"text/css\";\n", + " element.href = url;\n", + " console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n", + " document.body.appendChild(element);\n", + " } var existing_scripts = []\n", + " const scripts = document.getElementsByTagName('script')\n", + " for (let i = 0; i < scripts.length; i++) {\n", + " var script = scripts[i]\n", + " if (script.src != null) {\n", + " existing_scripts.push(script.src)\n", + " }\n", + " }\n", + " for (let i = 0; i < js_urls.length; i++) {\n", + " const url = js_urls[i];\n", + " const escaped = encodeURI(url)\n", + " if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n", + " if (!window.requirejs) {\n", + " on_load();\n", + " }\n", + " continue;\n", + " }\n", + " const element = document.createElement('script');\n", + " element.onload = on_load;\n", + " element.onerror = on_error;\n", + " element.async = false;\n", + " element.src = url;\n", + " console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n", + " document.head.appendChild(element);\n", + " }\n", + " for (let i = 0; i < js_modules.length; i++) {\n", + " const url = js_modules[i];\n", + " const escaped = encodeURI(url)\n", + " if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n", + " if (!window.requirejs) {\n", + " on_load();\n", + " }\n", + " continue;\n", + " }\n", + " var element = document.createElement('script');\n", + " element.onload = on_load;\n", + " element.onerror = on_error;\n", + " element.async = false;\n", + " element.src = url;\n", + " element.type = \"module\";\n", + " console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n", + " document.head.appendChild(element);\n", + " }\n", + " for (const name in js_exports) {\n", + " const url = js_exports[name];\n", + " const escaped = encodeURI(url)\n", + " if (skip.indexOf(escaped) >= 0 || root[name] != null) {\n", + " if (!window.requirejs) {\n", + " on_load();\n", + " }\n", + " continue;\n", + " }\n", + " var element = document.createElement('script');\n", + " element.onerror = on_error;\n", + " element.async = false;\n", + " element.type = \"module\";\n", + " console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n", + " element.textContent = `\n", + " import ${name} from \"${url}\"\n", + " window.${name} = ${name}\n", + " window._bokeh_on_load()\n", + " `\n", + " document.head.appendChild(element);\n", + " }\n", + " if (!js_urls.length && !js_modules.length) {\n", + " on_load()\n", + " }\n", + " };\n", + "\n", + " function inject_raw_css(css) {\n", + " const element = document.createElement(\"style\");\n", + " element.appendChild(document.createTextNode(css));\n", + " document.body.appendChild(element);\n", + " }\n", + "\n", + " const js_urls = [\"https://cdn.holoviz.org/panel/1.8.1/dist/bundled/reactiveesm/es-module-shims@^1.10.0/dist/es-module-shims.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-3.8.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.8.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.8.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.8.0.min.js\", \"https://cdn.holoviz.org/panel/1.8.1/dist/panel.min.js\"];\n", + " const js_modules = [];\n", + " const js_exports = {};\n", + " const css_urls = [];\n", + " const inline_js = [ function(Bokeh) {\n", + " Bokeh.set_log_level(\"info\");\n", + " },\n", + "function(Bokeh) {} // ensure no trailing comma for IE\n", + " ];\n", + "\n", + " function run_inline_js() {\n", + " if ((root.Bokeh !== undefined) || (force === true)) {\n", + " for (let i = 0; i < inline_js.length; i++) {\n", + " try {\n", + " inline_js[i].call(root, root.Bokeh);\n", + " } catch(e) {\n", + " if (!reloading) {\n", + " throw e;\n", + " }\n", + " }\n", + " }\n", + " // Cache old bokeh versions\n", + " if (Bokeh != undefined && !reloading) {\n", + " var NewBokeh = root.Bokeh;\n", + " if (Bokeh.versions === undefined) {\n", + " Bokeh.versions = new Map();\n", + " }\n", + " if (NewBokeh.version !== Bokeh.version) {\n", + " Bokeh.versions.set(NewBokeh.version, NewBokeh)\n", + " }\n", + " root.Bokeh = Bokeh;\n", + " }\n", + " } else if (Date.now() < root._bokeh_timeout) {\n", + " setTimeout(run_inline_js, 100);\n", + " } else if (!root._bokeh_failed_load) {\n", + " console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n", + " root._bokeh_failed_load = true;\n", + " }\n", + " root._bokeh_is_initializing = false\n", + " }\n", + "\n", + " function load_or_wait() {\n", + " // Implement a backoff loop that tries to ensure we do not load multiple\n", + " // versions of Bokeh and its dependencies at the same time.\n", + " // In recent versions we use the root._bokeh_is_initializing flag\n", + " // to determine whether there is an ongoing attempt to initialize\n", + " // bokeh, however for backward compatibility we also try to ensure\n", + " // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n", + " // before older versions are fully initialized.\n", + " if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n", + " // If the timeout and bokeh was not successfully loaded we reset\n", + " // everything and try loading again\n", + " root._bokeh_timeout = Date.now() + 5000;\n", + " root._bokeh_is_initializing = false;\n", + " root._bokeh_onload_callbacks = undefined;\n", + " root._bokeh_is_loading = 0\n", + " console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n", + " load_or_wait();\n", + " } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n", + " setTimeout(load_or_wait, 100);\n", + " } else {\n", + " root._bokeh_is_initializing = true\n", + " root._bokeh_onload_callbacks = []\n", + " const bokeh_loaded = root.Bokeh != null && (root.Bokeh.version === py_version || (root.Bokeh.versions !== undefined && root.Bokeh.versions.has(py_version)));\n", + " if (!reloading && !bokeh_loaded) {\n", + " if (root.Bokeh) {\n", + " root.Bokeh = undefined;\n", + " }\n", + " console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n", + " }\n", + " load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n", + " console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n", + " run_inline_js();\n", + " });\n", + " }\n", + " }\n", + " // Give older versions of the autoload script a head-start to ensure\n", + " // they initialize before we start loading newer version.\n", + " setTimeout(load_or_wait, 100)\n", + "}(window));" + ], + "application/vnd.holoviews_load.v0+json": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n const py_version = '3.8.0'.replace('rc', '-rc.').replace('.dev', '-dev.');\n const reloading = false;\n const Bokeh = root.Bokeh;\n\n // Set a timeout for this load but only if we are not already initializing\n if (typeof (root._bokeh_timeout) === \"undefined\" || (force || !root._bokeh_is_initializing)) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks;\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, js_modules, js_exports, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n if (js_modules == null) js_modules = [];\n if (js_exports == null) js_exports = {};\n\n root._bokeh_onload_callbacks.push(callback);\n\n if (root._bokeh_is_loading > 0) {\n // Don't load bokeh if it is still initializing\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n } else if (js_urls.length === 0 && js_modules.length === 0 && Object.keys(js_exports).length === 0) {\n // There is nothing to load\n run_callbacks();\n return null;\n }\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n window._bokeh_on_load = on_load\n\n function on_error(e) {\n const src_el = e.srcElement\n console.error(\"failed to load \" + (src_el.href || src_el.src));\n }\n\n const skip = [];\n if (window.requirejs) {\n window.requirejs.config({'packages': {}, 'paths': {}, 'shim': {}});\n root._bokeh_is_loading = css_urls.length + 0;\n } else {\n root._bokeh_is_loading = css_urls.length + js_urls.length + js_modules.length + Object.keys(js_exports).length;\n }\n\n const existing_stylesheets = []\n const links = document.getElementsByTagName('link')\n for (let i = 0; i < links.length; i++) {\n const link = links[i]\n if (link.href != null) {\n existing_stylesheets.push(link.href)\n }\n }\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const escaped = encodeURI(url)\n if (existing_stylesheets.indexOf(escaped) !== -1) {\n on_load()\n continue;\n }\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error;\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n } var existing_scripts = []\n const scripts = document.getElementsByTagName('script')\n for (let i = 0; i < scripts.length; i++) {\n var script = scripts[i]\n if (script.src != null) {\n existing_scripts.push(script.src)\n }\n }\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (let i = 0; i < js_modules.length; i++) {\n const url = js_modules[i];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) !== -1 || existing_scripts.indexOf(escaped) !== -1) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error;\n element.async = false;\n element.src = url;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n for (const name in js_exports) {\n const url = js_exports[name];\n const escaped = encodeURI(url)\n if (skip.indexOf(escaped) >= 0 || root[name] != null) {\n if (!window.requirejs) {\n on_load();\n }\n continue;\n }\n var element = document.createElement('script');\n element.onerror = on_error;\n element.async = false;\n element.type = \"module\";\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n element.textContent = `\n import ${name} from \"${url}\"\n window.${name} = ${name}\n window._bokeh_on_load()\n `\n document.head.appendChild(element);\n }\n if (!js_urls.length && !js_modules.length) {\n on_load()\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.holoviz.org/panel/1.8.1/dist/bundled/reactiveesm/es-module-shims@^1.10.0/dist/es-module-shims.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-3.8.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-3.8.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-3.8.0.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-3.8.0.min.js\", \"https://cdn.holoviz.org/panel/1.8.1/dist/panel.min.js\"];\n const js_modules = [];\n const js_exports = {};\n const css_urls = [];\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {} // ensure no trailing comma for IE\n ];\n\n function run_inline_js() {\n if ((root.Bokeh !== undefined) || (force === true)) {\n for (let i = 0; i < inline_js.length; i++) {\n try {\n inline_js[i].call(root, root.Bokeh);\n } catch(e) {\n if (!reloading) {\n throw e;\n }\n }\n }\n // Cache old bokeh versions\n if (Bokeh != undefined && !reloading) {\n var NewBokeh = root.Bokeh;\n if (Bokeh.versions === undefined) {\n Bokeh.versions = new Map();\n }\n if (NewBokeh.version !== Bokeh.version) {\n Bokeh.versions.set(NewBokeh.version, NewBokeh)\n }\n root.Bokeh = Bokeh;\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n }\n root._bokeh_is_initializing = false\n }\n\n function load_or_wait() {\n // Implement a backoff loop that tries to ensure we do not load multiple\n // versions of Bokeh and its dependencies at the same time.\n // In recent versions we use the root._bokeh_is_initializing flag\n // to determine whether there is an ongoing attempt to initialize\n // bokeh, however for backward compatibility we also try to ensure\n // that we do not start loading a newer (Panel>=1.0 and Bokeh>3) version\n // before older versions are fully initialized.\n if (root._bokeh_is_initializing && Date.now() > root._bokeh_timeout) {\n // If the timeout and bokeh was not successfully loaded we reset\n // everything and try loading again\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_is_initializing = false;\n root._bokeh_onload_callbacks = undefined;\n root._bokeh_is_loading = 0\n console.log(\"Bokeh: BokehJS was loaded multiple times but one version failed to initialize.\");\n load_or_wait();\n } else if (root._bokeh_is_initializing || (typeof root._bokeh_is_initializing === \"undefined\" && root._bokeh_onload_callbacks !== undefined)) {\n setTimeout(load_or_wait, 100);\n } else {\n root._bokeh_is_initializing = true\n root._bokeh_onload_callbacks = []\n const bokeh_loaded = root.Bokeh != null && (root.Bokeh.version === py_version || (root.Bokeh.versions !== undefined && root.Bokeh.versions.has(py_version)));\n if (!reloading && !bokeh_loaded) {\n if (root.Bokeh) {\n root.Bokeh = undefined;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n }\n load_libs(css_urls, js_urls, js_modules, js_exports, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n }\n // Give older versions of the autoload script a head-start to ensure\n // they initialize before we start loading newer version.\n setTimeout(load_or_wait, 100)\n}(window));" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/javascript": [ + "\n", + "if ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n", + " window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n", + "}\n", + "\n", + "\n", + " function JupyterCommManager() {\n", + " }\n", + "\n", + " JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n", + " if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n", + " var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n", + " comm_manager.register_target(comm_id, function(comm) {\n", + " comm.on_msg(msg_handler);\n", + " });\n", + " } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n", + " window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n", + " comm.onMsg = msg_handler;\n", + " });\n", + " } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n", + " google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n", + " var messages = comm.messages[Symbol.asyncIterator]();\n", + " function processIteratorResult(result) {\n", + " var message = result.value;\n", + " var content = {data: message.data, comm_id};\n", + " var buffers = []\n", + " for (var buffer of message.buffers || []) {\n", + " buffers.push(new DataView(buffer))\n", + " }\n", + " var metadata = message.metadata || {};\n", + " var msg = {content, buffers, metadata}\n", + " msg_handler(msg);\n", + " return messages.next().then(processIteratorResult);\n", + " }\n", + " return messages.next().then(processIteratorResult);\n", + " })\n", + " }\n", + " }\n", + "\n", + " JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n", + " if (comm_id in window.PyViz.comms) {\n", + " return window.PyViz.comms[comm_id];\n", + " } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n", + " var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n", + " var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n", + " if (msg_handler) {\n", + " comm.on_msg(msg_handler);\n", + " }\n", + " } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n", + " var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n", + " let retries = 0;\n", + " const open = () => {\n", + " if (comm.active) {\n", + " comm.open();\n", + " } else if (retries > 3) {\n", + " console.warn('Comm target never activated')\n", + " } else {\n", + " retries += 1\n", + " setTimeout(open, 500)\n", + " }\n", + " }\n", + " if (comm.active) {\n", + " comm.open();\n", + " } else {\n", + " setTimeout(open, 500)\n", + " }\n", + " if (msg_handler) {\n", + " comm.onMsg = msg_handler;\n", + " }\n", + " } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n", + " var comm_promise = google.colab.kernel.comms.open(comm_id)\n", + " comm_promise.then((comm) => {\n", + " window.PyViz.comms[comm_id] = comm;\n", + " if (msg_handler) {\n", + " var messages = comm.messages[Symbol.asyncIterator]();\n", + " function processIteratorResult(result) {\n", + " var message = result.value;\n", + " var content = {data: message.data};\n", + " var metadata = message.metadata || {comm_id};\n", + " var msg = {content, metadata}\n", + " msg_handler(msg);\n", + " return messages.next().then(processIteratorResult);\n", + " }\n", + " return messages.next().then(processIteratorResult);\n", + " }\n", + " })\n", + " var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n", + " return comm_promise.then((comm) => {\n", + " comm.send(data, metadata, buffers, disposeOnDone);\n", + " });\n", + " };\n", + " var comm = {\n", + " send: sendClosure\n", + " };\n", + " }\n", + " window.PyViz.comms[comm_id] = comm;\n", + " return comm;\n", + " }\n", + " window.PyViz.comm_manager = new JupyterCommManager();\n", + " \n", + "\n", + "\n", + "var JS_MIME_TYPE = 'application/javascript';\n", + "var HTML_MIME_TYPE = 'text/html';\n", + "var EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\n", + "var CLASS_NAME = 'output';\n", + "\n", + "/**\n", + " * Render data to the DOM node\n", + " */\n", + "function render(props, node) {\n", + " var div = document.createElement(\"div\");\n", + " var script = document.createElement(\"script\");\n", + " node.appendChild(div);\n", + " node.appendChild(script);\n", + "}\n", + "\n", + "/**\n", + " * Handle when a new output is added\n", + " */\n", + "function handle_add_output(event, handle) {\n", + " var output_area = handle.output_area;\n", + " var output = handle.output;\n", + " if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n", + " return\n", + " }\n", + " var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n", + " var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n", + " if (id !== undefined) {\n", + " var nchildren = toinsert.length;\n", + " var html_node = toinsert[nchildren-1].children[0];\n", + " html_node.innerHTML = output.data[HTML_MIME_TYPE];\n", + " var scripts = [];\n", + " var nodelist = html_node.querySelectorAll(\"script\");\n", + " for (var i in nodelist) {\n", + " if (nodelist.hasOwnProperty(i)) {\n", + " scripts.push(nodelist[i])\n", + " }\n", + " }\n", + "\n", + " scripts.forEach( function (oldScript) {\n", + " var newScript = document.createElement(\"script\");\n", + " var attrs = [];\n", + " var nodemap = oldScript.attributes;\n", + " for (var j in nodemap) {\n", + " if (nodemap.hasOwnProperty(j)) {\n", + " attrs.push(nodemap[j])\n", + " }\n", + " }\n", + " attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n", + " newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n", + " oldScript.parentNode.replaceChild(newScript, oldScript);\n", + " });\n", + " if (JS_MIME_TYPE in output.data) {\n", + " toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n", + " }\n", + " output_area._hv_plot_id = id;\n", + " if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n", + " window.PyViz.plot_index[id] = Bokeh.index[id];\n", + " } else {\n", + " window.PyViz.plot_index[id] = null;\n", + " }\n", + " } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n", + " var bk_div = document.createElement(\"div\");\n", + " bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n", + " var script_attrs = bk_div.children[0].attributes;\n", + " for (var i = 0; i < script_attrs.length; i++) {\n", + " toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n", + " }\n", + " // store reference to server id on output_area\n", + " output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n", + " }\n", + "}\n", + "\n", + "/**\n", + " * Handle when an output is cleared or removed\n", + " */\n", + "function handle_clear_output(event, handle) {\n", + " var id = handle.cell.output_area._hv_plot_id;\n", + " var server_id = handle.cell.output_area._bokeh_server_id;\n", + " if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n", + " var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n", + " if (server_id !== null) {\n", + " comm.send({event_type: 'server_delete', 'id': server_id});\n", + " return;\n", + " } else if (comm !== null) {\n", + " comm.send({event_type: 'delete', 'id': id});\n", + " }\n", + " delete PyViz.plot_index[id];\n", + " if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n", + " var doc = window.Bokeh.index[id].model.document\n", + " doc.clear();\n", + " const i = window.Bokeh.documents.indexOf(doc);\n", + " if (i > -1) {\n", + " window.Bokeh.documents.splice(i, 1);\n", + " }\n", + " }\n", + "}\n", + "\n", + "/**\n", + " * Handle kernel restart event\n", + " */\n", + "function handle_kernel_cleanup(event, handle) {\n", + " delete PyViz.comms[\"hv-extension-comm\"];\n", + " window.PyViz.plot_index = {}\n", + "}\n", + "\n", + "/**\n", + " * Handle update_display_data messages\n", + " */\n", + "function handle_update_output(event, handle) {\n", + " handle_clear_output(event, {cell: {output_area: handle.output_area}})\n", + " handle_add_output(event, handle)\n", + "}\n", + "\n", + "function register_renderer(events, OutputArea) {\n", + " function append_mime(data, metadata, element) {\n", + " // create a DOM node to render to\n", + " var toinsert = this.create_output_subarea(\n", + " metadata,\n", + " CLASS_NAME,\n", + " EXEC_MIME_TYPE\n", + " );\n", + " this.keyboard_manager.register_events(toinsert);\n", + " // Render to node\n", + " var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n", + " render(props, toinsert[0]);\n", + " element.append(toinsert);\n", + " return toinsert\n", + " }\n", + "\n", + " events.on('output_added.OutputArea', handle_add_output);\n", + " events.on('output_updated.OutputArea', handle_update_output);\n", + " events.on('clear_output.CodeCell', handle_clear_output);\n", + " events.on('delete.Cell', handle_clear_output);\n", + " events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n", + "\n", + " OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n", + " safe: true,\n", + " index: 0\n", + " });\n", + "}\n", + "\n", + "if (window.Jupyter !== undefined) {\n", + " try {\n", + " var events = require('base/js/events');\n", + " var OutputArea = require('notebook/js/outputarea').OutputArea;\n", + " if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n", + " register_renderer(events, OutputArea);\n", + " }\n", + " } catch(err) {\n", + " }\n", + "}\n" + ], + "application/vnd.holoviews_load.v0+json": "\nif ((window.PyViz === undefined) || (window.PyViz instanceof HTMLElement)) {\n window.PyViz = {comms: {}, comm_status:{}, kernels:{}, receivers: {}, plot_index: []}\n}\n\n\n function JupyterCommManager() {\n }\n\n JupyterCommManager.prototype.register_target = function(plot_id, comm_id, msg_handler) {\n if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n comm_manager.register_target(comm_id, function(comm) {\n comm.on_msg(msg_handler);\n });\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n window.PyViz.kernels[plot_id].registerCommTarget(comm_id, function(comm) {\n comm.onMsg = msg_handler;\n });\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n google.colab.kernel.comms.registerTarget(comm_id, (comm) => {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data, comm_id};\n var buffers = []\n for (var buffer of message.buffers || []) {\n buffers.push(new DataView(buffer))\n }\n var metadata = message.metadata || {};\n var msg = {content, buffers, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n })\n }\n }\n\n JupyterCommManager.prototype.get_client_comm = function(plot_id, comm_id, msg_handler) {\n if (comm_id in window.PyViz.comms) {\n return window.PyViz.comms[comm_id];\n } else if (window.comm_manager || ((window.Jupyter !== undefined) && (Jupyter.notebook.kernel != null))) {\n var comm_manager = window.comm_manager || Jupyter.notebook.kernel.comm_manager;\n var comm = comm_manager.new_comm(comm_id, {}, {}, {}, comm_id);\n if (msg_handler) {\n comm.on_msg(msg_handler);\n }\n } else if ((plot_id in window.PyViz.kernels) && (window.PyViz.kernels[plot_id])) {\n var comm = window.PyViz.kernels[plot_id].connectToComm(comm_id);\n let retries = 0;\n const open = () => {\n if (comm.active) {\n comm.open();\n } else if (retries > 3) {\n console.warn('Comm target never activated')\n } else {\n retries += 1\n setTimeout(open, 500)\n }\n }\n if (comm.active) {\n comm.open();\n } else {\n setTimeout(open, 500)\n }\n if (msg_handler) {\n comm.onMsg = msg_handler;\n }\n } else if (typeof google != 'undefined' && google.colab.kernel != null) {\n var comm_promise = google.colab.kernel.comms.open(comm_id)\n comm_promise.then((comm) => {\n window.PyViz.comms[comm_id] = comm;\n if (msg_handler) {\n var messages = comm.messages[Symbol.asyncIterator]();\n function processIteratorResult(result) {\n var message = result.value;\n var content = {data: message.data};\n var metadata = message.metadata || {comm_id};\n var msg = {content, metadata}\n msg_handler(msg);\n return messages.next().then(processIteratorResult);\n }\n return messages.next().then(processIteratorResult);\n }\n })\n var sendClosure = (data, metadata, buffers, disposeOnDone) => {\n return comm_promise.then((comm) => {\n comm.send(data, metadata, buffers, disposeOnDone);\n });\n };\n var comm = {\n send: sendClosure\n };\n }\n window.PyViz.comms[comm_id] = comm;\n return comm;\n }\n window.PyViz.comm_manager = new JupyterCommManager();\n \n\n\nvar JS_MIME_TYPE = 'application/javascript';\nvar HTML_MIME_TYPE = 'text/html';\nvar EXEC_MIME_TYPE = 'application/vnd.holoviews_exec.v0+json';\nvar CLASS_NAME = 'output';\n\n/**\n * Render data to the DOM node\n */\nfunction render(props, node) {\n var div = document.createElement(\"div\");\n var script = document.createElement(\"script\");\n node.appendChild(div);\n node.appendChild(script);\n}\n\n/**\n * Handle when a new output is added\n */\nfunction handle_add_output(event, handle) {\n var output_area = handle.output_area;\n var output = handle.output;\n if ((output.data == undefined) || (!output.data.hasOwnProperty(EXEC_MIME_TYPE))) {\n return\n }\n var id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n var toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n if (id !== undefined) {\n var nchildren = toinsert.length;\n var html_node = toinsert[nchildren-1].children[0];\n html_node.innerHTML = output.data[HTML_MIME_TYPE];\n var scripts = [];\n var nodelist = html_node.querySelectorAll(\"script\");\n for (var i in nodelist) {\n if (nodelist.hasOwnProperty(i)) {\n scripts.push(nodelist[i])\n }\n }\n\n scripts.forEach( function (oldScript) {\n var newScript = document.createElement(\"script\");\n var attrs = [];\n var nodemap = oldScript.attributes;\n for (var j in nodemap) {\n if (nodemap.hasOwnProperty(j)) {\n attrs.push(nodemap[j])\n }\n }\n attrs.forEach(function(attr) { newScript.setAttribute(attr.name, attr.value) });\n newScript.appendChild(document.createTextNode(oldScript.innerHTML));\n oldScript.parentNode.replaceChild(newScript, oldScript);\n });\n if (JS_MIME_TYPE in output.data) {\n toinsert[nchildren-1].children[1].textContent = output.data[JS_MIME_TYPE];\n }\n output_area._hv_plot_id = id;\n if ((window.Bokeh !== undefined) && (id in Bokeh.index)) {\n window.PyViz.plot_index[id] = Bokeh.index[id];\n } else {\n window.PyViz.plot_index[id] = null;\n }\n } else if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n var bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n var script_attrs = bk_div.children[0].attributes;\n for (var i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].childNodes[1].setAttribute(script_attrs[i].name, script_attrs[i].value);\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n}\n\n/**\n * Handle when an output is cleared or removed\n */\nfunction handle_clear_output(event, handle) {\n var id = handle.cell.output_area._hv_plot_id;\n var server_id = handle.cell.output_area._bokeh_server_id;\n if (((id === undefined) || !(id in PyViz.plot_index)) && (server_id !== undefined)) { return; }\n var comm = window.PyViz.comm_manager.get_client_comm(\"hv-extension-comm\", \"hv-extension-comm\", function () {});\n if (server_id !== null) {\n comm.send({event_type: 'server_delete', 'id': server_id});\n return;\n } else if (comm !== null) {\n comm.send({event_type: 'delete', 'id': id});\n }\n delete PyViz.plot_index[id];\n if ((window.Bokeh !== undefined) & (id in window.Bokeh.index)) {\n var doc = window.Bokeh.index[id].model.document\n doc.clear();\n const i = window.Bokeh.documents.indexOf(doc);\n if (i > -1) {\n window.Bokeh.documents.splice(i, 1);\n }\n }\n}\n\n/**\n * Handle kernel restart event\n */\nfunction handle_kernel_cleanup(event, handle) {\n delete PyViz.comms[\"hv-extension-comm\"];\n window.PyViz.plot_index = {}\n}\n\n/**\n * Handle update_display_data messages\n */\nfunction handle_update_output(event, handle) {\n handle_clear_output(event, {cell: {output_area: handle.output_area}})\n handle_add_output(event, handle)\n}\n\nfunction register_renderer(events, OutputArea) {\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n var toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n var props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[0]);\n element.append(toinsert);\n return toinsert\n }\n\n events.on('output_added.OutputArea', handle_add_output);\n events.on('output_updated.OutputArea', handle_update_output);\n events.on('clear_output.CodeCell', handle_clear_output);\n events.on('delete.Cell', handle_clear_output);\n events.on('kernel_ready.Kernel', handle_kernel_cleanup);\n\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n safe: true,\n index: 0\n });\n}\n\nif (window.Jupyter !== undefined) {\n try {\n var events = require('base/js/events');\n var OutputArea = require('notebook/js/outputarea').OutputArea;\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n } catch(err) {\n }\n}\n" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "
\n", + "
\n", + "" + ] + }, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "b0d583bf-d55a-4b5d-bd24-719874acb0e3" + } + }, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import panel as pn\n", + "pn.extension()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "9546da1e-b82c-4748-985b-a9a92ecb1dc7", + "metadata": {}, + "outputs": [], + "source": [ + "@pn.cache\n", + "def get_data():\n", + " return pd.read_csv(\"output_main.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "bba1ffa5-851f-41cd-9c1c-391b5b9b9485", + "metadata": {}, + "outputs": [], + "source": [ + "df=get_data().fillna(\"NaN\")" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "id": "11557b76-3156-45f5-8241-019f4fa0933a", + "metadata": {}, + "outputs": [], + "source": [ + "VARIABLES = [\"dtype\", \"seq_desc_format\"]\n", + "def _selector_widgets():\n", + " for cat in (\"attn_bias_type\", \"attn_mask_type\", \"qkv_layout\", \"bias_shape\", \"is_training\",\"swa\", \"dropout\"):\n", + " yield pn.widgets.Select(name=cat, options=list(df[cat].unique()))\n", + "\n", + "\n", + "selector_widgets = list(_selector_widgets())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "9bcaaf1c-2759-4ecf-a400-1f39897d341b", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn as sns" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "558f449f-dd9c-47b3-bef2-c5a1793b8777", + "metadata": {}, + "outputs": [], + "source": [ + "def make_plot(attn_bias_type, attn_mask_type, qkv_layout, bias_shape, is_training, swa, dropout):\n", + " subset = df[\n", + " (\n", + " (df[\"attn_bias_type\"]==attn_bias_type) &\n", + " (df[\"attn_mask_type\"]==attn_mask_type) &\n", + " (df[\"qkv_layout\"]==qkv_layout) &\n", + " (df[\"bias_shape\"]==bias_shape) &\n", + " (df[\"is_training\"]==is_training) &\n", + " (df[\"swa\"]==swa) &\n", + " (df[\"dropout\"]==dropout)\n", + " )\n", + " ]\n", + " subset = subset[subset.time < subset.time.quantile(.95)]\n", + " subset = subset[subset.time > subset.time.quantile(.05)]\n", + " if not subset.empty:\n", + " return sns.catplot(data=subset, kind=\"violin\", x=\"seq_desc_format\", y=\"time\", hue=\"dtype\", split=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "921e4c12-f4b3-4393-a569-3e88fda9c0fe", + "metadata": {}, + "outputs": [ + { + "data": {}, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.holoviews_exec.v0+json": "", + "text/html": [ + "
\n", + "
\n", + "
\n", + "" + ], + "text/plain": [ + "Row\n", + " [0] Column\n", + " [0] Select(name='attn_bias_type', options=['AttnBiasType.NO_BIAS', ...], value='AttnBiasType.NO_BIAS')\n", + " [1] Select(name='attn_mask_type', options=['AttnMaskType.NO_MASK', ...], value='AttnMaskType.NO_MASK')\n", + " [2] Select(name='qkv_layout', options=['QKVLayout.BS3HD', ...], value='QKVLayout.BS3HD')\n", + " [3] Select(name='bias_shape', options=['NaN', 'BiasShape._1HSS']...], value='NaN')\n", + " [4] Select(name='is_training', options=[True, False], value=True)\n", + " [5] Select(name='swa', options=[True, False], value=True)\n", + " [6] Select(name='dropout', options=[0.0, 0.1], value=0.0)\n", + " [1] ParamFunction(function, _pane=Str, defer_load=False)" + ] + }, + "execution_count": 115, + "metadata": { + "application/vnd.holoviews_exec.v0+json": { + "id": "a3c3ac47-984b-4a04-ab5f-442bb10f1822" + } + }, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "bound_make_plot = pn.bind(make_plot, \n", + " attn_bias_type=selector_widgets[0],\n", + " attn_mask_type=selector_widgets[1],\n", + " qkv_layout=selector_widgets[2],\n", + " bias_shape=selector_widgets[3],\n", + " is_training=selector_widgets[4],\n", + " swa=selector_widgets[5],\n", + " dropout=selector_widgets[6],\n", + ")\n", + "pn.Row(\n", + " pn.Column(*selector_widgets),\n", + " bound_make_plot\n", + ").servable()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60eba2e4-45c7-4d05-a090-802aaf0d494a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1299c7d5-725c-4965-bceb-4ba7a2932322", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06914d05-ec6f-4bff-82ed-6ec2dbdefd51", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/benchmarks/attention/requirements.txt b/benchmarks/attention/requirements.txt new file mode 100644 index 000000000..48bec5c1d --- /dev/null +++ b/benchmarks/attention/requirements.txt @@ -0,0 +1,4 @@ +seaborn +panel +watchfiles +pandas \ No newline at end of file diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp index 2b717ace0..b6d95d542 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp @@ -4,6 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ +#include #include #include #include @@ -415,6 +416,12 @@ void log_bwd_config(const char* func_name, } +void dump_bwd_timings(const char* dump_path, float average_runtime, hipStream_t stream){ + std::ofstream file; + file.open(std::string(dump_path) + "aiter-bwd-timings.txt", std::ios_base::app); + file << average_runtime << "\n"; +} + hipError_t ck_attn_bwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, @@ -489,9 +496,10 @@ hipError_t ck_attn_bwd( if (env_p != nullptr && std::string(env_p) == "1") ck_fused_attn_log_config = true; } + const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, false, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; ck_tile::index_t shape_seqlen_q = seqlen_q; ck_tile::index_t shape_seqlen_k = seqlen_k; @@ -651,6 +659,9 @@ hipError_t ck_attn_bwd( uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt); + if(dump_path){ + dump_bwd_timings(dump_path, average_runtime, stream); + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); @@ -841,8 +852,9 @@ hipError_t ck_attn_varlen_bwd( if (env_p != nullptr && std::string(env_p) == "1") ck_fused_attn_log_config = true; } + const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, false, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; std::string data_type_str = get_data_type_str(dtype); @@ -996,6 +1008,9 @@ hipError_t ck_attn_varlen_bwd( uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt); + if(dump_path){ + dump_bwd_timings(dump_path, average_runtime, stream); + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); diff --git a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp index c87a3db6c..b50c50098 100644 --- a/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp +++ b/transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp @@ -4,6 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ +#include #include #include #include @@ -107,6 +108,12 @@ void log_fwd_config(const char* func_name, } } +void dump_fwd_timings(const char* dump_path, float average_runtime, hipStream_t stream){ + std::ofstream file; + file.open(std::string(dump_path) + "aiter-fwd-timings.txt", std::ios_base::app); + file << average_runtime << "\n"; +} + hipError_t ck_attn_fwd( DType dtype, uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d_qk, uint64_t d_v, uint64_t bias_b, uint64_t bias_h, @@ -166,9 +173,9 @@ hipError_t ck_attn_fwd( if (env_p != nullptr && std::string(env_p) == "1") ck_fused_attn_log_config = true; } - + const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, false, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; std::string data_type_str = get_data_type_str(dtype); @@ -272,6 +279,9 @@ hipError_t ck_attn_fwd( bias_type, has_lse, uses_fwd_v3); + if(dump_path){ + dump_fwd_timings(dump_path, average_runtime, stream); + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass."); @@ -338,9 +348,9 @@ hipError_t ck_attn_varlen_fwd( if (env_p != nullptr && std::string(env_p) == "1") ck_fused_attn_log_config = true; } - + const char* dump_path = std::getenv("NVTE_DUMP_AITER_RT"); // print kernel name on verbose mode - ck_tile::stream_config stream_config{stream, false, ck_fused_attn_log_config}; + ck_tile::stream_config stream_config{stream, dump_path!=nullptr, ck_fused_attn_log_config}; std::string data_type_str = get_data_type_str(dtype); @@ -447,6 +457,9 @@ hipError_t ck_attn_varlen_fwd( bias_type, has_lse, uses_fwd_v3); + if(dump_path){ + dump_fwd_timings(dump_path, average_runtime, stream); + } if(average_runtime < 0){ //TODO: better error out system throw std::runtime_error("fused attn configs not supported in ck_fused_attn fwd pass.");