diff --git a/nbs/backends/ragas_api_client.ipynb b/nbs/backends/ragas_api_client.ipynb index 0834e13..d2c8d90 100644 --- a/nbs/backends/ragas_api_client.ipynb +++ b/nbs/backends/ragas_api_client.ipynb @@ -9,13 +9,6 @@ "> Python client to api.ragas.io" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -49,6 +42,20 @@ "from fastcore.utils import patch" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from ragas_experimental.exceptions import (\n", + " DatasetNotFoundError, DuplicateDatasetError,\n", + " ProjectNotFoundError, DuplicateProjectError,\n", + " ExperimentNotFoundError, DuplicateExperimentError\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -128,6 +135,108 @@ " return await self._request(\"DELETE\", path)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "@patch\n", + "async def _get_resource_by_name(\n", + " self: RagasApiClient,\n", + " list_method: t.Callable,\n", + " get_method: t.Callable,\n", + " resource_name: str,\n", + " name_field: str,\n", + " not_found_error: t.Type[Exception],\n", + " duplicate_error: t.Type[Exception],\n", + " resource_type_name: str,\n", + " **list_method_kwargs\n", + ") -> t.Dict:\n", + " \"\"\"Generic method to get a resource by name.\n", + " \n", + " Args:\n", + " list_method: Method to list resources\n", + " get_method: Method to get a specific resource\n", + " resource_name: Name to search for\n", + " name_field: Field name that contains the resource name\n", + " not_found_error: Exception to raise when resource is not found\n", + " duplicate_error: Exception to raise when multiple resources are found\n", + " resource_type_name: Human-readable name of the resource type\n", + " **list_method_kwargs: Additional arguments to pass to list_method\n", + " \n", + " Returns:\n", + " The resource information dictionary\n", + " \n", + " Raises:\n", + " Exception: If resource is not found or multiple resources are found\n", + " \"\"\"\n", + " # Initial pagination parameters\n", + " limit = 50 # Number of items per page\n", + " offset = 0 # Starting position\n", + " matching_resources = []\n", + " \n", + " while True:\n", + " # Get a page of resources\n", + " response = await list_method(\n", + " limit=limit,\n", + " offset=offset,\n", + " **list_method_kwargs\n", + " )\n", + " \n", + " items = response.get(\"items\", [])\n", + " \n", + " # If no items returned, we've reached the end\n", + " if not items:\n", + " break\n", + " \n", + " # Collect all resources with the matching name in this page\n", + " for resource in items:\n", + " if resource.get(name_field) == resource_name:\n", + " matching_resources.append(resource)\n", + " \n", + " # Update offset for the next page\n", + " offset += limit\n", + " \n", + " # If we've processed all items (less than limit returned), exit the loop\n", + " if len(items) < limit:\n", + " break\n", + " \n", + " # Check results\n", + " if not matching_resources:\n", + " context = list_method_kwargs.get(\"project_id\", \"\")\n", + " context_msg = f\" in project {context}\" if context else \"\"\n", + " raise not_found_error(\n", + " f\"No {resource_type_name} with name '{resource_name}' found{context_msg}\"\n", + " )\n", + " \n", + " if len(matching_resources) > 1:\n", + " # Multiple matches found - construct an informative error message\n", + " resource_ids = [r.get(\"id\") for r in matching_resources]\n", + " context = list_method_kwargs.get(\"project_id\", \"\")\n", + " context_msg = f\" in project {context}\" if context else \"\"\n", + " \n", + " raise duplicate_error(\n", + " f\"Multiple {resource_type_name}s found with name '{resource_name}'{context_msg}. \"\n", + " f\"{resource_type_name.capitalize()} IDs: {', '.join(resource_ids)}. \"\n", + " f\"Please use get_{resource_type_name}() with a specific ID instead.\"\n", + " )\n", + " \n", + " # Exactly one match found - retrieve full details\n", + " if \"project_id\" in list_method_kwargs:\n", + " return await get_method(list_method_kwargs[\"project_id\"], matching_resources[0].get(\"id\"))\n", + " else:\n", + " return await get_method(matching_resources[0].get(\"id\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Projects" + ] + }, { "cell_type": "code", "execution_count": null, @@ -224,13 +333,6 @@ " print(f\"Error: {e}\")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Projects" - ] - }, { "cell_type": "code", "execution_count": null, @@ -263,69 +365,89 @@ { "data": { "text/plain": [ - "{'items': [{'id': '26b0e577-8ff8-4014-bc7a-cfc410df3488',\n", + "{'items': [{'id': '1ef0843b-231f-4a2c-b64d-d39bcee9d830',\n", + " 'title': 'yann-lecun-wisdom',\n", + " 'description': 'Yann LeCun Wisdom',\n", + " 'created_at': '2025-04-15T03:27:08.962384+00:00',\n", + " 'updated_at': '2025-04-15T03:27:08.962384+00:00'},\n", + " {'id': 'c2d788ec-a602-495b-8ddc-f457ce11b414',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-12T19:47:10.928422+00:00',\n", + " 'updated_at': '2025-04-12T19:47:10.928422+00:00'},\n", + " {'id': '0d465f02-c88f-454e-9ff3-780a001e3e21',\n", " 'title': 'test project',\n", " 'description': 'test description',\n", - " 'created_at': '2025-04-10T00:12:34.606398+00:00',\n", - " 'updated_at': '2025-04-10T00:12:34.606398+00:00'},\n", - " {'id': '18df1385-8ff6-4c3f-bd8d-a4fcbfba5f57',\n", - " 'title': 'My Project',\n", + " 'created_at': '2025-04-12T19:46:36.221385+00:00',\n", + " 'updated_at': '2025-04-12T19:46:36.221385+00:00'},\n", + " {'id': '2ae1434c-e700-44a7-9528-7c2f03cfb491',\n", + " 'title': 'Demo Project',\n", " 'description': None,\n", - " 'created_at': '2025-04-09T23:49:26.491293+00:00',\n", - " 'updated_at': '2025-04-09T23:49:26.491293+00:00'},\n", - " {'id': '940f2c5a-4c28-4dde-9643-97daede5408e',\n", - " 'title': 'My Project',\n", + " 'created_at': '2025-04-12T19:46:36.157122+00:00',\n", + " 'updated_at': '2025-04-12T19:46:36.157122+00:00'},\n", + " {'id': 'adb45ec6-6902-4339-b05f-3b86fd256c7e',\n", + " 'title': 'Demo Project',\n", " 'description': None,\n", - " 'created_at': '2025-04-09T23:24:34.922817+00:00',\n", - " 'updated_at': '2025-04-09T23:24:34.922817+00:00'},\n", - " {'id': '5f892600-e620-448c-a699-aa5971abefdc',\n", - " 'title': 'test project',\n", - " 'description': 'test description',\n", - " 'created_at': '2025-04-09T23:00:39.272415+00:00',\n", - " 'updated_at': '2025-04-09T23:00:39.272415+00:00'},\n", - " {'id': 'd428d6fc-0348-4208-8187-6fa70ae8115e',\n", - " 'title': 'My Project',\n", + " 'created_at': '2025-04-12T19:45:54.430913+00:00',\n", + " 'updated_at': '2025-04-12T19:45:54.430913+00:00'},\n", + " {'id': '6f26bf5b-af4d-48b5-af2d-13d3e671bbbf',\n", + " 'title': 'Demo Project',\n", " 'description': None,\n", - " 'created_at': '2025-04-09T22:49:59.790937+00:00',\n", - " 'updated_at': '2025-04-09T22:49:59.790937+00:00'},\n", - " {'id': 'a00d417a-f1dc-4373-ad0e-a3bfff5c3c55',\n", - " 'title': 'test project',\n", - " 'description': 'test description',\n", - " 'created_at': '2025-04-09T18:32:42.772189+00:00',\n", - " 'updated_at': '2025-04-09T18:32:42.772189+00:00'},\n", - " {'id': '3b8396ab-592e-4898-bac3-845fe2d5fbd0',\n", - " 'title': 'test project',\n", - " 'description': 'test description',\n", - " 'created_at': '2025-04-09T18:31:36.496048+00:00',\n", - " 'updated_at': '2025-04-09T18:31:36.496048+00:00'},\n", - " {'id': '59cf2483-d2c7-4306-af87-bfac2813f27b',\n", - " 'title': 'test project',\n", - " 'description': 'test description',\n", - " 'created_at': '2025-04-09T05:57:57.991728+00:00',\n", - " 'updated_at': '2025-04-09T05:57:57.991728+00:00'},\n", - " {'id': 'c026b63c-d618-42c0-81c3-d7824c976eb1',\n", - " 'title': 'test project',\n", - " 'description': 'test description',\n", - " 'created_at': '2025-04-08T22:46:04.045516+00:00',\n", - " 'updated_at': '2025-04-08T22:46:04.045516+00:00'},\n", - " {'id': '3dd738de-49f7-494c-aa0a-f6531d3b603a',\n", - " 'title': 'RagasTest',\n", - " 'description': '',\n", - " 'created_at': '2025-04-08T17:45:32.759553+00:00',\n", - " 'updated_at': '2025-04-08T17:45:32.759553+00:00'},\n", - " {'id': '2f45d026-1b13-4851-a36d-c7680edb6380',\n", - " 'title': 'test project',\n", - " 'description': 'test description',\n", - " 'created_at': '2025-04-08T14:38:14.61165+00:00',\n", - " 'updated_at': '2025-04-08T14:38:14.61165+00:00'},\n", - " {'id': 'e1b3f1e4-d344-48f4-a178-84e7e32e6ab6',\n", - " 'title': 'test project',\n", - " 'description': 'test description',\n", - " 'created_at': '2025-03-30T02:33:38.751793+00:00',\n", - " 'updated_at': '2025-03-30T02:33:38.751793+00:00'}],\n", + " 'created_at': '2025-04-11T00:56:30.085249+00:00',\n", + " 'updated_at': '2025-04-11T00:56:30.085249+00:00'},\n", + " {'id': '63e4fc0f-1a60-441b-bd71-f21ce8e35c7e',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-11T00:44:56.031721+00:00',\n", + " 'updated_at': '2025-04-11T00:44:56.031721+00:00'},\n", + " {'id': 'db0bedd6-6cfa-4551-b1ab-af78fa82dca7',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-11T00:44:17.601598+00:00',\n", + " 'updated_at': '2025-04-11T00:44:17.601598+00:00'},\n", + " {'id': '80c8ef9a-23d7-4a9f-a7d7-36c6472ab51e',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-11T00:42:37.287184+00:00',\n", + " 'updated_at': '2025-04-11T00:42:37.287184+00:00'},\n", + " {'id': 'ae2a5a5c-3902-4ef6-af50-f2d8f27feea6',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-11T00:40:53.71528+00:00',\n", + " 'updated_at': '2025-04-11T00:40:53.71528+00:00'},\n", + " {'id': '96618f8b-d3a1-4998-9a66-155f8f254512',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-11T00:31:21.410658+00:00',\n", + " 'updated_at': '2025-04-11T00:31:21.410658+00:00'},\n", + " {'id': '4515aa23-cb4c-4c0a-b833-fefd0a30fdcc',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-11T00:27:49.977435+00:00',\n", + " 'updated_at': '2025-04-11T00:27:49.977435+00:00'},\n", + " {'id': '138098a4-651e-4dca-b226-d70956b3e039',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-11T00:24:03.39505+00:00',\n", + " 'updated_at': '2025-04-11T00:24:03.39505+00:00'},\n", + " {'id': 'bbe45632-3268-43a6-9694-b020b3f5226f',\n", + " 'title': 'Demo Project',\n", + " 'description': None,\n", + " 'created_at': '2025-04-10T22:41:14.663646+00:00',\n", + " 'updated_at': '2025-04-10T22:41:14.663646+00:00'},\n", + " {'id': 'df764139-bac7-4aec-af24-5c6886189f84',\n", + " 'title': 'SuperMe-Demo',\n", + " 'description': 'SuperMe demo to show the team',\n", + " 'created_at': '2025-04-10T04:35:18.631257+00:00',\n", + " 'updated_at': '2025-04-10T04:35:18.631257+00:00'},\n", + " {'id': 'a6ccabe0-7b8d-4866-98af-f167a36b94ff',\n", + " 'title': 'SuperMe',\n", + " 'description': 'SuperMe demo to show the team',\n", + " 'created_at': '2025-04-10T03:10:29.153622+00:00',\n", + " 'updated_at': '2025-04-10T03:10:29.153622+00:00'}],\n", " 'pagination': {'offset': 0,\n", " 'limit': 50,\n", - " 'total': 12,\n", + " 'total': 16,\n", " 'order_by': 'created_at',\n", " 'sort_dir': 'desc'}}" ] @@ -345,10 +467,68 @@ "metadata": {}, "outputs": [], "source": [ - "TEST_PROJECT_ID = \"e1b3f1e4-d344-48f4-a178-84e7e32e6ab6\"\n", + "TEST_PROJECT_ID = \"a6ccabe0-7b8d-4866-98af-f167a36b94ff\"\n", "project = await client.get_project(TEST_PROJECT_ID)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "@patch\n", + "async def get_project_by_name(\n", + " self: RagasApiClient, project_name: str\n", + ") -> t.Dict:\n", + " \"\"\"Get a project by its name.\n", + " \n", + " Args:\n", + " project_name: Name of the project to find\n", + " \n", + " Returns:\n", + " The project information dictionary\n", + " \n", + " Raises:\n", + " ProjectNotFoundError: If no project with the given name is found\n", + " DuplicateProjectError: If multiple projects with the given name are found\n", + " \"\"\"\n", + " return await self._get_resource_by_name(\n", + " list_method=self.list_projects,\n", + " get_method=self.get_project,\n", + " resource_name=project_name,\n", + " name_field=\"title\", # Projects use 'title' instead of 'name'\n", + " not_found_error=ProjectNotFoundError,\n", + " duplicate_error=DuplicateProjectError,\n", + " resource_type_name=\"project\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'id': 'a6ccabe0-7b8d-4866-98af-f167a36b94ff',\n", + " 'title': 'SuperMe',\n", + " 'description': 'SuperMe demo to show the team',\n", + " 'created_at': '2025-04-10T03:10:29.153622+00:00',\n", + " 'updated_at': '2025-04-10T03:10:29.153622+00:00'}" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await client.get_project_by_name(\"SuperMe\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -427,8 +607,8 @@ { "data": { "text/plain": [ - "('26b0e577-8ff8-4014-bc7a-cfc410df3488',\n", - " 'e1b3f1e4-d344-48f4-a178-84e7e32e6ab6')" + "('1ef0843b-231f-4a2c-b64d-d39bcee9d830',\n", + " 'a6ccabe0-7b8d-4866-98af-f167a36b94ff')" ] }, "execution_count": null, @@ -451,7 +631,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "New dataset created: {'id': 'e75f6ee2-2b36-4599-8e21-8da7f4006aa2', 'name': 'New Dataset', 'description': 'This is a new dataset', 'updated_at': '2025-04-10T00:12:36.940908+00:00', 'created_at': '2025-04-10T00:12:36.940908+00:00', 'version_counter': 0, 'project_id': '26b0e577-8ff8-4014-bc7a-cfc410df3488'}\n" + "New dataset created: {'id': '2382037f-906c-45a0-9b9f-702d32903efd', 'name': 'New Dataset', 'description': 'This is a new dataset', 'updated_at': '2025-04-16T03:52:01.91574+00:00', 'created_at': '2025-04-16T03:52:01.91574+00:00', 'version_counter': 0, 'project_id': '1ef0843b-231f-4a2c-b64d-d39bcee9d830'}\n" ] } ], @@ -491,7 +671,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Updated dataset: {'id': 'e75f6ee2-2b36-4599-8e21-8da7f4006aa2', 'name': 'Updated Dataset', 'description': 'This is an updated dataset', 'created_at': '2025-04-10T00:12:36.940908+00:00', 'updated_at': '2025-04-10T00:12:38.564782+00:00', 'version_counter': 0, 'project_id': '26b0e577-8ff8-4014-bc7a-cfc410df3488'}\n" + "Updated dataset: {'id': '8572180f-fddf-46c5-b943-e6ff6448eb01', 'name': 'Updated Dataset', 'description': 'This is an updated dataset', 'created_at': '2025-04-15T03:28:09.050125+00:00', 'updated_at': '2025-04-16T03:52:09.627448+00:00', 'version_counter': 0, 'project_id': '1ef0843b-231f-4a2c-b64d-d39bcee9d830'}\n" ] } ], @@ -524,6 +704,72 @@ "print(\"Dataset deleted\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For the time being I've also added another option to get the dataset by name too" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "@patch\n", + "async def get_dataset_by_name(\n", + " self: RagasApiClient, project_id: str, dataset_name: str\n", + ") -> t.Dict:\n", + " \"\"\"Get a dataset by its name.\n", + " \n", + " Args:\n", + " project_id: ID of the project\n", + " dataset_name: Name of the dataset to find\n", + " \n", + " Returns:\n", + " The dataset information dictionary\n", + " \n", + " Raises:\n", + " DatasetNotFoundError: If no dataset with the given name is found\n", + " DuplicateDatasetError: If multiple datasets with the given name are found\n", + " \"\"\"\n", + " return await self._get_resource_by_name(\n", + " list_method=self.list_datasets,\n", + " get_method=self.get_dataset,\n", + " resource_name=dataset_name,\n", + " name_field=\"name\",\n", + " not_found_error=DatasetNotFoundError,\n", + " duplicate_error=DuplicateDatasetError,\n", + " resource_type_name=\"dataset\",\n", + " project_id=project_id\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "DuplicateDatasetError", + "evalue": "Multiple datasets found with name 'test' in project a6ccabe0-7b8d-4866-98af-f167a36b94ff. Dataset IDs: 9a48d5d1-531f-424f-b2d2-d8f9bcaeec1e, 483477a4-3d00-4010-a253-c92dee3bc092. Please use get_dataset() with a specific ID instead.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mDuplicateDatasetError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[19]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m client.get_dataset_by_name(project_id=TEST_PROJECT_ID, dataset_name=\u001b[33m\"\u001b[39m\u001b[33mtest\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[18]\u001b[39m\u001b[32m, line 18\u001b[39m, in \u001b[36mget_dataset_by_name\u001b[39m\u001b[34m(self, project_id, dataset_name)\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;129m@patch\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mget_dataset_by_name\u001b[39m(\n\u001b[32m 3\u001b[39m \u001b[38;5;28mself\u001b[39m: RagasApiClient, project_id: \u001b[38;5;28mstr\u001b[39m, dataset_name: \u001b[38;5;28mstr\u001b[39m\n\u001b[32m 4\u001b[39m ) -> t.Dict:\n\u001b[32m 5\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Get a dataset by its name.\u001b[39;00m\n\u001b[32m 6\u001b[39m \n\u001b[32m 7\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 16\u001b[39m \u001b[33;03m DuplicateDatasetError: If multiple datasets with the given name are found\u001b[39;00m\n\u001b[32m 17\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m._get_resource_by_name(\n\u001b[32m 19\u001b[39m list_method=\u001b[38;5;28mself\u001b[39m.list_datasets,\n\u001b[32m 20\u001b[39m get_method=\u001b[38;5;28mself\u001b[39m.get_dataset,\n\u001b[32m 21\u001b[39m resource_name=dataset_name,\n\u001b[32m 22\u001b[39m name_field=\u001b[33m\"\u001b[39m\u001b[33mname\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 23\u001b[39m not_found_error=DatasetNotFoundError,\n\u001b[32m 24\u001b[39m duplicate_error=DuplicateDatasetError,\n\u001b[32m 25\u001b[39m resource_type_name=\u001b[33m\"\u001b[39m\u001b[33mdataset\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 26\u001b[39m project_id=project_id\n\u001b[32m 27\u001b[39m )\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 76\u001b[39m, in \u001b[36m_get_resource_by_name\u001b[39m\u001b[34m(self, list_method, get_method, resource_name, name_field, not_found_error, duplicate_error, resource_type_name, **list_method_kwargs)\u001b[39m\n\u001b[32m 73\u001b[39m context = list_method_kwargs.get(\u001b[33m\"\u001b[39m\u001b[33mproject_id\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 74\u001b[39m context_msg = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m in project \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcontext\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m context \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m76\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m duplicate_error(\n\u001b[32m 77\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mMultiple \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_type_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33ms found with name \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcontext_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_type_name.capitalize()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m IDs: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m, \u001b[39m\u001b[33m'\u001b[39m.join(resource_ids)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 79\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mPlease use get_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_type_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m() with a specific ID instead.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 80\u001b[39m )\n\u001b[32m 82\u001b[39m \u001b[38;5;66;03m# Exactly one match found - retrieve full details\u001b[39;00m\n\u001b[32m 83\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mproject_id\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m list_method_kwargs:\n", + "\u001b[31mDuplicateDatasetError\u001b[39m: Multiple datasets found with name 'test' in project a6ccabe0-7b8d-4866-98af-f167a36b94ff. Dataset IDs: 9a48d5d1-531f-424f-b2d2-d8f9bcaeec1e, 483477a4-3d00-4010-a253-c92dee3bc092. Please use get_dataset() with a specific ID instead." + ] + } + ], + "source": [ + "await client.get_dataset_by_name(project_id=TEST_PROJECT_ID, dataset_name=\"test\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -673,6 +919,65 @@ "await client.list_experiments(TEST_PROJECT_ID)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "@patch\n", + "async def get_experiment_by_name(\n", + " self: RagasApiClient, project_id: str, experiment_name: str\n", + ") -> t.Dict:\n", + " \"\"\"Get an experiment by its name.\n", + " \n", + " Args:\n", + " project_id: ID of the project containing the experiment\n", + " experiment_name: Name of the experiment to find\n", + " \n", + " Returns:\n", + " The experiment information dictionary\n", + " \n", + " Raises:\n", + " ExperimentNotFoundError: If no experiment with the given name is found\n", + " DuplicateExperimentError: If multiple experiments with the given name are found\n", + " \"\"\"\n", + " return await self._get_resource_by_name(\n", + " list_method=self.list_experiments,\n", + " get_method=self.get_experiment,\n", + " resource_name=experiment_name,\n", + " name_field=\"name\",\n", + " not_found_error=ExperimentNotFoundError,\n", + " duplicate_error=DuplicateExperimentError,\n", + " resource_type_name=\"experiment\",\n", + " project_id=project_id\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "ename": "DuplicateExperimentError", + "evalue": "Multiple experiments found with name 'test' in project a6ccabe0-7b8d-4866-98af-f167a36b94ff. Experiment IDs: e1ae15aa-2e0e-40dd-902a-0f0e0fd4df69, 52428c79-afdf-468e-82dc-6ef82c5b71d2, 55e14ac3-0037-4909-898f-eee9533a6d3f, 9adfa008-b479-41cf-ba28-c860e01401ea, 233d28c8-6556-49c5-b146-1e001720c214, 6aed5143-3f60-4bf2-bcf2-ecfdb950e992. Please use get_experiment() with a specific ID instead.", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mDuplicateExperimentError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[23]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m client.get_experiment_by_name(TEST_PROJECT_ID, \u001b[33m\"\u001b[39m\u001b[33mtest\u001b[39m\u001b[33m\"\u001b[39m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[22]\u001b[39m\u001b[32m, line 19\u001b[39m, in \u001b[36mget_experiment_by_name\u001b[39m\u001b[34m(self, project_id, experiment_name)\u001b[39m\n\u001b[32m 2\u001b[39m \u001b[38;5;129m@patch\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mget_experiment_by_name\u001b[39m(\n\u001b[32m 4\u001b[39m \u001b[38;5;28mself\u001b[39m: RagasApiClient, project_id: \u001b[38;5;28mstr\u001b[39m, experiment_name: \u001b[38;5;28mstr\u001b[39m\n\u001b[32m 5\u001b[39m ) -> t.Dict:\n\u001b[32m 6\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Get an experiment by its name.\u001b[39;00m\n\u001b[32m 7\u001b[39m \n\u001b[32m 8\u001b[39m \u001b[33;03m Args:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 17\u001b[39m \u001b[33;03m DuplicateExperimentError: If multiple experiments with the given name are found\u001b[39;00m\n\u001b[32m 18\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mawait\u001b[39;00m \u001b[38;5;28mself\u001b[39m._get_resource_by_name(\n\u001b[32m 20\u001b[39m list_method=\u001b[38;5;28mself\u001b[39m.list_experiments,\n\u001b[32m 21\u001b[39m get_method=\u001b[38;5;28mself\u001b[39m.get_experiment,\n\u001b[32m 22\u001b[39m resource_name=experiment_name,\n\u001b[32m 23\u001b[39m name_field=\u001b[33m\"\u001b[39m\u001b[33mname\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 24\u001b[39m not_found_error=ExperimentNotFoundError,\n\u001b[32m 25\u001b[39m duplicate_error=DuplicateExperimentError,\n\u001b[32m 26\u001b[39m resource_type_name=\u001b[33m\"\u001b[39m\u001b[33mexperiment\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 27\u001b[39m project_id=project_id\n\u001b[32m 28\u001b[39m )\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[12]\u001b[39m\u001b[32m, line 76\u001b[39m, in \u001b[36m_get_resource_by_name\u001b[39m\u001b[34m(self, list_method, get_method, resource_name, name_field, not_found_error, duplicate_error, resource_type_name, **list_method_kwargs)\u001b[39m\n\u001b[32m 73\u001b[39m context = list_method_kwargs.get(\u001b[33m\"\u001b[39m\u001b[33mproject_id\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 74\u001b[39m context_msg = \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m in project \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcontext\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m context \u001b[38;5;28;01melse\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m76\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m duplicate_error(\n\u001b[32m 77\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mMultiple \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_type_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33ms found with name \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcontext_msg\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 78\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_type_name.capitalize()\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m IDs: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m, \u001b[39m\u001b[33m'\u001b[39m.join(resource_ids)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 79\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mPlease use get_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresource_type_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m() with a specific ID instead.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 80\u001b[39m )\n\u001b[32m 82\u001b[39m \u001b[38;5;66;03m# Exactly one match found - retrieve full details\u001b[39;00m\n\u001b[32m 83\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mproject_id\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m list_method_kwargs:\n", + "\u001b[31mDuplicateExperimentError\u001b[39m: Multiple experiments found with name 'test' in project a6ccabe0-7b8d-4866-98af-f167a36b94ff. Experiment IDs: e1ae15aa-2e0e-40dd-902a-0f0e0fd4df69, 52428c79-afdf-468e-82dc-6ef82c5b71d2, 55e14ac3-0037-4909-898f-eee9533a6d3f, 9adfa008-b479-41cf-ba28-c860e01401ea, 233d28c8-6556-49c5-b146-1e001720c214, 6aed5143-3f60-4bf2-bcf2-ecfdb950e992. Please use get_experiment() with a specific ID instead." + ] + } + ], + "source": [ + "await client.get_experiment_by_name(TEST_PROJECT_ID, \"test\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/nbs/exceptions.ipynb b/nbs/exceptions.ipynb index 38dfb74..758b65d 100644 --- a/nbs/exceptions.ipynb +++ b/nbs/exceptions.ipynb @@ -18,6 +18,18 @@ "# | default_exp exceptions" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class RagasError(Exception):\n", + " \"\"\"Base class for all Ragas-related exceptions.\"\"\"\n", + " pass" + ] + }, { "cell_type": "code", "execution_count": null, @@ -42,6 +54,56 @@ "\n", " pass" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class ResourceNotFoundError(RagasError):\n", + " \"\"\"Exception raised when a requested resource doesn't exist.\"\"\"\n", + " pass\n", + "\n", + "class ProjectNotFoundError(ResourceNotFoundError):\n", + " \"\"\"Exception raised when a project doesn't exist.\"\"\"\n", + " pass\n", + "\n", + "class DatasetNotFoundError(ResourceNotFoundError):\n", + " \"\"\"Exception raised when a dataset doesn't exist.\"\"\"\n", + " pass\n", + "\n", + "class ExperimentNotFoundError(ResourceNotFoundError):\n", + " \"\"\"Exception raised when an experiment doesn't exist.\"\"\"\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class DuplicateResourceError(RagasError):\n", + " \"\"\"Exception raised when multiple resources exist with the same identifier.\"\"\"\n", + " pass\n", + "\n", + "class DuplicateProjectError(DuplicateResourceError):\n", + " \"\"\"Exception raised when multiple projects exist with the same name.\"\"\"\n", + " pass\n", + "\n", + "class DuplicateDatasetError(DuplicateResourceError):\n", + " \"\"\"Exception raised when multiple datasets exist with the same name.\"\"\"\n", + " pass\n", + "\n", + "class DuplicateExperimentError(DuplicateResourceError):\n", + " \"\"\"Exception raised when multiple experiments exist with the same name.\"\"\"\n", + " pass" + ] } ], "metadata": { diff --git a/nbs/project/core.ipynb b/nbs/project/core.ipynb index c53c029..9db4b92 100644 --- a/nbs/project/core.ipynb +++ b/nbs/project/core.ipynb @@ -144,6 +144,53 @@ "project" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# | export\n", + "@patch(cls_method=True)\n", + "def get(cls: Project, name: str, ragas_app_client: t.Optional[RagasApiClient] = None) -> Project:\n", + " \"\"\"Get an existing project by name.\"\"\"\n", + " # Search for project with given name\n", + " if ragas_app_client is None:\n", + " ragas_app_client = RagasApiClientFactory.create()\n", + "\n", + " # get the project by name\n", + " sync_version = async_to_sync(ragas_app_client.get_project_by_name)\n", + " project_info = sync_version(\n", + " project_name=name\n", + " )\n", + "\n", + " # Return Project instance\n", + " return Project(\n", + " project_id=project_info[\"id\"],\n", + " ragas_app_client=ragas_app_client,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Project(name='SuperMe')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Project.get(\"SuperMe\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -268,7 +315,7 @@ { "data": { "text/plain": [ - "'bbe45632-3268-43a6-9694-b020b3f5226f'" + "'3d9b529b-c23f-4e87-8a26-dd1923749aa7'" ] }, "execution_count": null, @@ -288,7 +335,7 @@ { "data": { "text/plain": [ - "'0fee5330-9f6e-44a9-a85c-e3b947b697de'" + "'5f7839e7-f4f9-4ada-9082-bc1c5528c259'" ] }, "execution_count": null, @@ -308,7 +355,7 @@ "source": [ "# | export\n", "@patch\n", - "def get_dataset(self: Project, dataset_id: str, model) -> Dataset:\n", + "def get_dataset_by_id(self: Project, dataset_id: str, model) -> Dataset:\n", " \"\"\"Get an existing dataset by name.\"\"\"\n", " # Search for database with given name\n", " sync_version = async_to_sync(self._ragas_api_client.get_dataset)\n", @@ -344,7 +391,34 @@ } ], "source": [ - "project.get_dataset(test_dataset.dataset_id, TestModel)" + "project.get_dataset_by_id(test_dataset.dataset_id, TestModel)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# | export\n", + "@patch\n", + "def get_dataset(self: Project, dataset_name: str, model) -> Dataset:\n", + " \"\"\"Get an existing dataset by name.\"\"\"\n", + " # Search for dataset with given name\n", + " sync_version = async_to_sync(self._ragas_api_client.get_dataset_by_name)\n", + " dataset_info = sync_version(\n", + " project_id=self.project_id,\n", + " dataset_name=dataset_name\n", + " )\n", + "\n", + " # Return Dataset instance\n", + " return Dataset(\n", + " name=dataset_info[\"name\"],\n", + " model=model,\n", + " project_id=self.project_id,\n", + " dataset_id=dataset_info[\"id\"],\n", + " ragas_api_client=self._ragas_api_client,\n", + " )" ] }, { @@ -355,7 +429,7 @@ { "data": { "text/plain": [ - "'0a7c4ecb-b313-4bb0-81c0-852c9634ce03'" + "Dataset(name=TestModel, model=TestModel, len=0)" ] }, "execution_count": null, @@ -364,15 +438,8 @@ } ], "source": [ - "project.project_id" + "project.get_dataset(\"TestModel\", TestModel)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/nbs/project/experiments.ipynb b/nbs/project/experiments.ipynb index 891d94c..06e43b0 100644 --- a/nbs/project/experiments.ipynb +++ b/nbs/project/experiments.ipynb @@ -41,6 +41,13 @@ "import ragas_experimental.typing as rt" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basics" + ] + }, { "cell_type": "code", "execution_count": null, @@ -155,7 +162,7 @@ { "data": { "text/plain": [ - "Experiment(name=just name, desc, price 2, model=TestModel)" + "Experiment(name=test-exp, model=TestModel)" ] }, "execution_count": null, @@ -165,7 +172,7 @@ ], "source": [ "experiment_id = \"5d7752ab-17bf-46bc-a302-afe04ce1a763\"\n", - "exp = p.create_experiment(name=\"just name, desc, price 2\", model=TestModel)\n", + "exp = p.create_experiment(name=\"test-exp\", model=TestModel)\n", "#exp = p.create_dataset(name=\"just name and desc 2\", model=TestModel)\n", "\n", "exp" @@ -179,7 +186,7 @@ "source": [ "# | export\n", "@patch\n", - "def get_experiment(self: Project, experiment_id: str, model: t.Type[BaseModel]) -> Experiment:\n", + "def get_experiment_by_id(self: Project, experiment_id: str, model: t.Type[BaseModel]) -> Experiment:\n", " \"\"\"Get an existing experiment by ID.\"\"\"\n", " # Get experiment info\n", " sync_version = async_to_sync(self._ragas_api_client.get_experiment)\n", @@ -205,7 +212,7 @@ { "data": { "text/plain": [ - "'22bbb40c-1fc0-4a09-b26a-ccc93c8bd595'" + "'effe0e10-916d-4530-b974-91d5115f5dc2'" ] }, "execution_count": null, @@ -214,7 +221,7 @@ } ], "source": [ - "exp.dataset_id" + "exp.experiment_id" ] }, { @@ -225,7 +232,7 @@ { "data": { "text/plain": [ - "Project(name='SuperMe')" + "Experiment(name=test-exp, model=TestModel)" ] }, "execution_count": null, @@ -234,7 +241,34 @@ } ], "source": [ - "p" + "p.get_experiment_by_id(exp.experiment_id, TestModel)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# | export\n", + "@patch\n", + "def get_experiment(self: Project, dataset_name: str, model) -> Dataset:\n", + " \"\"\"Get an existing dataset by name.\"\"\"\n", + " # Search for dataset with given name\n", + " sync_version = async_to_sync(self._ragas_api_client.get_experiment_by_name)\n", + " exp_info = sync_version(\n", + " project_id=self.project_id,\n", + " experiment_name=dataset_name\n", + " )\n", + "\n", + " # Return Dataset instance\n", + " return Experiment(\n", + " name=exp_info[\"name\"],\n", + " model=model,\n", + " project_id=self.project_id,\n", + " experiment_id=exp_info[\"id\"],\n", + " ragas_api_client=self._ragas_api_client,\n", + " )" ] }, { @@ -245,7 +279,7 @@ { "data": { "text/plain": [ - "Experiment(name=just name, desc, price 2, model=TestModel)" + "Experiment(name=test-exp, model=TestModel)" ] }, "execution_count": null, @@ -254,7 +288,14 @@ } ], "source": [ - "p.get_experiment(exp.dataset_id, TestModel)" + "p.get_experiment(\"test-exp\", TestModel)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Experiment Wrapper" ] }, { @@ -337,25 +378,50 @@ " if name_prefix:\n", " name = f\"{name_prefix}-{name}\"\n", "\n", - " # Create tasks for all items\n", - " tasks = []\n", - " for item in dataset:\n", - " tasks.append(wrapped_experiment(item))\n", - "\n", - " # Use as_completed with tqdm for progress tracking\n", - " results = []\n", - " for future in tqdm(asyncio.as_completed(tasks), total=len(tasks)):\n", - " result = await future\n", - " # Add each result to experiment view as it completes\n", - " if result is not None:\n", - " results.append(result)\n", - "\n", - " # upload results to experiment view\n", - " experiment_view = self.create_experiment(name=name, model=experiment_model)\n", - " for result in results:\n", - " experiment_view.append(result)\n", - "\n", - " return experiment_view\n", + " experiment_view = None\n", + " try:\n", + " # Create the experiment view upfront\n", + " experiment_view = self.create_experiment(name=name, model=experiment_model)\n", + " \n", + " # Create tasks for all items\n", + " tasks = []\n", + " for item in dataset:\n", + " tasks.append(wrapped_experiment(item))\n", + "\n", + " # Calculate total operations (processing + appending)\n", + " total_operations = len(tasks) * 2 # Each item requires processing and appending\n", + " \n", + " # Use tqdm for combined progress tracking\n", + " results = []\n", + " progress_bar = tqdm(total=total_operations, desc=\"Running experiment\")\n", + " \n", + " # Process all items\n", + " for future in asyncio.as_completed(tasks):\n", + " result = await future\n", + " if result is not None:\n", + " results.append(result)\n", + " progress_bar.update(1) # Update for task completion\n", + " \n", + " # Append results to experiment view\n", + " for result in results:\n", + " experiment_view.append(result)\n", + " progress_bar.update(1) # Update for append operation\n", + " \n", + " progress_bar.close()\n", + " return experiment_view\n", + " \n", + " except Exception as e:\n", + " # Clean up the experiment if there was an error and it was created\n", + " if experiment_view is not None:\n", + " try:\n", + " # Delete the experiment (you might need to implement this method)\n", + " sync_version = async_to_sync(self._ragas_api_client.delete_experiment)\n", + " sync_version(project_id=self.project_id, experiment_id=experiment_view.experiment_id)\n", + " except Exception as cleanup_error:\n", + " print(f\"Failed to clean up experiment after error: {cleanup_error}\")\n", + " \n", + " # Re-raise the original exception\n", + " raise e\n", "\n", " wrapped_experiment.__setattr__(\"run_async\", run_async)\n", " return t.cast(ExperimentProtocol, wrapped_experiment)\n", @@ -390,7 +456,6 @@ "# create a test experiment function\n", "@p.experiment(TextExperimentModel)\n", "async def test_experiment(item: TestModel):\n", - " print(item)\n", " return TextExperimentModel(**item.model_dump(), response=\"test response\", is_correct=\"yes\")\n" ] }, @@ -403,22 +468,13 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 3/3 [00:00<00:00, 7752.87it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "name='test item 2' description='test item 2 description' price=200.0\n", - "name='test item 1' description='test item 1 description' price=100.0\n", - "name='test item 3' description='test item 3 description' price=300.0\n" + "Running experiment: 100%|██████████| 6/6 [00:01<00:00, 3.84it/s]\n" ] }, { "data": { "text/plain": [ - "Experiment(name=keen_backus, model=TextExperimentModel)" + "Experiment(name=gallant_torvalds, model=TextExperimentModel)" ] }, "execution_count": null, diff --git a/ragas_experimental/_modidx.py b/ragas_experimental/_modidx.py index e58fd07..0ab2f12 100644 --- a/ragas_experimental/_modidx.py +++ b/ragas_experimental/_modidx.py @@ -107,6 +107,8 @@ 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient._get_resource': ( 'backends/ragas_api_client.html#ragasapiclient._get_resource', 'ragas_experimental/backends/ragas_api_client.py'), + 'ragas_experimental.backends.ragas_api_client.RagasApiClient._get_resource_by_name': ( 'backends/ragas_api_client.html#ragasapiclient._get_resource_by_name', + 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient._list_resources': ( 'backends/ragas_api_client.html#ragasapiclient._list_resources', 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient._request': ( 'backends/ragas_api_client.html#ragasapiclient._request', @@ -155,18 +157,24 @@ 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_dataset': ( 'backends/ragas_api_client.html#ragasapiclient.get_dataset', 'ragas_experimental/backends/ragas_api_client.py'), + 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_dataset_by_name': ( 'backends/ragas_api_client.html#ragasapiclient.get_dataset_by_name', + 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_dataset_column': ( 'backends/ragas_api_client.html#ragasapiclient.get_dataset_column', 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_dataset_row': ( 'backends/ragas_api_client.html#ragasapiclient.get_dataset_row', 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_experiment': ( 'backends/ragas_api_client.html#ragasapiclient.get_experiment', 'ragas_experimental/backends/ragas_api_client.py'), + 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_experiment_by_name': ( 'backends/ragas_api_client.html#ragasapiclient.get_experiment_by_name', + 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_experiment_column': ( 'backends/ragas_api_client.html#ragasapiclient.get_experiment_column', 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_experiment_row': ( 'backends/ragas_api_client.html#ragasapiclient.get_experiment_row', 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_project': ( 'backends/ragas_api_client.html#ragasapiclient.get_project', 'ragas_experimental/backends/ragas_api_client.py'), + 'ragas_experimental.backends.ragas_api_client.RagasApiClient.get_project_by_name': ( 'backends/ragas_api_client.html#ragasapiclient.get_project_by_name', + 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.list_dataset_columns': ( 'backends/ragas_api_client.html#ragasapiclient.list_dataset_columns', 'ragas_experimental/backends/ragas_api_client.py'), 'ragas_experimental.backends.ragas_api_client.RagasApiClient.list_dataset_rows': ( 'backends/ragas_api_client.html#ragasapiclient.list_dataset_rows', @@ -255,10 +263,28 @@ 'ragas_experimental/embedding/base.py'), 'ragas_experimental.embedding.base.ragas_embedding': ( 'embedding/base.html#ragas_embedding', 'ragas_experimental/embedding/base.py')}, - 'ragas_experimental.exceptions': { 'ragas_experimental.exceptions.DuplicateError': ( 'exceptions.html#duplicateerror', + 'ragas_experimental.exceptions': { 'ragas_experimental.exceptions.DatasetNotFoundError': ( 'exceptions.html#datasetnotfounderror', + 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.DuplicateDatasetError': ( 'exceptions.html#duplicatedataseterror', + 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.DuplicateError': ( 'exceptions.html#duplicateerror', 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.DuplicateExperimentError': ( 'exceptions.html#duplicateexperimenterror', + 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.DuplicateProjectError': ( 'exceptions.html#duplicateprojecterror', + 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.DuplicateResourceError': ( 'exceptions.html#duplicateresourceerror', + 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.ExperimentNotFoundError': ( 'exceptions.html#experimentnotfounderror', + 'ragas_experimental/exceptions.py'), 'ragas_experimental.exceptions.NotFoundError': ( 'exceptions.html#notfounderror', 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.ProjectNotFoundError': ( 'exceptions.html#projectnotfounderror', + 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.RagasError': ( 'exceptions.html#ragaserror', + 'ragas_experimental/exceptions.py'), + 'ragas_experimental.exceptions.ResourceNotFoundError': ( 'exceptions.html#resourcenotfounderror', + 'ragas_experimental/exceptions.py'), 'ragas_experimental.exceptions.ValidationError': ( 'exceptions.html#validationerror', 'ragas_experimental/exceptions.py')}, 'ragas_experimental.experiment': { 'ragas_experimental.experiment.Experiment': ( 'experiment.html#experiment', @@ -559,8 +585,12 @@ 'ragas_experimental/project/core.py'), 'ragas_experimental.project.core.Project.delete': ( 'project/core.html#project.delete', 'ragas_experimental/project/core.py'), + 'ragas_experimental.project.core.Project.get': ( 'project/core.html#project.get', + 'ragas_experimental/project/core.py'), 'ragas_experimental.project.core.Project.get_dataset': ( 'project/core.html#project.get_dataset', 'ragas_experimental/project/core.py'), + 'ragas_experimental.project.core.Project.get_dataset_by_id': ( 'project/core.html#project.get_dataset_by_id', + 'ragas_experimental/project/core.py'), 'ragas_experimental.project.core.create_dataset_columns': ( 'project/core.html#create_dataset_columns', 'ragas_experimental/project/core.py')}, 'ragas_experimental.project.experiments': { 'ragas_experimental.project.experiments.ExperimentProtocol': ( 'project/experiments.html#experimentprotocol', @@ -575,6 +605,8 @@ 'ragas_experimental/project/experiments.py'), 'ragas_experimental.project.experiments.Project.get_experiment': ( 'project/experiments.html#project.get_experiment', 'ragas_experimental/project/experiments.py'), + 'ragas_experimental.project.experiments.Project.get_experiment_by_id': ( 'project/experiments.html#project.get_experiment_by_id', + 'ragas_experimental/project/experiments.py'), 'ragas_experimental.project.experiments.Project.langfuse_experiment': ( 'project/experiments.html#project.langfuse_experiment', 'ragas_experimental/project/experiments.py'), 'ragas_experimental.project.experiments.create_experiment_columns': ( 'project/experiments.html#create_experiment_columns', diff --git a/ragas_experimental/backends/ragas_api_client.py b/ragas_experimental/backends/ragas_api_client.py index 04ce591..4a3ce84 100644 --- a/ragas_experimental/backends/ragas_api_client.py +++ b/ragas_experimental/backends/ragas_api_client.py @@ -5,13 +5,20 @@ # %% auto 0 __all__ = ['DEFAULT_SETTINGS', 'RagasApiClient', 'create_nano_id', 'Column', 'RowCell', 'Row'] -# %% ../../nbs/backends/ragas_api_client.ipynb 4 +# %% ../../nbs/backends/ragas_api_client.ipynb 3 import httpx import asyncio import typing as t from pydantic import BaseModel, Field from fastcore.utils import patch +# %% ../../nbs/backends/ragas_api_client.ipynb 4 +from ragas_experimental.exceptions import ( + DatasetNotFoundError, DuplicateDatasetError, + ProjectNotFoundError, DuplicateProjectError, + ExperimentNotFoundError, DuplicateExperimentError +) + # %% ../../nbs/backends/ragas_api_client.ipynb 5 class RagasApiClient(): """Client for the Ragas Relay API.""" @@ -85,6 +92,94 @@ async def _delete_resource(self, path): return await self._request("DELETE", path) # %% ../../nbs/backends/ragas_api_client.ipynb 6 +@patch +async def _get_resource_by_name( + self: RagasApiClient, + list_method: t.Callable, + get_method: t.Callable, + resource_name: str, + name_field: str, + not_found_error: t.Type[Exception], + duplicate_error: t.Type[Exception], + resource_type_name: str, + **list_method_kwargs +) -> t.Dict: + """Generic method to get a resource by name. + + Args: + list_method: Method to list resources + get_method: Method to get a specific resource + resource_name: Name to search for + name_field: Field name that contains the resource name + not_found_error: Exception to raise when resource is not found + duplicate_error: Exception to raise when multiple resources are found + resource_type_name: Human-readable name of the resource type + **list_method_kwargs: Additional arguments to pass to list_method + + Returns: + The resource information dictionary + + Raises: + Exception: If resource is not found or multiple resources are found + """ + # Initial pagination parameters + limit = 50 # Number of items per page + offset = 0 # Starting position + matching_resources = [] + + while True: + # Get a page of resources + response = await list_method( + limit=limit, + offset=offset, + **list_method_kwargs + ) + + items = response.get("items", []) + + # If no items returned, we've reached the end + if not items: + break + + # Collect all resources with the matching name in this page + for resource in items: + if resource.get(name_field) == resource_name: + matching_resources.append(resource) + + # Update offset for the next page + offset += limit + + # If we've processed all items (less than limit returned), exit the loop + if len(items) < limit: + break + + # Check results + if not matching_resources: + context = list_method_kwargs.get("project_id", "") + context_msg = f" in project {context}" if context else "" + raise not_found_error( + f"No {resource_type_name} with name '{resource_name}' found{context_msg}" + ) + + if len(matching_resources) > 1: + # Multiple matches found - construct an informative error message + resource_ids = [r.get("id") for r in matching_resources] + context = list_method_kwargs.get("project_id", "") + context_msg = f" in project {context}" if context else "" + + raise duplicate_error( + f"Multiple {resource_type_name}s found with name '{resource_name}'{context_msg}. " + f"{resource_type_name.capitalize()} IDs: {', '.join(resource_ids)}. " + f"Please use get_{resource_type_name}() with a specific ID instead." + ) + + # Exactly one match found - retrieve full details + if "project_id" in list_method_kwargs: + return await get_method(list_method_kwargs["project_id"], matching_resources[0].get("id")) + else: + return await get_method(matching_resources[0].get("id")) + +# %% ../../nbs/backends/ragas_api_client.ipynb 8 #---- Projects ---- @patch async def list_projects( @@ -147,6 +242,33 @@ async def delete_project(self: RagasApiClient, project_id: str) -> None: # %% ../../nbs/backends/ragas_api_client.ipynb 13 +@patch +async def get_project_by_name( + self: RagasApiClient, project_name: str +) -> t.Dict: + """Get a project by its name. + + Args: + project_name: Name of the project to find + + Returns: + The project information dictionary + + Raises: + ProjectNotFoundError: If no project with the given name is found + DuplicateProjectError: If multiple projects with the given name are found + """ + return await self._get_resource_by_name( + list_method=self.list_projects, + get_method=self.get_project, + resource_name=project_name, + name_field="title", # Projects use 'title' instead of 'name' + not_found_error=ProjectNotFoundError, + duplicate_error=DuplicateProjectError, + resource_type_name="project" + ) + +# %% ../../nbs/backends/ragas_api_client.ipynb 16 #---- Datasets ---- @patch async def list_datasets( @@ -201,7 +323,36 @@ async def delete_dataset(self: RagasApiClient, project_id: str, dataset_id: str) """Delete a dataset.""" await self._delete_resource(f"projects/{project_id}/datasets/{dataset_id}") -# %% ../../nbs/backends/ragas_api_client.ipynb 20 +# %% ../../nbs/backends/ragas_api_client.ipynb 23 +@patch +async def get_dataset_by_name( + self: RagasApiClient, project_id: str, dataset_name: str +) -> t.Dict: + """Get a dataset by its name. + + Args: + project_id: ID of the project + dataset_name: Name of the dataset to find + + Returns: + The dataset information dictionary + + Raises: + DatasetNotFoundError: If no dataset with the given name is found + DuplicateDatasetError: If multiple datasets with the given name are found + """ + return await self._get_resource_by_name( + list_method=self.list_datasets, + get_method=self.get_dataset, + resource_name=dataset_name, + name_field="name", + not_found_error=DatasetNotFoundError, + duplicate_error=DuplicateDatasetError, + resource_type_name="dataset", + project_id=project_id + ) + +# %% ../../nbs/backends/ragas_api_client.ipynb 26 #---- Experiments ---- @patch async def list_experiments( @@ -257,10 +408,39 @@ async def delete_experiment(self: RagasApiClient, project_id: str, experiment_id await self._delete_resource(f"projects/{project_id}/experiments/{experiment_id}") -# %% ../../nbs/backends/ragas_api_client.ipynb 25 +# %% ../../nbs/backends/ragas_api_client.ipynb 29 +@patch +async def get_experiment_by_name( + self: RagasApiClient, project_id: str, experiment_name: str +) -> t.Dict: + """Get an experiment by its name. + + Args: + project_id: ID of the project containing the experiment + experiment_name: Name of the experiment to find + + Returns: + The experiment information dictionary + + Raises: + ExperimentNotFoundError: If no experiment with the given name is found + DuplicateExperimentError: If multiple experiments with the given name are found + """ + return await self._get_resource_by_name( + list_method=self.list_experiments, + get_method=self.get_experiment, + resource_name=experiment_name, + name_field="name", + not_found_error=ExperimentNotFoundError, + duplicate_error=DuplicateExperimentError, + resource_type_name="experiment", + project_id=project_id + ) + +# %% ../../nbs/backends/ragas_api_client.ipynb 33 from ..typing import ColumnType -# %% ../../nbs/backends/ragas_api_client.ipynb 26 +# %% ../../nbs/backends/ragas_api_client.ipynb 34 #---- Dataset Columns ---- @patch async def list_dataset_columns( @@ -331,7 +511,7 @@ async def delete_dataset_column( f"projects/{project_id}/datasets/{dataset_id}/columns/{column_id}" ) -# %% ../../nbs/backends/ragas_api_client.ipynb 34 +# %% ../../nbs/backends/ragas_api_client.ipynb 42 #---- Dataset Rows ---- @patch async def list_dataset_rows( @@ -393,11 +573,11 @@ async def delete_dataset_row( ) -# %% ../../nbs/backends/ragas_api_client.ipynb 46 +# %% ../../nbs/backends/ragas_api_client.ipynb 54 import uuid import string -# %% ../../nbs/backends/ragas_api_client.ipynb 47 +# %% ../../nbs/backends/ragas_api_client.ipynb 55 def create_nano_id(size=12): # Define characters to use (alphanumeric) alphabet = string.ascii_letters + string.digits @@ -414,11 +594,11 @@ def create_nano_id(size=12): # Pad if necessary and return desired length return result[:size] -# %% ../../nbs/backends/ragas_api_client.ipynb 49 +# %% ../../nbs/backends/ragas_api_client.ipynb 57 import uuid import string -# %% ../../nbs/backends/ragas_api_client.ipynb 50 +# %% ../../nbs/backends/ragas_api_client.ipynb 58 def create_nano_id(size=12): # Define characters to use (alphanumeric) alphabet = string.ascii_letters + string.digits @@ -435,7 +615,7 @@ def create_nano_id(size=12): # Pad if necessary and return desired length return result[:size] -# %% ../../nbs/backends/ragas_api_client.ipynb 52 +# %% ../../nbs/backends/ragas_api_client.ipynb 60 # Default settings for columns DEFAULT_SETTINGS = { "is_required": False, @@ -458,7 +638,7 @@ class Row(BaseModel): id: str = Field(default_factory=create_nano_id) data: t.List[RowCell] = Field(...) -# %% ../../nbs/backends/ragas_api_client.ipynb 53 +# %% ../../nbs/backends/ragas_api_client.ipynb 61 #---- Resource With Data Helper Methods ---- @patch async def _create_with_data( @@ -585,7 +765,7 @@ async def create_dataset_with_data( "dataset", project_id, name, description, columns, rows, batch_size ) -# %% ../../nbs/backends/ragas_api_client.ipynb 59 +# %% ../../nbs/backends/ragas_api_client.ipynb 67 #---- Experiment Columns ---- @patch async def list_experiment_columns( @@ -716,7 +896,7 @@ async def delete_experiment_row( f"projects/{project_id}/experiments/{experiment_id}/rows/{row_id}" ) -# %% ../../nbs/backends/ragas_api_client.ipynb 62 +# %% ../../nbs/backends/ragas_api_client.ipynb 70 @patch async def create_experiment_with_data( self: RagasApiClient, @@ -747,7 +927,7 @@ async def create_experiment_with_data( "experiment", project_id, name, description, columns, rows, batch_size ) -# %% ../../nbs/backends/ragas_api_client.ipynb 63 +# %% ../../nbs/backends/ragas_api_client.ipynb 71 #---- Utility Methods ---- @patch def create_column( diff --git a/ragas_experimental/exceptions.py b/ragas_experimental/exceptions.py index 051a2f3..d0bd6f8 100644 --- a/ragas_experimental/exceptions.py +++ b/ragas_experimental/exceptions.py @@ -3,9 +3,16 @@ # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/exceptions.ipynb. # %% auto 0 -__all__ = ['ValidationError', 'DuplicateError', 'NotFoundError'] +__all__ = ['RagasError', 'ValidationError', 'DuplicateError', 'NotFoundError', 'ResourceNotFoundError', 'ProjectNotFoundError', + 'DatasetNotFoundError', 'ExperimentNotFoundError', 'DuplicateResourceError', 'DuplicateProjectError', + 'DuplicateDatasetError', 'DuplicateExperimentError'] # %% ../nbs/exceptions.ipynb 2 +class RagasError(Exception): + """Base class for all Ragas-related exceptions.""" + pass + +# %% ../nbs/exceptions.ipynb 3 class ValidationError(Exception): """Raised when field validation fails.""" @@ -22,3 +29,37 @@ class NotFoundError(Exception): """Raised when an item is not found.""" pass + +# %% ../nbs/exceptions.ipynb 4 +class ResourceNotFoundError(RagasError): + """Exception raised when a requested resource doesn't exist.""" + pass + +class ProjectNotFoundError(ResourceNotFoundError): + """Exception raised when a project doesn't exist.""" + pass + +class DatasetNotFoundError(ResourceNotFoundError): + """Exception raised when a dataset doesn't exist.""" + pass + +class ExperimentNotFoundError(ResourceNotFoundError): + """Exception raised when an experiment doesn't exist.""" + pass + +# %% ../nbs/exceptions.ipynb 5 +class DuplicateResourceError(RagasError): + """Exception raised when multiple resources exist with the same identifier.""" + pass + +class DuplicateProjectError(DuplicateResourceError): + """Exception raised when multiple projects exist with the same name.""" + pass + +class DuplicateDatasetError(DuplicateResourceError): + """Exception raised when multiple datasets exist with the same name.""" + pass + +class DuplicateExperimentError(DuplicateResourceError): + """Exception raised when multiple experiments exist with the same name.""" + pass diff --git a/ragas_experimental/project/core.py b/ragas_experimental/project/core.py index a4a1f27..b892208 100644 --- a/ragas_experimental/project/core.py +++ b/ragas_experimental/project/core.py @@ -63,7 +63,27 @@ def delete(self): def __repr__(self): return f"Project(name='{self.name}')" -# %% ../../nbs/project/core.ipynb 10 +# %% ../../nbs/project/core.ipynb 8 +@patch(cls_method=True) +def get(cls: Project, name: str, ragas_app_client: t.Optional[RagasApiClient] = None) -> Project: + """Get an existing project by name.""" + # Search for project with given name + if ragas_app_client is None: + ragas_app_client = RagasApiClientFactory.create() + + # get the project by name + sync_version = async_to_sync(ragas_app_client.get_project_by_name) + project_info = sync_version( + project_name=name + ) + + # Return Project instance + return Project( + project_id=project_info["id"], + ragas_app_client=ragas_app_client, + ) + +# %% ../../nbs/project/core.ipynb 12 async def create_dataset_columns(project_id, dataset_id, columns, create_dataset_column_func): tasks = [] for column in columns: @@ -81,7 +101,7 @@ async def create_dataset_columns(project_id, dataset_id, columns, create_dataset return await asyncio.gather(*tasks) -# %% ../../nbs/project/core.ipynb 11 +# %% ../../nbs/project/core.ipynb 13 @patch def create_dataset( self: Project, model: t.Type[BaseModel], name: t.Optional[str] = None @@ -121,9 +141,9 @@ def create_dataset( ragas_api_client=self._ragas_api_client, ) -# %% ../../nbs/project/core.ipynb 15 +# %% ../../nbs/project/core.ipynb 17 @patch -def get_dataset(self: Project, dataset_id: str, model) -> Dataset: +def get_dataset_by_id(self: Project, dataset_id: str, model) -> Dataset: """Get an existing dataset by name.""" # Search for database with given name sync_version = async_to_sync(self._ragas_api_client.get_dataset) @@ -140,3 +160,23 @@ def get_dataset(self: Project, dataset_id: str, model) -> Dataset: dataset_id=dataset_id, ragas_api_client=self._ragas_api_client, ) + +# %% ../../nbs/project/core.ipynb 19 +@patch +def get_dataset(self: Project, dataset_name: str, model) -> Dataset: + """Get an existing dataset by name.""" + # Search for dataset with given name + sync_version = async_to_sync(self._ragas_api_client.get_dataset_by_name) + dataset_info = sync_version( + project_id=self.project_id, + dataset_name=dataset_name + ) + + # Return Dataset instance + return Dataset( + name=dataset_info["name"], + model=model, + project_id=self.project_id, + dataset_id=dataset_info["id"], + ragas_api_client=self._ragas_api_client, + ) diff --git a/ragas_experimental/project/experiments.py b/ragas_experimental/project/experiments.py index b39c6f0..83b0169 100644 --- a/ragas_experimental/project/experiments.py +++ b/ragas_experimental/project/experiments.py @@ -21,7 +21,7 @@ from ..experiment import Experiment import ragas_experimental.typing as rt -# %% ../../nbs/project/experiments.ipynb 3 +# %% ../../nbs/project/experiments.ipynb 4 @patch def create_experiment( self: Project, name: str, model: t.Type[BaseModel] @@ -78,9 +78,9 @@ async def create_experiment_columns(project_id, experiment_id, columns, create_e )) return await asyncio.gather(*tasks) -# %% ../../nbs/project/experiments.ipynb 7 +# %% ../../nbs/project/experiments.ipynb 8 @patch -def get_experiment(self: Project, experiment_id: str, model: t.Type[BaseModel]) -> Experiment: +def get_experiment_by_id(self: Project, experiment_id: str, model: t.Type[BaseModel]) -> Experiment: """Get an existing experiment by ID.""" # Get experiment info sync_version = async_to_sync(self._ragas_api_client.get_experiment) @@ -98,22 +98,42 @@ def get_experiment(self: Project, experiment_id: str, model: t.Type[BaseModel]) ) # %% ../../nbs/project/experiments.ipynb 11 +@patch +def get_experiment(self: Project, dataset_name: str, model) -> Dataset: + """Get an existing dataset by name.""" + # Search for dataset with given name + sync_version = async_to_sync(self._ragas_api_client.get_experiment_by_name) + exp_info = sync_version( + project_id=self.project_id, + experiment_name=dataset_name + ) + + # Return Dataset instance + return Experiment( + name=exp_info["name"], + model=model, + project_id=self.project_id, + experiment_id=exp_info["id"], + ragas_api_client=self._ragas_api_client, + ) + +# %% ../../nbs/project/experiments.ipynb 14 @t.runtime_checkable class ExperimentProtocol(t.Protocol): async def __call__(self, *args, **kwargs): ... async def run_async(self, name: str, dataset: Dataset): ... -# %% ../../nbs/project/experiments.ipynb 12 +# %% ../../nbs/project/experiments.ipynb 15 # this one we have to clean up from langfuse.decorators import observe -# %% ../../nbs/project/experiments.ipynb 13 +# %% ../../nbs/project/experiments.ipynb 16 from .naming import MemorableNames -# %% ../../nbs/project/experiments.ipynb 14 +# %% ../../nbs/project/experiments.ipynb 17 memorable_names = MemorableNames() -# %% ../../nbs/project/experiments.ipynb 15 +# %% ../../nbs/project/experiments.ipynb 18 @patch def experiment( self: Project, experiment_model, name_prefix: str = "" @@ -142,32 +162,57 @@ async def run_async(dataset: Dataset, name: t.Optional[str] = None): if name_prefix: name = f"{name_prefix}-{name}" - # Create tasks for all items - tasks = [] - for item in dataset: - tasks.append(wrapped_experiment(item)) - - # Use as_completed with tqdm for progress tracking - results = [] - for future in tqdm(asyncio.as_completed(tasks), total=len(tasks)): - result = await future - # Add each result to experiment view as it completes - if result is not None: - results.append(result) - - # upload results to experiment view - experiment_view = self.create_experiment(name=name, model=experiment_model) - for result in results: - experiment_view.append(result) - - return experiment_view + experiment_view = None + try: + # Create the experiment view upfront + experiment_view = self.create_experiment(name=name, model=experiment_model) + + # Create tasks for all items + tasks = [] + for item in dataset: + tasks.append(wrapped_experiment(item)) + + # Calculate total operations (processing + appending) + total_operations = len(tasks) * 2 # Each item requires processing and appending + + # Use tqdm for combined progress tracking + results = [] + progress_bar = tqdm(total=total_operations, desc="Running experiment") + + # Process all items + for future in asyncio.as_completed(tasks): + result = await future + if result is not None: + results.append(result) + progress_bar.update(1) # Update for task completion + + # Append results to experiment view + for result in results: + experiment_view.append(result) + progress_bar.update(1) # Update for append operation + + progress_bar.close() + return experiment_view + + except Exception as e: + # Clean up the experiment if there was an error and it was created + if experiment_view is not None: + try: + # Delete the experiment (you might need to implement this method) + sync_version = async_to_sync(self._ragas_api_client.delete_experiment) + sync_version(project_id=self.project_id, experiment_id=experiment_view.experiment_id) + except Exception as cleanup_error: + print(f"Failed to clean up experiment after error: {cleanup_error}") + + # Re-raise the original exception + raise e wrapped_experiment.__setattr__("run_async", run_async) return t.cast(ExperimentProtocol, wrapped_experiment) return decorator -# %% ../../nbs/project/experiments.ipynb 19 +# %% ../../nbs/project/experiments.ipynb 22 @patch def langfuse_experiment( self: Project, experiment_model, name_prefix: str = ""