diff --git a/apps/evaluations/models.py b/apps/evaluations/models.py index d84dac6e8b..fc28aad92a 100644 --- a/apps/evaluations/models.py +++ b/apps/evaluations/models.py @@ -1,7 +1,7 @@ from __future__ import annotations import importlib -from collections import defaultdict +from collections import OrderedDict, defaultdict from functools import cached_property from typing import TYPE_CHECKING, Literal @@ -350,24 +350,28 @@ def get_table_data(self, include_ids: bool = False): for key, value in result.message_context.items() if key != "current_datetime" } - if include_ids is True: - table_by_message[result.message.id].update({"id": result.message.id}) - - table_by_message[result.message.id].update( - { - "Dataset Input": result.input_message, - "Dataset Output": result.output_message, - "Generated Response": result.output.get("generated_response", ""), - **{ - f"{key} ({result.evaluator.name})": value - for key, value in result.output.get("result", {}).items() - }, - **context_columns, - "session": result.session.external_id if result.session_id else "", - } + # Build row data in order + row_data = OrderedDict() + row_data["session"] = result.session.external_id if result.session_id else "" + row_data["message_id"] = result.message.id + row_data["Dataset Input"] = result.input_message + row_data["Dataset Output"] = result.output_message + row_data["Generated Response"] = result.output.get("generated_response", "") + + row_data.update( + {f"{key} ({result.evaluator.name})": value for key, value in result.output.get("result", {}).items()} ) + + row_data.update(context_columns) + if result.output.get("error"): - table_by_message[result.message.id]["error"] = result.output.get("error") + row_data["error"] = result.output.get("error") + + if include_ids is True: + row_data["id"] = result.message.id + + table_by_message[result.message.id] = row_data + return [{"#": index, **row} for index, row in enumerate(table_by_message.values())] diff --git a/apps/evaluations/tables.py b/apps/evaluations/tables.py index bc3db2afb8..022b3042da 100644 --- a/apps/evaluations/tables.py +++ b/apps/evaluations/tables.py @@ -284,6 +284,17 @@ class Meta: empty_text = "No sessions available for selection." +def _row_class_factory(table, record): + class_defaults = settings.DJANGO_TABLES2_ROW_ATTRS["class"] + if ( + hasattr(table, "highlight_message_id") + and table.highlight_message_id + and record.id == table.highlight_message_id + ): + return f"{class_defaults} bg-yellow-100 dark:bg-yellow-900/20" + return class_defaults + + class DatasetMessagesTable(tables.Table): human_message_content = TemplateColumn( template_name="evaluations/dataset_message_human_column.html", @@ -339,6 +350,10 @@ class DatasetMessagesTable(tables.Table): ] ) + def __init__(self, *args, highlight_message_id=None, **kwargs): + super().__init__(*args, **kwargs) + self.highlight_message_id = highlight_message_id + class Meta: model = EvaluationMessage fields = ( @@ -351,6 +366,9 @@ class Meta: "session_state", "actions", ) - row_attrs = settings.DJANGO_TABLES2_ROW_ATTRS + row_attrs = { + **settings.DJANGO_TABLES2_ROW_ATTRS, + "class": _row_class_factory, + } orderable = False empty_text = "No messages in this dataset yet." diff --git a/apps/evaluations/views/dataset_views.py b/apps/evaluations/views/dataset_views.py index 48a0e69834..3e77aaada9 100644 --- a/apps/evaluations/views/dataset_views.py +++ b/apps/evaluations/views/dataset_views.py @@ -283,7 +283,7 @@ class DatasetMessagesTableView(LoginAndTeamRequiredMixin, SingleTableView, Permi model = EvaluationMessage table_class = DatasetMessagesTable table_pagination = {"per_page": 10} - template_name = "table/single_table.html" + template_name = "evaluations/dataset_messages_table.html" permission_required = "evaluations.view_evaluationdataset" def get_queryset(self): @@ -295,6 +295,40 @@ def get_queryset(self): evaluationdataset__id=dataset_id, evaluationdataset__team=self.request.team ).order_by("id") + def get_highlight_message_id(self): + """Extract and validate the message_id query parameter for highlighting.""" + try: + return int(self.request.GET.get("message_id")) + except (ValueError, TypeError): + return None + + def get_table_pagination(self, table): + """Configure pagination and calculate page for highlighted message.""" + highlight_message_id = self.get_highlight_message_id() + page_size = self.table_pagination.get("per_page", 10) + pagination_config = dict(self.table_pagination) + + # On first load with highlight, calculate which page contains the message + if highlight_message_id and not self.request.GET.get("page"): + queryset = self.get_queryset() + messages_before = queryset.filter(id__lt=highlight_message_id).count() + + # Calculate which page contains this message and add to pagination config + calculated_page = (messages_before // page_size) + 1 + pagination_config["page"] = calculated_page + + return pagination_config + + def get_context_data(self, **kwargs): + context = super().get_context_data(**kwargs) + context["highlight_message_id"] = self.get_highlight_message_id() + return context + + def get_table_kwargs(self): + kwargs = super().get_table_kwargs() + kwargs["highlight_message_id"] = self.get_highlight_message_id() + return kwargs + @login_and_team_required @require_POST diff --git a/apps/evaluations/views/evaluation_config_views.py b/apps/evaluations/views/evaluation_config_views.py index 58858c22a1..3f41ae5297 100644 --- a/apps/evaluations/views/evaluation_config_views.py +++ b/apps/evaluations/views/evaluation_config_views.py @@ -11,6 +11,7 @@ from django.shortcuts import get_object_or_404, render from django.urls import reverse from django.utils import timezone +from django.utils.safestring import mark_safe from django.views.decorators.http import require_http_methods, require_POST from django.views.generic import CreateView, TemplateView, UpdateView from django_tables2 import SingleTableView, columns, tables @@ -337,6 +338,18 @@ def session_enabled_condition(_, record): # Check if session value exists (not empty string) return bool(record.get("session")) + def dataset_url_factory(_, __, record, value): + if not value: + return "#" + dataset_id = self.evaluation_run.config.dataset_id + message_id = record.get("message_id") + + url = reverse("evaluations:dataset_edit", args=[self.kwargs["team_slug"], dataset_id]) + return f"{url}?message_id={message_id}" + + def dataset_enabled_condition(_, record): + return bool(record.get("message_id")) + header = key.replace("_", " ").title() match key: case "#": @@ -362,9 +375,18 @@ def session_enabled_condition(_, record): url_factory=session_url_factory, enabled_condition=session_enabled_condition, ), + actions.chip_action( + label=mark_safe(''), + url_factory=dataset_url_factory, + enabled_condition=dataset_enabled_condition, + open_url_in_new_tab=True, + ), ], align="right", ) + case "message_id": + # Skip rendering message_id as a separate column since it's now in session column + return None return columns.Column(verbose_name=header) diff --git a/templates/evaluations/dataset_edit.html b/templates/evaluations/dataset_edit.html index 4f9fe1eb48..8da206bb7a 100644 --- a/templates/evaluations/dataset_edit.html +++ b/templates/evaluations/dataset_edit.html @@ -56,7 +56,7 @@

{% translate "Dataset Creation Failed" %}

{% endif %}

{% translate "Dataset Messages" %}

diff --git a/templates/evaluations/dataset_messages_table.html b/templates/evaluations/dataset_messages_table.html new file mode 100644 index 0000000000..afdcddd527 --- /dev/null +++ b/templates/evaluations/dataset_messages_table.html @@ -0,0 +1,28 @@ +{% load static %} +{% load render_table from django_tables2 %} + +{% render_table table %} + +{% if highlight_message_id %} + +{% endif %} + +{% block modal %} +{% endblock modal %}