diff --git a/nbs/dataset.ipynb b/nbs/dataset.ipynb index 7eadb5d..c0c8a04 100644 --- a/nbs/dataset.ipynb +++ b/nbs/dataset.ipynb @@ -696,7 +696,7 @@ "def to_pandas(self: Dataset) -> \"pd.DataFrame\":\n", " \"\"\"Convert dataset to pandas DataFrame.\"\"\"\n", " import pandas as pd\n", - " \n", + "\n", " # Make sure we have data\n", " if not self._entries:\n", " self.load()\n", diff --git a/nbs/project/experiments.ipynb b/nbs/project/experiments.ipynb index 06e43b0..c93b742 100644 --- a/nbs/project/experiments.ipynb +++ b/nbs/project/experiments.ipynb @@ -28,6 +28,7 @@ "from tqdm import tqdm\n", "from functools import wraps\n", "import asyncio\n", + "from tqdm import tqdm\n", "\n", "import typing as t\n", "\n", @@ -120,7 +121,7 @@ { "data": { "text/plain": [ - "Project(name='SuperMe')" + "Project(name='yann-lecun-wisdom')" ] }, "execution_count": null, @@ -131,13 +132,13 @@ "source": [ "import os\n", "\n", - "RAGAS_APP_TOKEN = \"apt.47bd-c55e4a45b27c-02f8-8446-1441f09b-651a8\"\n", + "RAGAS_APP_TOKEN = \"apt.4e81-99ed3e6efdfe-bb9c-88e3-f88438f5-5dfef\"\n", "RAGAS_API_BASE_URL = \"https://api.dev.app.ragas.io\"\n", "\n", "os.environ[\"RAGAS_APP_TOKEN\"] = RAGAS_APP_TOKEN\n", "os.environ[\"RAGAS_API_BASE_URL\"] = RAGAS_API_BASE_URL\n", "\n", - "PROJECT_ID = \"a6ccabe0-7b8d-4866-98af-f167a36b94ff\"\n", + "PROJECT_ID = \"919a4d42-aaf2-45cd-badd-152249788bfa\"\n", "p = Project(project_id=PROJECT_ID)\n", "p" ] @@ -252,13 +253,13 @@ "source": [ "# | export\n", "@patch\n", - "def get_experiment(self: Project, dataset_name: str, model) -> Dataset:\n", + "def get_experiment(self: Project, experiment_name: str, model) -> Dataset:\n", " \"\"\"Get an existing dataset by name.\"\"\"\n", " # Search for dataset with given name\n", " sync_version = async_to_sync(self._ragas_api_client.get_experiment_by_name)\n", " exp_info = sync_version(\n", " project_id=self.project_id,\n", - " experiment_name=dataset_name\n", + " experiment_name=experiment_name\n", " )\n", "\n", " # Return Dataset instance\n", @@ -426,7 +427,8 @@ " wrapped_experiment.__setattr__(\"run_async\", run_async)\n", " return t.cast(ExperimentProtocol, wrapped_experiment)\n", "\n", - " return decorator" + " return decorator\n", + "\n" ] }, { @@ -540,6 +542,996 @@ " return decorator" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# | export\n", + "\n", + "import logging\n", + "from ragas_experimental.utils import plot_experiments_as_subplots\n", + "\n", + "@patch\n", + "def compare_and_plot(self: Project, experiment_names: t.List[str], model: t.Type[BaseModel], metric_names: t.List[str]):\n", + " \"\"\"Compare multiple experiments and generate a plot.\n", + "\n", + " Args:\n", + " experiment_names: List of experiment IDs to compare\n", + " model: Model class defining the experiment structure\n", + " \"\"\"\n", + " results = {}\n", + " for experiment_name in tqdm(experiment_names, desc=\"Fetching experiments\"):\n", + " experiment = self.get_experiment(experiment_name, model)\n", + " experiment.load()\n", + " results[experiment_name] = {}\n", + " for row in experiment:\n", + " for metric in metric_names:\n", + " if metric not in results[experiment_name]:\n", + " results[experiment_name][metric] = []\n", + " if hasattr(row, metric):\n", + " results[experiment_name][metric].append(getattr(row, metric))\n", + " else:\n", + " results[metric].append(None)\n", + " logging.warning(f\"Metric {metric} not found in row: {row}\")\n", + " \n", + " \n", + " \n", + " fig = plot_experiments_as_subplots(results,experiment_ids=experiment_names)\n", + " fig.show()\n", + " \n", + " \n", + " \n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ragas_experimental import BaseModel\n", + "\n", + "class TestDataset(BaseModel):\n", + " question: str\n", + " citations: list[str]\n", + " grading_notes: str\n", + " \n", + "\n", + "class ExperimentModel(TestDataset):\n", + " response: str\n", + " score: str\n", + " score_reason: str\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fetching experiments: 100%|██████████| 2/2 [00:06<00:00, 3.01s/it]\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hoverinfo": "text", + "hovertext": [ + "Fail: 30.0%", + "Fail: 33.3%" + ], + "marker": { + "color": "#e11185" + }, + "name": "Fail", + "showlegend": false, + "type": "bar", + "width": 0.5, + "x": [ + "Exp 1", + "Exp 2" + ], + "xaxis": "x", + "y": [ + 30, + 33.33333333333333 + ], + "yaxis": "y" + }, + { + "hoverinfo": "text", + "hovertext": [ + "Pass: 70.0%", + "Pass: 66.7%" + ], + "marker": { + "color": "#1a1dc9" + }, + "name": "Pass", + "showlegend": false, + "type": "bar", + "width": 0.5, + "x": [ + "Exp 1", + "Exp 2" + ], + "xaxis": "x", + "y": [ + 70, + 66.66666666666666 + ], + "yaxis": "y" + } + ], + "layout": { + "annotations": [ + { + "font": { + "size": 16 + }, + "showarrow": false, + "text": "Score Comparison", + "x": 0.5, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + } + ], + "barmode": "stack", + "height": 400, + "hovermode": "closest", + "margin": { + "b": 50, + "l": 50, + "r": 50, + "t": 80 + }, + "plot_bgcolor": "white", + "showlegend": false, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Experiment Comparison by Metrics" + }, + "width": 400, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "showgrid": false, + "showline": true, + "tickangle": 0, + "title": { + "text": "Experiments" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "gridcolor": "lightgray", + "linecolor": "black", + "linewidth": 1, + "range": [ + 0, + 105 + ], + "showgrid": true, + "showline": true, + "ticksuffix": "%", + "title": { + "text": "Percentage (%)" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "p.compare_and_plot(\n", + " experiment_names=[\"xenodochial_hoare\",\"confident_liskov\"],\n", + " model=ExperimentModel,\n", + " metric_names=[\"score\"]\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nbs/utils.ipynb b/nbs/utils.ipynb index 8a2a2b8..b71a12d 100644 --- a/nbs/utils.ipynb +++ b/nbs/utils.ipynb @@ -71,6 +71,1187 @@ " return sync_wrapper" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "import numpy as np\n", + "import plotly.graph_objects as go\n", + "from plotly.subplots import make_subplots\n", + "from collections import Counter\n", + "\n", + "def plot_experiments_as_subplots(data, experiment_names=None):\n", + " \"\"\"\n", + " Plot metrics comparison across experiments.\n", + " \n", + " Parameters:\n", + " - data: Dictionary with experiment_names as keys and metrics as nested dictionaries\n", + " - experiment_names: List of experiment IDs in the order they should be plotted\n", + " \n", + " Returns:\n", + " - Plotly figure object with horizontal subplots\n", + " \"\"\"\n", + " if experiment_names is None:\n", + " experiment_names = list(data.keys())\n", + " \n", + " exp_short_names = [f\"{name[:10]}..\"for name in experiment_names]\n", + " #TODO: need better solution to identify what type of metric it is\n", + " # this is a temporary solution\n", + " # Identify metrics and their types\n", + " metrics = {}\n", + " for exp_id in experiment_names:\n", + " for metric_name, values in data[exp_id].items():\n", + " # Classify metric type (discrete or numerical)\n", + " if metric_name not in metrics:\n", + " # Check first value to determine type\n", + " is_discrete = isinstance(values[0], str)\n", + " metrics[metric_name] = {\"type\": \"discrete\" if is_discrete else \"numerical\"}\n", + " \n", + " # Create horizontal subplots (one for each metric)\n", + " fig = make_subplots(\n", + " rows=1, \n", + " cols=len(metrics),\n", + " subplot_titles=[f\"{metric.capitalize()} Comparison\" for metric in metrics.keys()],\n", + " horizontal_spacing=0.1\n", + " )\n", + " \n", + " # Process metrics and add traces\n", + " col_idx = 1\n", + " for metric_name, metric_info in metrics.items():\n", + " if metric_info[\"type\"] == \"discrete\":\n", + " # For discrete metrics (like pass/fail)\n", + " categories = set()\n", + " for exp_id in experiment_names:\n", + " count = Counter(data[exp_id][metric_name])\n", + " categories.update(count.keys())\n", + " \n", + " categories = sorted(list(categories))\n", + " \n", + " for category in categories:\n", + " y_values = []\n", + " for exp_id in experiment_names:\n", + " count = Counter(data[exp_id][metric_name])\n", + " total = sum(count.values())\n", + " percentage = (count.get(category, 0) / total) * 100\n", + " y_values.append(percentage)\n", + " \n", + " # Assign colors based on category\n", + " \n", + " # Generate consistent color for other categories\n", + " import hashlib\n", + " hash_obj = hashlib.md5(category.encode())\n", + " hash_hex = hash_obj.hexdigest()\n", + " color = f\"#{hash_hex[:6]}\"\n", + " \n", + " fig.add_trace(\n", + " go.Bar(\n", + " x=exp_short_names,\n", + " y=y_values,\n", + " name=category.capitalize(),\n", + " marker_color=color,\n", + " width=0.5, # Narrower bars\n", + " hoverinfo='text',\n", + " hovertext=[f\"{category.capitalize()}: {x:.1f}%\" for x in y_values],\n", + " showlegend=False # Remove legend\n", + " ),\n", + " row=1, col=col_idx\n", + " )\n", + " \n", + " else: # Numerical metrics\n", + " normalized_values = []\n", + " original_values = []\n", + " \n", + " for exp_id in experiment_names:\n", + " values = data[exp_id][metric_name]\n", + " mean_val = np.mean(values)\n", + " original_values.append(mean_val)\n", + " \n", + " # Normalize to 0-100 scale\n", + " min_val = np.min(values)\n", + " max_val = np.max(values)\n", + " normalized = ((mean_val - min_val) / (max_val - min_val)) * 100\n", + " normalized_values.append(normalized)\n", + " \n", + " # Add bar chart for numerical data\n", + " fig.add_trace(\n", + " go.Bar(\n", + " x=exp_short_names,\n", + " y=normalized_values,\n", + " name=metric_name.capitalize(),\n", + " marker_color='#2E8B57', # Sea green\n", + " width=0.5, # Narrower bars\n", + " hoverinfo='text',\n", + " hovertext=[f\"{metric_name.capitalize()} Mean: {val:.2f} (Normalized: {norm:.1f}%)\" \n", + " for val, norm in zip(original_values, normalized_values)],\n", + " showlegend=False # Remove legend\n", + " ),\n", + " row=1, col=col_idx\n", + " )\n", + " \n", + " # Update axes for each subplot\n", + " fig.update_yaxes(\n", + " title_text=\"Percentage (%)\" if metric_info[\"type\"] == \"discrete\" else \"Normalized Value\",\n", + " range=[0, 105], # Leave room for labels at the top\n", + " ticksuffix=\"%\",\n", + " showgrid=True,\n", + " gridcolor='lightgray',\n", + " showline=True,\n", + " linewidth=1,\n", + " linecolor='black',\n", + " row=1, col=col_idx\n", + " )\n", + " \n", + " fig.update_xaxes(\n", + " title_text=\"Experiments\",\n", + " tickangle=-45,\n", + " showgrid=False,\n", + " showline=True,\n", + " linewidth=1,\n", + " linecolor='black',\n", + " row=1, col=col_idx\n", + " )\n", + " \n", + " col_idx += 1\n", + " \n", + " # Update layout for the entire figure\n", + " fig.update_layout(\n", + " title='Experiment Comparison by Metrics',\n", + " barmode='stack' if any(metric_info[\"type\"] == \"discrete\" for metric_info in metrics.values()) else 'group',\n", + " height=400, # Reduced height\n", + " width=250 * len(metrics) + 150, # Adjust width based on number of metrics\n", + " showlegend=False, # Remove legend\n", + " margin=dict(t=80, b=50, l=50, r=50),\n", + " plot_bgcolor='white',\n", + " hovermode='closest'\n", + " )\n", + " \n", + " return fig\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hoverinfo": "text", + "hovertext": [ + "Fail: 50.0%", + "Fail: 33.3%" + ], + "marker": { + "color": "#e11185" + }, + "name": "Fail", + "showlegend": false, + "type": "bar", + "width": 0.5, + "x": [ + "my-first-e..", + "my-second-.." + ], + "xaxis": "x", + "y": [ + 50, + 33.33333333333333 + ], + "yaxis": "y" + }, + { + "hoverinfo": "text", + "hovertext": [ + "Pass: 50.0%", + "Pass: 66.7%" + ], + "marker": { + "color": "#1a1dc9" + }, + "name": "Pass", + "showlegend": false, + "type": "bar", + "width": 0.5, + "x": [ + "my-first-e..", + "my-second-.." + ], + "xaxis": "x", + "y": [ + 50, + 66.66666666666666 + ], + "yaxis": "y" + }, + { + "hoverinfo": "text", + "hovertext": [ + "Positivity Mean: 5.67 (Normalized: 51.9%)", + "Positivity Mean: 6.23 (Normalized: 52.9%)" + ], + "marker": { + "color": "#2E8B57" + }, + "name": "Positivity", + "showlegend": false, + "type": "bar", + "width": 0.5, + "x": [ + "my-first-e..", + "my-second-.." + ], + "xaxis": "x2", + "y": [ + 51.85185185185186, + 52.916666666666664 + ], + "yaxis": "y2" + } + ], + "layout": { + "annotations": [ + { + "font": { + "size": 16 + }, + "showarrow": false, + "text": "Correctness Comparison", + "x": 0.225, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, + { + "font": { + "size": 16 + }, + "showarrow": false, + "text": "Positivity Comparison", + "x": 0.775, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + } + ], + "barmode": "stack", + "height": 400, + "hovermode": "closest", + "margin": { + "b": 50, + "l": 50, + "r": 50, + "t": 80 + }, + "plot_bgcolor": "white", + "showlegend": false, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermap": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermap" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Experiment Comparison by Metrics" + }, + "width": 650, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 0.45 + ], + "linecolor": "black", + "linewidth": 1, + "showgrid": false, + "showline": true, + "tickangle": -45, + "title": { + "text": "Experiments" + } + }, + "xaxis2": { + "anchor": "y2", + "domain": [ + 0.55, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "showgrid": false, + "showline": true, + "tickangle": -45, + "title": { + "text": "Experiments" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "gridcolor": "lightgray", + "linecolor": "black", + "linewidth": 1, + "range": [ + 0, + 105 + ], + "showgrid": true, + "showline": true, + "ticksuffix": "%", + "title": { + "text": "Percentage (%)" + } + }, + "yaxis2": { + "anchor": "x2", + "domain": [ + 0, + 1 + ], + "gridcolor": "lightgray", + "linecolor": "black", + "linewidth": 1, + "range": [ + 0, + 105 + ], + "showgrid": true, + "showline": true, + "ticksuffix": "%", + "title": { + "text": "Normalized Value" + } + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Provided sample data\n", + "data = {\n", + " 'my-first-experiment': {\n", + " 'correctness': ['pass', 'fail', 'fail', 'fail', 'fail', 'pass', 'fail', \n", + " 'pass', 'fail', 'fail', 'fail', 'pass', 'pass', 'pass', \n", + " 'pass', 'fail', 'pass', 'fail', 'pass', 'pass', 'pass', \n", + " 'fail', 'fail', 'pass', 'pass', 'pass', 'pass', 'fail', \n", + " 'fail', 'fail'],\n", + " 'positivity': [\n", + " 7, 3, 8, 2, 4, 9, 3, 8, 7, 6, \n", + " 9, 7, 8, 10, 1, 8, 9, 4, 8, 1, \n", + " 9, 3, 2, 1, 1, 9, 8, 4, 3, 8\n", + " ]\n", + " },\n", + " 'my-second-experiment': {\n", + " 'correctness': ['pass', 'pass', 'pass', 'fail', 'pass', 'pass', 'pass', \n", + " 'pass', 'fail', 'pass', 'pass', 'pass', 'fail', 'pass', \n", + " 'pass', 'pass', 'pass', 'pass', 'pass', 'pass', 'fail', \n", + " 'pass', 'fail', 'fail', 'pass', 'fail', 'pass', 'fail', \n", + " 'fail', 'fail'],\n", + " 'positivity': [\n", + " 6, 8, 7, 3, 8, 7, 9, 8, 2, 7, \n", + " 6, 8, 4, 9, 8, 7, 10, 9, 8, 9, \n", + " 3, 8, 4, 2, 7, 3, 8, 4, 2, 3\n", + " ]\n", + " }\n", + "}\n", + "\n", + "\n", + "# Plot the comparison\n", + "experiment_names = ['my-first-experiment', 'my-second-experiment',]\n", + "fig = plot_experiments_as_subplots(data, experiment_names)\n", + "\n", + "# Show the figure\n", + "fig.show()\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/ragas_experimental/_modidx.py b/ragas_experimental/_modidx.py index 0ab2f12..54ae041 100644 --- a/ragas_experimental/_modidx.py +++ b/ragas_experimental/_modidx.py @@ -599,6 +599,8 @@ 'ragas_experimental/project/experiments.py'), 'ragas_experimental.project.experiments.ExperimentProtocol.run_async': ( 'project/experiments.html#experimentprotocol.run_async', 'ragas_experimental/project/experiments.py'), + 'ragas_experimental.project.experiments.Project.compare_and_plot': ( 'project/experiments.html#project.compare_and_plot', + 'ragas_experimental/project/experiments.py'), 'ragas_experimental.project.experiments.Project.create_experiment': ( 'project/experiments.html#project.create_experiment', 'ragas_experimental/project/experiments.py'), 'ragas_experimental.project.experiments.Project.experiment': ( 'project/experiments.html#project.experiment', @@ -719,4 +721,6 @@ 'ragas_experimental.utils': { 'ragas_experimental.utils.async_to_sync': ( 'utils.html#async_to_sync', 'ragas_experimental/utils.py'), 'ragas_experimental.utils.create_nano_id': ( 'utils.html#create_nano_id', - 'ragas_experimental/utils.py')}}} + 'ragas_experimental/utils.py'), + 'ragas_experimental.utils.plot_experiments_as_subplots': ( 'utils.html#plot_experiments_as_subplots', + 'ragas_experimental/utils.py')}}} diff --git a/ragas_experimental/dataset.py b/ragas_experimental/dataset.py index 7f6df1c..920b578 100644 --- a/ragas_experimental/dataset.py +++ b/ragas_experimental/dataset.py @@ -218,7 +218,7 @@ def load_as_dicts(self: Dataset) -> t.List[t.Dict]: def to_pandas(self: Dataset) -> "pd.DataFrame": """Convert dataset to pandas DataFrame.""" import pandas as pd - + # Make sure we have data if not self._entries: self.load() diff --git a/ragas_experimental/project/experiments.py b/ragas_experimental/project/experiments.py index 83b0169..24cd178 100644 --- a/ragas_experimental/project/experiments.py +++ b/ragas_experimental/project/experiments.py @@ -9,6 +9,7 @@ from tqdm import tqdm from functools import wraps import asyncio +from tqdm import tqdm import typing as t @@ -99,13 +100,13 @@ def get_experiment_by_id(self: Project, experiment_id: str, model: t.Type[BaseMo # %% ../../nbs/project/experiments.ipynb 11 @patch -def get_experiment(self: Project, dataset_name: str, model) -> Dataset: +def get_experiment(self: Project, experiment_name: str, model) -> Dataset: """Get an existing dataset by name.""" # Search for dataset with given name sync_version = async_to_sync(self._ragas_api_client.get_experiment_by_name) exp_info = sync_version( project_id=self.project_id, - experiment_name=dataset_name + experiment_name=experiment_name ) # Return Dataset instance @@ -212,6 +213,8 @@ async def run_async(dataset: Dataset, name: t.Optional[str] = None): return decorator + + # %% ../../nbs/project/experiments.ipynb 22 @patch def langfuse_experiment( @@ -257,3 +260,40 @@ async def run_async_with_langfuse( return t.cast(ExperimentProtocol, base_experiment) return decorator + +# %% ../../nbs/project/experiments.ipynb 23 +import logging +from ..utils import plot_experiments_as_subplots + +@patch +def compare_and_plot(self: Project, experiment_names: t.List[str], model: t.Type[BaseModel], metric_names: t.List[str]): + """Compare multiple experiments and generate a plot. + + Args: + experiment_names: List of experiment IDs to compare + model: Model class defining the experiment structure + """ + results = {} + for experiment_name in tqdm(experiment_names, desc="Fetching experiments"): + experiment = self.get_experiment(experiment_name, model) + experiment.load() + results[experiment_name] = {} + for row in experiment: + for metric in metric_names: + if metric not in results[experiment_name]: + results[experiment_name][metric] = [] + if hasattr(row, metric): + results[experiment_name][metric].append(getattr(row, metric)) + else: + results[metric].append(None) + logging.warning(f"Metric {metric} not found in row: {row}") + + + + fig = plot_experiments_as_subplots(results,experiment_ids=experiment_names) + fig.show() + + + + + diff --git a/ragas_experimental/utils.py b/ragas_experimental/utils.py index d330081..03da4fd 100644 --- a/ragas_experimental/utils.py +++ b/ragas_experimental/utils.py @@ -1,7 +1,7 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb. # %% auto 0 -__all__ = ['create_nano_id', 'async_to_sync'] +__all__ = ['create_nano_id', 'async_to_sync', 'plot_experiments_as_subplots'] # %% ../nbs/utils.ipynb 1 import string @@ -43,3 +43,157 @@ def sync_wrapper(*args, **kwargs): except RuntimeError: return asyncio.run(async_func(*args, **kwargs)) return sync_wrapper + +# %% ../nbs/utils.ipynb 4 +import numpy as np +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from collections import Counter + +def plot_experiments_as_subplots(data, experiment_names=None): + """ + Plot metrics comparison across experiments. + + Parameters: + - data: Dictionary with experiment_names as keys and metrics as nested dictionaries + - experiment_names: List of experiment IDs in the order they should be plotted + + Returns: + - Plotly figure object with horizontal subplots + """ + if experiment_names is None: + experiment_names = list(data.keys()) + + exp_short_names = [f"{name[:10]}.."for name in experiment_names] + #TODO: need better solution to identify what type of metric it is + # this is a temporary solution + # Identify metrics and their types + metrics = {} + for exp_id in experiment_names: + for metric_name, values in data[exp_id].items(): + # Classify metric type (discrete or numerical) + if metric_name not in metrics: + # Check first value to determine type + is_discrete = isinstance(values[0], str) + metrics[metric_name] = {"type": "discrete" if is_discrete else "numerical"} + + # Create horizontal subplots (one for each metric) + fig = make_subplots( + rows=1, + cols=len(metrics), + subplot_titles=[f"{metric.capitalize()} Comparison" for metric in metrics.keys()], + horizontal_spacing=0.1 + ) + + # Process metrics and add traces + col_idx = 1 + for metric_name, metric_info in metrics.items(): + if metric_info["type"] == "discrete": + # For discrete metrics (like pass/fail) + categories = set() + for exp_id in experiment_names: + count = Counter(data[exp_id][metric_name]) + categories.update(count.keys()) + + categories = sorted(list(categories)) + + for category in categories: + y_values = [] + for exp_id in experiment_names: + count = Counter(data[exp_id][metric_name]) + total = sum(count.values()) + percentage = (count.get(category, 0) / total) * 100 + y_values.append(percentage) + + # Assign colors based on category + + # Generate consistent color for other categories + import hashlib + hash_obj = hashlib.md5(category.encode()) + hash_hex = hash_obj.hexdigest() + color = f"#{hash_hex[:6]}" + + fig.add_trace( + go.Bar( + x=exp_short_names, + y=y_values, + name=category.capitalize(), + marker_color=color, + width=0.5, # Narrower bars + hoverinfo='text', + hovertext=[f"{category.capitalize()}: {x:.1f}%" for x in y_values], + showlegend=False # Remove legend + ), + row=1, col=col_idx + ) + + else: # Numerical metrics + normalized_values = [] + original_values = [] + + for exp_id in experiment_names: + values = data[exp_id][metric_name] + mean_val = np.mean(values) + original_values.append(mean_val) + + # Normalize to 0-100 scale + min_val = np.min(values) + max_val = np.max(values) + normalized = ((mean_val - min_val) / (max_val - min_val)) * 100 + normalized_values.append(normalized) + + # Add bar chart for numerical data + fig.add_trace( + go.Bar( + x=exp_short_names, + y=normalized_values, + name=metric_name.capitalize(), + marker_color='#2E8B57', # Sea green + width=0.5, # Narrower bars + hoverinfo='text', + hovertext=[f"{metric_name.capitalize()} Mean: {val:.2f} (Normalized: {norm:.1f}%)" + for val, norm in zip(original_values, normalized_values)], + showlegend=False # Remove legend + ), + row=1, col=col_idx + ) + + # Update axes for each subplot + fig.update_yaxes( + title_text="Percentage (%)" if metric_info["type"] == "discrete" else "Normalized Value", + range=[0, 105], # Leave room for labels at the top + ticksuffix="%", + showgrid=True, + gridcolor='lightgray', + showline=True, + linewidth=1, + linecolor='black', + row=1, col=col_idx + ) + + fig.update_xaxes( + title_text="Experiments", + tickangle=-45, + showgrid=False, + showline=True, + linewidth=1, + linecolor='black', + row=1, col=col_idx + ) + + col_idx += 1 + + # Update layout for the entire figure + fig.update_layout( + title='Experiment Comparison by Metrics', + barmode='stack' if any(metric_info["type"] == "discrete" for metric_info in metrics.values()) else 'group', + height=400, # Reduced height + width=250 * len(metrics) + 150, # Adjust width based on number of metrics + showlegend=False, # Remove legend + margin=dict(t=80, b=50, l=50, r=50), + plot_bgcolor='white', + hovermode='closest' + ) + + return fig + diff --git a/settings.ini b/settings.ini index 07695a6..2f69d3f 100644 --- a/settings.ini +++ b/settings.ini @@ -39,7 +39,7 @@ status = 3 user = explodinggradients ### Dependencies ### -requirements = fastcore tqdm langfuse instructor pydantic numpy +requirements = fastcore tqdm langfuse instructor pydantic numpy plotly dev_requirements = pytest # console_scripts = # conda_user =