diff --git a/nbs/embedding/base.ipynb b/nbs/embedding/base.ipynb new file mode 100644 index 0000000..fd23689 --- /dev/null +++ b/nbs/embedding/base.ipynb @@ -0,0 +1,1150 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#| default_exp embedding.base" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## RagasEmbedding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "import typing as t\n", + "from abc import ABC, abstractmethod\n", + "\n", + "#TODO: Add support for other providers like HuggingFace, Cohere, etc.\n", + "#TODO: handle async calls properly and ensure that the client supports async if needed.\n", + "\n", + "class BaseEmbedding(ABC):\n", + " @abstractmethod\n", + " def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:\n", + " pass\n", + " \n", + " @abstractmethod\n", + " async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:\n", + " pass\n", + " \n", + " @abstractmethod\n", + " def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:\n", + " pass\n", + " \n", + " @abstractmethod\n", + " async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:\n", + " pass\n", + "\n", + "\n", + "class OpenAIEmbeddings(BaseEmbedding):\n", + " def __init__(self, client: t.Any, model: str):\n", + " self.client = client\n", + " self.model = model\n", + " \n", + " def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:\n", + " return self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding\n", + " \n", + " async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:\n", + " response = await self.client.embeddings.create(input=text, model=self.model, **kwargs)\n", + " return response.data[0].embedding\n", + " \n", + " def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:\n", + " embeddings = self.client.embeddings.create(input=documents, model=self.model, **kwargs)\n", + " return [embedding.embedding for embedding in embeddings.data]\n", + " \n", + " async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:\n", + " embeddings = await self.client.embeddings.create(input=documents, model=self.model, **kwargs)\n", + " return [embedding.embedding for embedding in embeddings.data]\n", + " \n", + " \n", + "def ragas_embedding(provider: str, model: str, client: t.Any) -> BaseEmbedding:\n", + " \"\"\"\n", + " Factory function to create an embedding instance based on the provider.\n", + " \n", + " Args:\n", + " provider (str): The name of the embedding provider (e.g., \"openai\").\n", + " model (str): The model name to use for embeddings.\n", + " **kwargs: Additional arguments for the provider's client.\n", + " \n", + " Returns:\n", + " BaseEmbedding: An instance of the specified embedding provider.\n", + " \"\"\"\n", + " if provider.lower() == \"openai\":\n", + " return OpenAIEmbeddings(client=client, model=model)\n", + " \n", + " raise ValueError(f\"Unsupported provider: {provider}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example Usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[-0.019184619188308716,\n", + " -0.025279032066464424,\n", + " -0.0017195191467180848,\n", + " 0.01884828321635723,\n", + " -0.033795066177845,\n", + " -0.01969585195183754,\n", + " -0.02094702236354351,\n", + " 0.051580529659986496,\n", + " -0.03212684020400047,\n", + " -0.030377890914678574,\n", + " -0.002145825419574976,\n", + " -0.028978731483221054,\n", + " -0.0024737531784921885,\n", + " -0.031481072306632996,\n", + " 0.010332250036299229,\n", + " 0.018606122583150864,\n", + " -0.04614533483982086,\n", + " 0.04146353527903557,\n", + " 0.0004418617463670671,\n", + " 0.04122137278318405,\n", + " 0.05367926508188248,\n", + " 0.0018733929609879851,\n", + " 0.0045674461871385574,\n", + " 0.010022819973528385,\n", + " 0.04786737635731697,\n", + " 0.0022013208363205194,\n", + " -0.009834472090005875,\n", + " 0.03847686946392059,\n", + " 0.00089213193859905,\n", + " -0.05211866647005081,\n", + " 0.051150016486644745,\n", + " -0.032557349652051926,\n", + " -0.014031948521733284,\n", + " -0.012632790021598339,\n", + " 0.013271828182041645,\n", + " 0.018565760925412178,\n", + " 0.0016068464610725641,\n", + " -0.0008185583865270019,\n", + " -0.012753871269524097,\n", + " -0.029705218970775604,\n", + " -0.004443001933395863,\n", + " -0.015323479659855366,\n", + " 0.025655729696154594,\n", + " 0.009107985533773899,\n", + " -0.03686245530843735,\n", + " 0.020328164100646973,\n", + " -0.04071014001965523,\n", + " -0.002621741034090519,\n", + " 0.03549019992351532,\n", + " 0.04851314052939415,\n", + " -0.03368743881583214,\n", + " -0.002441801130771637,\n", + " 0.017260776832699776,\n", + " 0.07598508894443512,\n", + " 0.0009232430020347238,\n", + " -0.04267434403300285,\n", + " 0.008381499908864498,\n", + " 0.0760388970375061,\n", + " -0.047275424003601074,\n", + " 0.015081318095326424,\n", + " 0.014247204177081585,\n", + " 0.024700535461306572,\n", + " 0.010197714902460575,\n", + " -0.000978738535195589,\n", + " 0.013789786025881767,\n", + " -0.010103541426360607,\n", + " -0.020704859867691994,\n", + " -0.001531170797534287,\n", + " -0.011717955581843853,\n", + " 0.04934725537896156,\n", + " 0.0010939337080344558,\n", + " 0.037831101566553116,\n", + " -0.019332608208060265,\n", + " 0.005855614319443703,\n", + " -0.046279869973659515,\n", + " -0.0045439028181135654,\n", + " -0.022359633818268776,\n", + " 0.008751469664275646,\n", + " -0.02657056413590908,\n", + " -0.05440575256943703,\n", + " -0.04423494264483452,\n", + " 0.019332608208060265,\n", + " -0.03091602772474289,\n", + " -0.06037908419966698,\n", + " -0.018888644874095917,\n", + " 0.004372371360659599,\n", + " -0.02389332838356495,\n", + " -0.012027384713292122,\n", + " -0.016601556912064552,\n", + " 0.0022013208363205194,\n", + " -0.00802498310804367,\n", + " 0.01529657281935215,\n", + " -0.014960236847400665,\n", + " 0.01245789509266615,\n", + " 0.014502819627523422,\n", + " -0.027687201276421547,\n", + " -0.022790145128965378,\n", + " 0.05666593089699745,\n", + " 0.061024848371744156,\n", + " -0.04929343983530998,\n", + " 0.014610446989536285,\n", + " -0.027323957532644272,\n", + " 0.013251648284494877,\n", + " -0.0205434188246727,\n", + " 0.0298666600137949,\n", + " 0.022507622838020325,\n", + " 0.00819987803697586,\n", + " -0.04068323224782944,\n", + " -0.026584018021821976,\n", + " 0.004533812869340181,\n", + " -0.12474039196968079,\n", + " 0.009417415596544743,\n", + " 0.031803958117961884,\n", + " -0.031077470630407333,\n", + " 0.005801800638437271,\n", + " 0.030835308134555817,\n", + " 0.05367926508188248,\n", + " -0.039553143084049225,\n", + " 0.02342245727777481,\n", + " -0.05375998839735985,\n", + " 0.00868420209735632,\n", + " -0.01152287982404232,\n", + " 0.019534409046173096,\n", + " -0.04184022918343544,\n", + " -0.043131761252880096,\n", + " -0.04297031834721565,\n", + " 0.005852250847965479,\n", + " 0.057526953518390656,\n", + " -0.031481072306632996,\n", + " 0.019911106675863266,\n", + " 0.03944551572203636,\n", + " 0.03982221335172653,\n", + " 0.01127399131655693,\n", + " -0.0002850449818652123,\n", + " -0.045553382486104965,\n", + " 0.0018666662508621812,\n", + " -0.040656328201293945,\n", + " -0.013446723110973835,\n", + " -0.049105092883110046,\n", + " 0.047275424003601074,\n", + " 0.056450676172971725,\n", + " -0.047248516231775284,\n", + " -0.010890567675232887,\n", + " -0.00996228028088808,\n", + " -0.005926244892179966,\n", + " -0.04119446501135826,\n", + " -0.008791829459369183,\n", + " 0.026086239144206047,\n", + " -0.009948826394975185,\n", + " -0.00625585438683629,\n", + " 0.030377890914678574,\n", + " 0.060648154467344284,\n", + " -0.051230739802122116,\n", + " 0.025776810944080353,\n", + " 0.00377705623395741,\n", + " -0.002621741034090519,\n", + " 0.024512186646461487,\n", + " -0.016816813498735428,\n", + " -0.02782173454761505,\n", + " 0.015054411254823208,\n", + " 0.05510533228516579,\n", + " 0.039580050855875015,\n", + " -0.04436947777867317,\n", + " -0.007897174917161465,\n", + " -0.008146064355969429,\n", + " 0.00850930716842413,\n", + " -0.011744862422347069,\n", + " 0.002426665974780917,\n", + " -0.04361608624458313,\n", + " -0.002248407807201147,\n", + " 0.023974047973752022,\n", + " 0.020933568477630615,\n", + " -0.0211219172924757,\n", + " -0.04509596526622772,\n", + " -0.0192249808460474,\n", + " 0.02634185552597046,\n", + " 0.023449363186955452,\n", + " -0.04958941787481308,\n", + " -0.01622486114501953,\n", + " -0.025238672271370888,\n", + " 0.02852131426334381,\n", + " 0.04541884735226631,\n", + " 0.0022921315394341946,\n", + " 0.019090445712208748,\n", + " -0.026584018021821976,\n", + " -0.011179816909134388,\n", + " -0.004473272245377302,\n", + " -0.006804082542657852,\n", + " -0.011913030408322811,\n", + " 0.0008563962182961404,\n", + " -0.03298785910010338,\n", + " 0.056235421448946,\n", + " 0.023476270958781242,\n", + " 0.0019675670191645622,\n", + " 0.004510269034653902,\n", + " -0.03659338504076004,\n", + " 0.0669981837272644,\n", + " 0.00536792678758502,\n", + " -0.021565880626440048,\n", + " 0.02427002415060997,\n", + " -0.00038993984344415367,\n", + " 0.012706783600151539,\n", + " -0.05136527121067047,\n", + " -0.031884677708148956,\n", + " -0.02342245727777481,\n", + " -0.04186713695526123,\n", + " -1.4254876077757217e-05,\n", + " 0.07087277621030807,\n", + " -0.00837477296590805,\n", + " -0.05246845632791519,\n", + " 0.058603230863809586,\n", + " -0.014677714556455612,\n", + " -0.0541904978454113,\n", + " -0.0020482877735048532,\n", + " -0.04932034760713577,\n", + " -0.017879635095596313,\n", + " 0.041275184601545334,\n", + " 0.02229236625134945,\n", + " -0.011226904578506947,\n", + " -0.03161560744047165,\n", + " -0.07937535643577576,\n", + " 0.07157235592603683,\n", + " 0.08513343334197998,\n", + " -0.04122137278318405,\n", + " 0.030889121815562248,\n", + " -0.013339095748960972,\n", + " -0.008536214008927345,\n", + " -0.008213330991566181,\n", + " 0.04996611550450325,\n", + " 0.01458354014903307,\n", + " 0.020879754796624184,\n", + " 0.01826978474855423,\n", + " 0.02429693192243576,\n", + " -0.021431345492601395,\n", + " -0.010500418022274971,\n", + " -0.004325284156948328,\n", + " 0.036727920174598694,\n", + " -0.021350625902414322,\n", + " -0.005657176021486521,\n", + " -0.0071572354063391685,\n", + " -0.0387459360063076,\n", + " -0.0011199996806681156,\n", + " -0.006037235725671053,\n", + " 0.034252483397722244,\n", + " 0.04563410207629204,\n", + " -0.016103779897093773,\n", + " -0.042728159576654434,\n", + " -0.022413447499275208,\n", + " 0.011119276285171509,\n", + " 0.04076395556330681,\n", + " 0.017960356548428535,\n", + " 0.02724323607981205,\n", + " 0.005418376997113228,\n", + " -0.02036852389574051,\n", + " 0.017166603356599808,\n", + " -0.01021116878837347,\n", + " 0.006659457925707102,\n", + " -0.027458492666482925,\n", + " 0.042728159576654434,\n", + " -0.02106810361146927,\n", + " -0.048728395253419876,\n", + " -0.062101125717163086,\n", + " -0.035301852971315384,\n", + " -0.02779482863843441,\n", + " 0.012632790021598339,\n", + " -0.027404678985476494,\n", + " 0.004089849069714546,\n", + " -0.013897414319217205,\n", + " -0.016615010797977448,\n", + " -0.013164200820028782,\n", + " 0.04385824874043465,\n", + " -0.0075810193084180355,\n", + " 0.03266497701406479,\n", + " -0.004355554468929768,\n", + " -0.025803716853260994,\n", + " 0.0032876869663596153,\n", + " -0.005179578438401222,\n", + " -0.017328044399619102,\n", + " -0.01981693133711815,\n", + " 0.0369969867169857,\n", + " -0.025763357058167458,\n", + " -0.0014664260670542717,\n", + " 0.010513870976865292,\n", + " 0.033983416855335236,\n", + " -0.05131145939230919,\n", + " 0.008832190185785294,\n", + " 0.027081795036792755,\n", + " -0.01144888624548912,\n", + " 0.007722280453890562,\n", + " -0.02479470893740654,\n", + " 0.03277260437607765,\n", + " 0.02774101495742798,\n", + " 0.016278674826025963,\n", + " -0.02039542980492115,\n", + " 0.025911344215273857,\n", + " -0.002879038453102112,\n", + " -0.0013175972271710634,\n", + " -0.041651882231235504,\n", + " 0.038153983652591705,\n", + " 0.0025460654869675636,\n", + " 0.07695373892784119,\n", + " 0.0007592791225761175,\n", + " 0.04294341430068016,\n", + " -0.005845523905009031,\n", + " -0.001709428965114057,\n", + " 0.04154425486922264,\n", + " 0.015901979058980942,\n", + " -0.01701861433684826,\n", + " 0.05951806530356407,\n", + " -0.0013714110245928168,\n", + " -0.008959997445344925,\n", + " 0.009585583582520485,\n", + " 0.05666593089699745,\n", + " -0.02784864231944084,\n", + " 0.01347362995147705,\n", + " -0.045849356800317764,\n", + " 0.019857292994856834,\n", + " -0.019332608208060265,\n", + " 0.0009694892796687782,\n", + " -0.04003746807575226,\n", + " 0.023449363186955452,\n", + " -0.06199349835515022,\n", + " 0.009477955289185047,\n", + " -0.015713630244135857,\n", + " -0.015162038616836071,\n", + " -0.00862366147339344,\n", + " 0.045553382486104965,\n", + " 0.021538974717259407,\n", + " 0.0020180174615234137,\n", + " 0.013756153173744678,\n", + " 0.014664260670542717,\n", + " -0.02706834115087986,\n", + " -0.004664984066039324,\n", + " 0.010830027051270008,\n", + " 0.007224502973258495,\n", + " -0.016951346769928932,\n", + " -0.04372371360659599,\n", + " 0.05427121743559837,\n", + " 0.012767324224114418,\n", + " 0.04579554498195648,\n", + " -0.02657056413590908,\n", + " -0.027902456000447273,\n", + " 0.02179458923637867,\n", + " -0.03651266545057297,\n", + " -0.011987023986876011,\n", + " -0.0041941129602491856,\n", + " 0.033929601311683655,\n", + " -0.02712215483188629,\n", + " 0.004288287367671728,\n", + " 0.004399278201162815,\n", + " -0.017381858080625534,\n", + " -0.005243482068181038,\n", + " 0.016413209959864616,\n", + " -0.02464671991765499,\n", + " -0.01762402057647705,\n", + " -0.009868105873465538,\n", + " 0.0716799795627594,\n", + " -0.024727441370487213,\n", + " -0.019534409046173096,\n", + " 0.021256450563669205,\n", + " -0.006609007250517607,\n", + " -0.006915073376148939,\n", + " 0.00413020933046937,\n", + " -0.01210810523480177,\n", + " 0.03384888172149658,\n", + " 0.030431704595685005,\n", + " -0.007258136291056871,\n", + " -0.04081776738166809,\n", + " -0.007345583755522966,\n", + " 0.04385824874043465,\n", + " 0.013298735953867435,\n", + " 0.01475843507796526,\n", + " 0.032153744250535965,\n", + " -0.0036324316170066595,\n", + " -0.03479062393307686,\n", + " -0.015175491571426392,\n", + " 0.0117986761033535,\n", + " -0.00017373869195580482,\n", + " 0.059625692665576935,\n", + " -0.009249246679246426,\n", + " 0.04036035016179085,\n", + " 0.03371434658765793,\n", + " -0.019736211746931076,\n", + " -0.026610923931002617,\n", + " 0.010325523093342781,\n", + " -0.005855614319443703,\n", + " 0.0206914059817791,\n", + " 0.011381618678569794,\n", + " -0.01701861433684826,\n", + " 0.008576574735343456,\n", + " 0.03352599963545799,\n", + " -0.011563240550458431,\n", + " 0.004426185041666031,\n", + " 0.00951158907264471,\n", + " 0.007809727918356657,\n", + " -0.01757020689547062,\n", + " -0.021808043122291565,\n", + " -0.015188945457339287,\n", + " -0.022682517766952515,\n", + " -0.05763458088040352,\n", + " 0.04716779664158821,\n", + " -0.023664619773626328,\n", + " 0.007527205627411604,\n", + " 0.011401799507439137,\n", + " -0.02022053487598896,\n", + " -0.03347218409180641,\n", + " 0.012229186482727528,\n", + " 0.05112311244010925,\n", + " -0.0036391583271324635,\n", + " -0.023503176867961884,\n", + " 0.004083122126758099,\n", + " -0.052280109375715256,\n", + " 0.033956509083509445,\n", + " 0.03191158547997475,\n", + " -0.025036871433258057,\n", + " 0.00199615559540689,\n", + " -0.023261016234755516,\n", + " -0.03928407281637192,\n", + " -0.0007407806115224957,\n", + " -0.0041201189160346985,\n", + " 0.00614150008186698,\n", + " 0.019036632031202316,\n", + " -0.014153029769659042,\n", + " 0.025911344215273857,\n", + " -0.032557349652051926,\n", + " 0.04006437584757805,\n", + " 0.03062005341053009,\n", + " -0.028063897043466568,\n", + " 0.0187944695353508,\n", + " -0.08260418474674225,\n", + " -0.0015959155280143023,\n", + " -0.03573236241936684,\n", + " -0.00360216130502522,\n", + " 0.03624359518289566,\n", + " 0.02631494775414467,\n", + " -0.04617224261164665,\n", + " 0.002162642078474164,\n", + " -0.006302941590547562,\n", + " 0.058603230863809586,\n", + " 0.02322065457701683,\n", + " -0.0025494287256151438,\n", + " 0.009013812057673931,\n", + " 0.008832190185785294,\n", + " 0.0022988582495599985,\n", + " -0.009350148029625416,\n", + " -0.05384070798754692,\n", + " -0.003153152298182249,\n", + " -0.013857053592801094,\n", + " -0.040548697113990784,\n", + " 0.017812367528676987,\n", + " 0.0035248040221631527,\n", + " -0.04358917847275734,\n", + " 0.013177654705941677,\n", + " 0.013978134840726852,\n", + " 0.03134653717279434,\n", + " 0.015175491571426392,\n", + " -0.0002869368763640523,\n", + " 0.01687062717974186,\n", + " 0.01992456056177616,\n", + " 0.026449482887983322,\n", + " -0.0039048639591783285,\n", + " 0.0231668408960104,\n", + " -0.04773284122347832,\n", + " 0.052172478288412094,\n", + " 0.006410568952560425,\n", + " -0.0035718909930437803,\n", + " -0.02284395880997181,\n", + " 0.023328281939029694,\n", + " -0.016305582597851753,\n", + " -0.02229236625134945,\n", + " -0.012525161728262901,\n", + " 0.025077231228351593,\n", + " 0.008226784877479076,\n", + " -0.023758793249726295,\n", + " -0.020314710214734077,\n", + " -0.018202519044280052,\n", + " -0.05445956811308861,\n", + " 0.01547146774828434,\n", + " -0.044154223054647446,\n", + " 0.0001709008647594601,\n", + " 0.027525758370757103,\n", + " 0.007002520840615034,\n", + " 0.04143662750720978,\n", + " 0.02919398620724678,\n", + " -0.003316275542601943,\n", + " 0.009773931466042995,\n", + " -0.07211049646139145,\n", + " 0.026732005178928375,\n", + " -0.004042761866003275,\n", + " -0.010231348685920238,\n", + " -0.034333206713199615,\n", + " 0.06193968653678894,\n", + " 0.0640922337770462,\n", + " -0.015484921634197235,\n", + " -0.009706663899123669,\n", + " -0.008280598558485508,\n", + " 0.005670629441738129,\n", + " -0.013251648284494877,\n", + " -0.002973212394863367,\n", + " -0.02879038266837597,\n", + " -0.007143781986087561,\n", + " -0.04157116264104843,\n", + " -0.0066998181864619255,\n", + " 0.01987074688076973,\n", + " 0.06199349835515022,\n", + " -0.006968887057155371,\n", + " -0.04687182232737541,\n", + " -0.014193389564752579,\n", + " 0.007399397436529398,\n", + " -0.03374125435948372,\n", + " -0.043481551110744476,\n", + " -0.008139337413012981,\n", + " 0.007634832989424467,\n", + " -0.005532731302082539,\n", + " 0.012087925337255001,\n", + " -0.003134653903543949,\n", + " 0.009518316015601158,\n", + " 0.028252245858311653,\n", + " -0.012000477872788906,\n", + " -0.030835308134555817,\n", + " 0.026624377816915512,\n", + " 0.032557349652051926,\n", + " -0.006575373932719231,\n", + " -0.00798462238162756,\n", + " -0.0033515908289700747,\n", + " 0.019386421889066696,\n", + " -0.05160743370652199,\n", + " -0.022104019299149513,\n", + " 0.008516034111380577,\n", + " 0.027875548228621483,\n", + " 0.019628584384918213,\n", + " 0.004991230089217424,\n", + " 0.028655849397182465,\n", + " 0.01359471119940281,\n", + " -0.007782821077853441,\n", + " -0.01109909638762474,\n", + " -0.0005763962399214506,\n", + " 0.011953390203416348,\n", + " -0.004738977644592524,\n", + " -0.022790145128965378,\n", + " 0.007096694782376289,\n", + " 0.02948996238410473,\n", + " -0.006481199525296688,\n", + " -0.0007987986318767071,\n", + " -0.011475793085992336,\n", + " -0.00785008817911148,\n", + " 0.04687182232737541,\n", + " 0.006397115532308817,\n", + " -0.002424984471872449,\n", + " 0.025157952681183815,\n", + " 0.00809897668659687,\n", + " -0.016332488507032394,\n", + " -0.013897414319217205,\n", + " -0.012081198394298553,\n", + " 0.03387578949332237,\n", + " 0.0027613206766545773,\n", + " -0.02149861305952072,\n", + " -0.006656094454228878,\n", + " 0.015148584730923176,\n", + " 0.06586809456348419,\n", + " 0.004765884950757027,\n", + " -0.010439877398312092,\n", + " 0.013762879185378551,\n", + " 0.027956269681453705,\n", + " -9.002249862533063e-05,\n", + " 0.03177705034613609,\n", + " 0.007190869189798832,\n", + " -0.0212699044495821,\n", + " -0.03772347420454025,\n", + " -0.038530681282281876,\n", + " -0.03616287559270859,\n", + " -0.024014407768845558,\n", + " -0.026032425463199615,\n", + " -0.06387697905302048,\n", + " 0.021175730973482132,\n", + " -0.007587745785713196,\n", + " 0.033929601311683655,\n", + " 0.026355309411883354,\n", + " 0.0013167564757168293,\n", + " -0.004880239255726337,\n", + " -0.004715434275567532,\n", + " -0.0167495459318161,\n", + " -0.0015866663306951523,\n", + " 0.029705218970775604,\n", + " -0.04119446501135826,\n", + " 0.048755303025245667,\n", + " 0.02182149700820446,\n", + " 0.014368284493684769,\n", + " 0.024700535461306572,\n", + " -0.032207559794187546,\n", + " 0.012188825756311417,\n", + " 0.003978857770562172,\n", + " 0.009249246679246426,\n", + " 0.04264743626117706,\n", + " 0.0012848045444115996,\n", + " -0.0352480411529541,\n", + " -0.018000716343522072,\n", + " -0.02034161612391472,\n", + " -0.029382335022091866,\n", + " 0.03702389448881149,\n", + " 0.011785222217440605,\n", + " 0.006400479003787041,\n", + " -0.022238552570343018,\n", + " -0.04845932871103287,\n", + " 0.027552666142582893,\n", + " -0.014166482724249363,\n", + " -0.01102510280907154,\n", + " -0.0018464860040694475,\n", + " 0.0025527921970933676,\n", + " -0.04958941787481308,\n", + " -0.024956149980425835,\n", + " 0.03772347420454025,\n", + " -0.021565880626440048,\n", + " -0.05410977825522423,\n", + " -0.004147026222199202,\n", + " 0.03053933195769787,\n", + " -0.011354711838066578,\n", + " 0.011778495274484158,\n", + " -0.015202398411929607,\n", + " -0.021888762712478638,\n", + " -0.008253691717982292,\n", + " -0.042378369718790054,\n", + " 0.0026671465020626783,\n", + " 0.028225338086485863,\n", + " -0.00250906846486032,\n", + " 0.016789905726909637,\n", + " -0.018606122583150864,\n", + " 0.0023072666954249144,\n", + " -0.02369152568280697,\n", + " 0.01987074688076973,\n", + " 0.012901858426630497,\n", + " 0.014960236847400665,\n", + " 0.0059800585731863976,\n", + " -0.0016825221246108413,\n", + " -0.006575373932719231,\n", + " -0.005008046980947256,\n", + " -0.008657295256853104,\n", + " -0.01654774323105812,\n", + " 0.00396204087883234,\n", + " -0.02334173582494259,\n", + " 0.04958941787481308,\n", + " 0.020852847024798393,\n", + " 0.0028454046696424484,\n", + " -0.01757020689547062,\n", + " 0.05203794687986374,\n", + " 0.014260657131671906,\n", + " 0.013083480298519135,\n", + " 0.03137344494462013,\n", + " 0.009531769901514053,\n", + " -0.013339095748960972,\n", + " 0.026705099269747734,\n", + " 0.004022581502795219,\n", + " 0.0033717709593474865,\n", + " 0.0017573569202795625,\n", + " 0.012908585369586945,\n", + " -0.020489605143666267,\n", + " -0.028117710724473,\n", + " -0.01844467967748642,\n", + " -0.021027741953730583,\n", + " 0.02234617993235588,\n", + " -0.004634713754057884,\n", + " 0.07496262341737747,\n", + " -0.016278674826025963,\n", + " -0.006239037495106459,\n", + " -0.009074351750314236,\n", + " 0.010049727745354176,\n", + " 0.019467143341898918,\n", + " 0.014193389564752579,\n", + " -0.008072069846093655,\n", + " -0.019561316817998886,\n", + " 0.00862366147339344,\n", + " -0.014314470812678337,\n", + " 0.04251290112733841,\n", + " 0.0033566358033567667,\n", + " 0.03659338504076004,\n", + " 0.0019103899830952287,\n", + " -0.030108822509646416,\n", + " -0.007305223494768143,\n", + " 0.0018733929609879851,\n", + " -0.024431465193629265,\n", + " 0.01335927564650774,\n", + " 0.006326484959572554,\n", + " -0.04105992987751961,\n", + " -0.03629740700125694,\n", + " -0.0020953749772161245,\n", + " 0.028924917802214622,\n", + " 0.029785938560962677,\n", + " 0.01069549284875393,\n", + " -0.003615614725276828,\n", + " -0.0005154352984391153,\n", + " -0.02922089397907257,\n", + " -0.021808043122291565,\n", + " -0.0036324316170066595,\n", + " 0.04243218153715134,\n", + " -0.010480238124728203,\n", + " -0.03156179562211037,\n", + " 0.022709423676133156,\n", + " 0.004443001933395863,\n", + " -0.01286149863153696,\n", + " -0.03826161101460457,\n", + " 0.024660173803567886,\n", + " -0.011004921980202198,\n", + " -0.006393752060830593,\n", + " 0.02114882320165634,\n", + " 0.026906900107860565,\n", + " -0.023462817072868347,\n", + " -0.024135489016771317,\n", + " 0.03446773812174797,\n", + " 0.028036991134285927,\n", + " 0.014341377653181553,\n", + " -0.04700635373592377,\n", + " 0.005378016736358404,\n", + " -0.02914017252624035,\n", + " 0.0093232411891222,\n", + " -0.05881848558783531,\n", + " -0.0029210804495960474,\n", + " -0.029678311198949814,\n", + " -0.060701966285705566,\n", + " -0.006797355599701405,\n", + " 0.002322401851415634,\n", + " -0.034306298941373825,\n", + " 0.0004843242058996111,\n", + " -0.023651165887713432,\n", + " 0.01073585357517004,\n", + " -0.021310264244675636,\n", + " -0.035005878657102585,\n", + " 0.0028050444088876247,\n", + " -0.01596924476325512,\n", + " 0.03126581758260727,\n", + " 0.018256332725286484,\n", + " 0.0285482220351696,\n", + " -0.01844467967748642,\n", + " 0.013688885606825352,\n", + " 0.02581717073917389,\n", + " 0.0167495459318161,\n", + " -0.0010073271114379168,\n", + " -0.023826060816645622,\n", + " -0.01404540240764618,\n", + " 0.015054411254823208,\n", + " -0.01493333000689745,\n", + " -0.022978492081165314,\n", + " 0.02494269609451294,\n", + " 0.04407350346446037,\n", + " 0.022938132286071777,\n", + " -0.016655370593070984,\n", + " 0.012807684950530529,\n", + " 0.001075435196980834,\n", + " 0.001704383990727365,\n", + " -0.016386302188038826,\n", + " -7.651649502804503e-05,\n", + " 0.011771769262850285,\n", + " 0.01046005729585886,\n", + " -0.028575127944350243,\n", + " -0.003598797833546996,\n", + " 0.004406005144119263,\n", + " -0.012377174571156502,\n", + " 0.017704740166664124,\n", + " -0.0015740536618977785,\n", + " -0.017112787812948227,\n", + " 0.021565880626440048,\n", + " -0.01887519098818302,\n", + " 0.030862214043736458,\n", + " 0.00434210104867816,\n", + " 0.05147290229797363,\n", + " -0.020449243485927582,\n", + " 0.006454292684793472,\n", + " 0.011926483362913132,\n", + " 0.0012721918756142259,\n", + " -0.001787627232261002,\n", + " 0.003323002252727747,\n", + " 0.04606461524963379,\n", + " -0.003995674662292004,\n", + " 0.01133453194051981,\n", + " 0.0022013208363205194,\n", + " 0.0026419213972985744,\n", + " 0.0064273858442902565,\n", + " -0.04157116264104843,\n", + " 0.022332727909088135,\n", + " -0.042324554175138474,\n", + " -0.018431227654218674,\n", + " -0.006249127443879843,\n", + " 0.009444322437047958,\n", + " -0.024108583107590675,\n", + " -0.0015706903068348765,\n", + " 0.01404540240764618,\n", + " -0.017812367528676987,\n", + " 0.0015967563958838582,\n", + " 0.011516153812408447,\n", + " 0.022211646661162376,\n", + " -0.04229764640331268,\n", + " -0.024175850674510002,\n", + " -0.046279869973659515,\n", + " -0.01168432179838419,\n", + " 0.005357836373150349,\n", + " 0.005263662431389093,\n", + " 0.044907618314027786,\n", + " -0.01824287883937359,\n", + " -0.032207559794187546,\n", + " 0.010641679167747498,\n", + " 0.003783782944083214,\n", + " 0.004570809658616781,\n", + " -0.04751758649945259,\n", + " 0.02071831375360489,\n", + " 0.04009127989411354,\n", + " 0.004762521479278803,\n", + " -0.026678191497921944,\n", + " -0.014395191334187984,\n", + " 0.008838917128741741,\n", + " 0.006434112787246704,\n", + " -0.008267145603895187,\n", + " 0.021525520831346512,\n", + " 0.03406413644552231,\n", + " -0.012101378291845322,\n", + " -0.012356993742287159,\n", + " 0.005690809339284897,\n", + " -0.03982221335172653,\n", + " 0.006400479003787041,\n", + " 0.0035483473911881447,\n", + " 0.02304575964808464,\n", + " -0.00011897894728463143,\n", + " 0.02071831375360489,\n", + " 0.008327685296535492,\n", + " -0.018552307039499283,\n", + " -0.014206843450665474,\n", + " 0.046898726373910904,\n", + " 0.0218484029173851,\n", + " -0.023974047973752022,\n", + " 0.014287563972175121,\n", + " 0.03376815840601921,\n", + " -0.003514713840559125,\n", + " -0.018565760925412178,\n", + " 0.0023139934055507183,\n", + " -0.006820899434387684,\n", + " -0.006615734193474054,\n", + " 0.006568646989762783,\n", + " 0.02922089397907257,\n", + " 0.00862366147339344,\n", + " -0.01687062717974186,\n", + " -0.03522113338112831,\n", + " -0.010668586008250713,\n", + " 0.0003584083169698715,\n", + " -0.0030942936427891254,\n", + " 0.0010552549501881003,\n", + " -0.0161710474640131,\n", + " 0.02601897343993187,\n", + " -0.008072069846093655,\n", + " 0.021538974717259407,\n", + " -0.02456600032746792,\n", + " -0.0029093085322529078,\n", + " 0.012942219153046608,\n", + " -0.043454643338918686,\n", + " -0.012854771688580513,\n", + " 0.026207320392131805,\n", + " -0.006733451969921589,\n", + " -0.03209993243217468,\n", + " 0.016063420102000237,\n", + " -0.026032425463199615,\n", + " -0.012195552699267864,\n", + " -0.002974894130602479,\n", + " -0.01949404925107956,\n", + " -0.005391470156610012,\n", + " 0.019655490294098854,\n", + " 0.018485041335225105,\n", + " 0.017139695584774017,\n", + " 0.033283837139606476,\n", + " -0.014731528237462044,\n", + " -0.0006108707166276872,\n", + " -0.012377174571156502,\n", + " 0.0495356023311615,\n", + " 0.050100646913051605,\n", + " 0.0015606002416461706,\n", + " -0.00031111104181036353,\n", + " 0.001344504184089601,\n", + " -0.02948996238410473,\n", + " 0.020758673548698425,\n", + " 0.04474617540836334,\n", + " -0.05475554242730141,\n", + " 0.02784864231944084,\n", + " -0.006649367976933718,\n", + " -0.007708827033638954,\n", + " 0.022790145128965378,\n", + " 0.04264743626117706,\n", + " 0.010338976047933102,\n", + " 0.006767085287719965,\n", + " -0.036028340458869934,\n", + " -0.026032425463199615,\n", + " 0.01494678296148777,\n", + " 0.02324756234884262,\n", + " 0.01347362995147705,\n", + " 0.008246964775025845,\n", + " -0.014341377653181553,\n", + " 0.003151470795273781,\n", + " -0.0016581377713009715,\n", + " -0.00967303104698658,\n", + " 0.006259217858314514,\n", + " -0.02124299854040146,\n", + " -0.010675312951207161,\n", + " 0.027202876284718513,\n", + " 0.014879516325891018,\n", + " 0.009269427508115768,\n", + " 0.010675312951207161,\n", + " -0.007937535643577576,\n", + " 0.02121609076857567,\n", + " 0.02779482863843441,\n", + " 0.018135251477360725,\n", + " -0.007567565888166428,\n", + " -0.0042714704759418964,\n", + " -0.002071831375360489,\n", + " -0.006245764438062906,\n", + " 0.0018363959388807416,\n", + " -0.014650807715952396,\n", + " -0.0521455742418766,\n", + " 0.02922089397907257,\n", + " 0.024485278874635696,\n", + " 0.047975003719329834,\n", + " 0.009081078693270683,\n", + " 0.015592548996210098,\n", + " 0.022238552570343018,\n", + " -0.0061784968711435795,\n", + " 0.006121319718658924,\n", + " -0.01894245855510235,\n", + " -0.04353536665439606,\n", + " -0.016937894746661186,\n", + " 0.0056975362822413445,\n", + " -0.004089849069714546,\n", + " -0.009121439419686794,\n", + " -0.032853323966264725,\n", + " 0.0556434690952301,\n", + " 0.006935253739356995,\n", + " -0.017435671761631966,\n", + " 0.029086358845233917,\n", + " 0.029624497517943382,\n", + " -0.016036512330174446,\n", + " 0.01809488981962204,\n", + " 0.007897174917161465,\n", + " -0.013453450053930283,\n", + " -0.051580529659986496,\n", + " 0.030512424185872078,\n", + " 0.0027512304950505495,\n", + " -0.031104376539587975,\n", + " -0.03099674917757511,\n", + " 0.03879975154995918,\n", + " 0.0193729680031538,\n", + " 0.00539819709956646,\n", + " 0.06226256862282753,\n", + " 0.00551255140453577,\n", + " 0.017906542867422104,\n", + " -0.004089849069714546,\n", + " -0.015229305252432823,\n", + " -0.0192249808460474,\n", + " -0.023651165887713432,\n", + " -0.002043242799118161,\n", + " 0.0007563361432403326,\n", + " 0.007587745785713196,\n", + " -0.010830027051270008,\n", + " 0.008246964775025845,\n", + " 0.044127315282821655,\n", + " -0.008919637650251389,\n", + " -0.005472190678119659,\n", + " 0.012404081411659718,\n", + " -0.01666882447898388,\n", + " -0.016426661983132362,\n", + " -0.02474089525640011,\n", + " -0.012195552699267864,\n", + " -0.0016488884575664997,\n", + " -0.004607806913554668,\n", + " -0.01870029605925083,\n", + " -0.013830146752297878,\n", + " 0.009713390842080116,\n", + " 0.015632908791303635,\n", + " -0.0273912250995636,\n", + " 0.0006550148827955127,\n", + " 0.03656647726893425,\n", + " -0.01140852551907301,\n", + " 0.0023745340295135975,\n", + " -0.017287682741880417,\n", + " -0.035328760743141174,\n", + " 0.025884438306093216,\n", + " 0.04052179306745529,\n", + " -0.006302941590547562,\n", + " 0.023624258115887642,\n", + " 0.02266906388103962,\n", + " 0.02584407851099968,\n", + " -0.005145944654941559,\n", + " -0.005293932743370533,\n", + " 0.001347026671282947,\n", + " 0.01459699310362339,\n", + " 0.006010328885167837,\n", + " -0.016184501349925995,\n", + " -0.014475912787020206,\n", + " 0.007305223494768143,\n", + " -0.006706545129418373,\n", + " -0.02092011459171772,\n", + " 0.03452155366539955,\n", + " 0.03976839780807495,\n", + " -0.003048888174816966,\n", + " -0.025938251987099648,\n", + " -0.011354711838066578,\n", + " -0.02129681222140789,\n", + " -0.0167495459318161,\n", + " ...]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#| eval: false\n", + "\n", + "## change to this design\n", + "from openai import OpenAI\n", + "embedding_model = ragas_embedding(provider=\"openai\", model=\"text-embedding-3-small\", client=OpenAI())\n", + "embedding_model.embed_text(\"Hello, world!\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/llm/llm.ipynb b/nbs/llm/llm.ipynb new file mode 100644 index 0000000..f98037e --- /dev/null +++ b/nbs/llm/llm.ipynb @@ -0,0 +1,257 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp llm.llm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLM Interface for Ragas" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "import typing as t\n", + "import asyncio\n", + "import inspect\n", + "import threading\n", + "from pydantic import BaseModel\n", + "import instructor\n", + "\n", + "T = t.TypeVar('T', bound=BaseModel)\n", + "\n", + "class RagasLLM:\n", + " def __init__(self, provider: str, model: str, client: t.Any, **model_args):\n", + " self.provider = provider.lower()\n", + " self.model = model\n", + " self.model_args = model_args or {}\n", + " self.client = self._initialize_client(provider, client)\n", + " # Check if client is async-capable at initialization\n", + " self.is_async = self._check_client_async()\n", + " \n", + " def _check_client_async(self) -> bool:\n", + " \"\"\"Determine if the client is async-capable.\"\"\"\n", + " try:\n", + " # Check if this is an async client by checking for a coroutine method\n", + " if hasattr(self.client.chat.completions, 'create'):\n", + " return inspect.iscoroutinefunction(self.client.chat.completions.create)\n", + " return False\n", + " except (AttributeError, TypeError):\n", + " return False\n", + " \n", + " def _initialize_client(self, provider: str, client: t.Any) -> t.Any:\n", + " provider = provider.lower()\n", + " \n", + " if provider == \"openai\":\n", + " return instructor.from_openai(client)\n", + " elif provider == \"anthropic\":\n", + " return instructor.from_anthropic(client)\n", + " elif provider == \"cohere\":\n", + " return instructor.from_cohere(client)\n", + " elif provider == \"gemini\":\n", + " return instructor.from_gemini(client)\n", + " elif provider == \"litellm\":\n", + " return instructor.from_litellm(client)\n", + " else:\n", + " raise ValueError(f\"Unsupported provider: {provider}\")\n", + " \n", + " def _run_async_in_current_loop(self, coro):\n", + " \"\"\"Run an async coroutine in the current event loop if possible.\n", + " \n", + " This handles Jupyter environments correctly by using a separate thread\n", + " when a running event loop is detected.\n", + " \"\"\"\n", + " try:\n", + " # Try to get the current event loop\n", + " loop = asyncio.get_event_loop()\n", + " \n", + " if loop.is_running():\n", + " # If the loop is already running (like in Jupyter notebooks),\n", + " # we run the coroutine in a separate thread with its own event loop\n", + " result_container = {'result': None, 'exception': None}\n", + " \n", + " def run_in_thread():\n", + " # Create a new event loop for this thread\n", + " new_loop = asyncio.new_event_loop()\n", + " asyncio.set_event_loop(new_loop)\n", + " try:\n", + " # Run the coroutine in this thread's event loop\n", + " result_container['result'] = new_loop.run_until_complete(coro)\n", + " except Exception as e:\n", + " # Capture any exceptions to re-raise in the main thread\n", + " result_container['exception'] = e\n", + " finally:\n", + " # Clean up the event loop\n", + " new_loop.close()\n", + " \n", + " # Start the thread and wait for it to complete\n", + " thread = threading.Thread(target=run_in_thread)\n", + " thread.start()\n", + " thread.join()\n", + " \n", + " # Re-raise any exceptions that occurred in the thread\n", + " if result_container['exception']:\n", + " raise result_container['exception']\n", + " \n", + " return result_container['result']\n", + " else:\n", + " # Standard case - event loop exists but isn't running\n", + " return loop.run_until_complete(coro)\n", + " \n", + " except RuntimeError:\n", + " # If we get a runtime error about no event loop, create a new one\n", + " loop = asyncio.new_event_loop()\n", + " asyncio.set_event_loop(loop)\n", + " try:\n", + " return loop.run_until_complete(coro)\n", + " finally:\n", + " # Clean up\n", + " loop.close()\n", + " asyncio.set_event_loop(None)\n", + " \n", + " def generate(self, prompt: str, response_model: t.Type[T]) -> T:\n", + " \"\"\"Generate a response using the configured LLM.\n", + " \n", + " For async clients, this will run the async method in the appropriate event loop.\n", + " \"\"\"\n", + " messages = [{\"role\": \"user\", \"content\": prompt}]\n", + " \n", + " # If client is async, use the appropriate method to run it\n", + " if self.is_async:\n", + " return self._run_async_in_current_loop(\n", + " self.agenerate(prompt, response_model)\n", + " )\n", + " else:\n", + " # Regular sync client, just call the method directly\n", + " return self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=messages,\n", + " response_model=response_model,\n", + " **self.model_args,\n", + " )\n", + " \n", + " async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T:\n", + " \"\"\"Asynchronously generate a response using the configured LLM.\"\"\"\n", + " messages = [{\"role\": \"user\", \"content\": prompt}]\n", + " \n", + " # If client is not async, raise a helpful error\n", + " if not self.is_async:\n", + " raise TypeError(\n", + " \"Cannot use agenerate() with a synchronous client. Use generate() instead.\"\n", + " )\n", + " \n", + " # Regular async client, call the method directly\n", + " return await self.client.chat.completions.create(\n", + " model=self.model,\n", + " messages=messages,\n", + " response_model=response_model,\n", + " **self.model_args,\n", + " )\n", + "\n", + "def ragas_llm(provider: str, model: str, client: t.Any, **model_args) -> RagasLLM:\n", + " return RagasLLM(provider=provider, client=client, model=model, **model_args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example Usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| eval: false\n", + "\n", + "from openai import OpenAI\n", + "class Response(BaseModel):\n", + " response: str\n", + "\n", + "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", + "llm.generate(\"What is the capital of India?\",response_model=Response) #works fine\n", + "\n", + "try:\n", + " await llm.agenerate(\"What is the capital of India?\", response_model=Response)\n", + "except TypeError as e:\n", + " assert isinstance(e, TypeError)\n", + "#gives TypeError: object Response can't be used in 'await' expression\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Response(response='The capital of India is New Delhi.')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#| eval: false\n", + "\n", + "from openai import AsyncOpenAI\n", + "\n", + "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=AsyncOpenAI())\n", + "await llm.agenerate(\"What is the capital of India?\",response_model=Response)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Response(response='The capital of India is New Delhi.')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#| eval: false\n", + "\n", + "from anthropic import Anthropic\n", + "\n", + "llm = ragas_llm(provider=\"anthropic\",model=\"claude-3-opus-20240229\",client=Anthropic(),max_tokens=1024)\n", + "llm.generate(\"What is the capital of India?\",response_model=Response)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/metric/base.ipynb b/nbs/metric/base.ipynb index c6d6e24..c31ecaa 100644 --- a/nbs/metric/base.ipynb +++ b/nbs/metric/base.ipynb @@ -42,19 +42,32 @@ "from dataclasses import dataclass, field\n", "from pydantic import BaseModel\n", "import typing as t\n", + "import json\n", + "from tqdm import tqdm\n", + "\n", + "from ragas_annotator.prompt.base import Prompt\n", + "from ragas_annotator.embedding.base import BaseEmbedding\n", "from ragas_annotator.metric import MetricResult\n", - "from ragas_annotator.metric import LLM\n", + "from ragas_annotator.llm import RagasLLM\n", + "from ragas_annotator.project.core import Project\n", + "from ragas_annotator.model.notion_model import NotionModel\n", + "from ragas_annotator.prompt.dynamic_few_shot import DynamicFewShotPrompt\n", + "\n", "\n", "@dataclass\n", "class Metric(ABC):\n", " \"\"\"Base class for all metrics in the LLM evaluation library.\"\"\"\n", " name: str\n", - " prompt: str\n", - " llm: LLM\n", + " prompt: str | Prompt\n", + " llm: RagasLLM\n", " _response_models: t.Dict[bool, t.Type[BaseModel]] = field(\n", " default_factory=dict, init=False, repr=False\n", " )\n", " \n", + " def __post_init__(self):\n", + " if isinstance(self.prompt,str):\n", + " self.prompt = Prompt(self.prompt)\n", + " \n", " @abstractmethod\n", " def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]:\n", " \"\"\"Get the appropriate response model.\"\"\"\n", @@ -67,22 +80,32 @@ " \n", " def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any:\n", " responses = []\n", + " traces = {}\n", + " traces[\"input\"] = kwargs\n", " prompt_input = self.prompt.format(**kwargs)\n", " for _ in range(n):\n", " response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) \n", + " traces['output'] = response.model_dump()\n", " response = MetricResult(**response.model_dump())\n", " responses.append(response)\n", - " return self._ensemble(responses)\n", + " results = self._ensemble(responses)\n", + " results.traces = traces\n", + " return results\n", "\n", "\n", " async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult:\n", " responses = [] # Added missing initialization\n", + " traces = {}\n", + " traces[\"input\"] = kwargs\n", " prompt_input = self.prompt.format(**kwargs)\n", " for _ in range(n):\n", " response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning))\n", + " traces['output'] = response.model_dump()\n", " response = MetricResult(**response.model_dump()) # Fixed missing parentheses\n", " responses.append(response)\n", - " return self._ensemble(responses)\n", + " results = self._ensemble(responses)\n", + " results.traces = traces\n", + " return results\n", " \n", " def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]:\n", " return [self.score(reasoning, n, **input_dict) for input_dict in inputs]\n", @@ -94,7 +117,34 @@ " async_tasks.append(self.ascore(reasoning=reasoning, n=n, **input_dict))\n", " \n", " # Run all tasks concurrently and return results\n", - " return await asyncio.gather(*async_tasks)" + " return await asyncio.gather(*async_tasks)\n", + " \n", + " def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: BaseEmbedding,method: t.Dict[str, t.Any]):\n", + " \n", + " assert isinstance(self.prompt, Prompt)\n", + " self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model)\n", + " datasets = []\n", + " for experiment_name in experiment_names:\n", + " experiment_data = project.get_experiment(experiment_name,model)\n", + " experiment_data.load()\n", + " datasets.append(experiment_data)\n", + " \n", + " total_items = sum([len(dataset) for dataset in datasets])\n", + " with tqdm(total=total_items, desc=\"Processing examples\") as pbar:\n", + " for dataset in datasets:\n", + " for row in dataset:\n", + " if hasattr(row, f'{self.name}_traces'):\n", + " traces = json.loads(getattr(row, f'{self.name}_traces'))\n", + " if traces:\n", + " self.prompt.add_example(traces['input'],traces['output'])\n", + " pbar.update(1)\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " " ] }, { @@ -114,6 +164,11 @@ "source": [ "#| eval: false\n", "\n", + "from ragas_annotator.llm import ragas_llm\n", + "from openai import OpenAI\n", + "\n", + "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", + "\n", "@dataclass\n", "class CustomMetric(Metric):\n", " values: t.List[str] = field(default_factory=lambda: [\"pass\", \"fail\"])\n", @@ -131,12 +186,18 @@ " \n", " return results[0] # Placeholder for ensemble logic\n", "\n", - "my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=LLM())\n", + "my_metric = CustomMetric(name=\"example\", prompt=\"What is the result of {input}?\", llm=llm)\n", "my_metric.score(input=\"test\")" ] } ], - "metadata": {}, + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, "nbformat": 4, "nbformat_minor": 2 } diff --git a/nbs/metric/decorator.ipynb b/nbs/metric/decorator.ipynb index 70131f0..7f9752a 100644 --- a/nbs/metric/decorator.ipynb +++ b/nbs/metric/decorator.ipynb @@ -30,6 +30,8 @@ "import asyncio\n", "from dataclasses import dataclass\n", "from ragas_annotator.metric import MetricResult\n", + "from ragas_annotator.llm import RagasLLM\n", + "from ragas_annotator.prompt.base import Prompt\n", "\n", "\n", "\n", @@ -44,7 +46,7 @@ " Returns:\n", " A decorator factory function for the specified metric type\n", " \"\"\"\n", - " def decorator_factory(llm, prompt, name: t.Optional[str] = None, **metric_params):\n", + " def decorator_factory(llm:RagasLLM, prompt: t.Union[str, Prompt], name: t.Optional[str] = None, **metric_params):\n", " \"\"\"\n", " Creates a decorator that wraps a function into a metric instance.\n", " \n", @@ -63,17 +65,9 @@ " metric_name = name or func.__name__\n", " is_async = inspect.iscoroutinefunction(func)\n", " \n", + " #TODO: Move to dataclass type implementation\n", " @dataclass\n", " class CustomMetric(metric_class):\n", - " def _extract_result(self, result, reasoning: bool):\n", - " \"\"\"Extract score and reason from the result.\"\"\"\n", - " if isinstance(result, tuple) and len(result) == 2:\n", - " score, reason = result\n", - " else:\n", - " score, reason = result, None\n", - " \n", - " # Use \"result\" instead of \"score\" for the new MetricResult implementation\n", - " return MetricResult(result=score, reason=reason if reasoning else None)\n", " \n", " def _run_sync_in_async(self, func, *args, **kwargs):\n", " \"\"\"Run a synchronous function in an async context.\"\"\"\n", @@ -100,7 +94,7 @@ " # Sync function implementation\n", " result = func(self.llm, self.prompt, **kwargs)\n", " \n", - " return self._extract_result(result, reasoning)\n", + " return result\n", " except Exception as e:\n", " # Handle errors gracefully\n", " error_msg = f\"Error executing metric {self.name}: {str(e)}\"\n", @@ -119,7 +113,7 @@ " else:\n", " # For sync functions, run normally\n", " result = self._run_sync_in_async(func, self.llm, self.prompt, **kwargs)\n", - " return self._extract_result(result, reasoning)\n", + " return result\n", " \n", " # Create the metric instance with all parameters\n", " metric_instance = CustomMetric(\n", @@ -158,8 +152,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "high\n", - "reason\n" + "low\n", + "The context or details of the user's response ('my response') are not provided, making it impossible to evaluate its helpfulness accurately.\n" ] } ], @@ -167,13 +161,17 @@ "#| eval: false\n", "\n", "\n", - "from ragas_annotator.metric import DiscreteMetric\n", - "from ragas_annotator.metric.llm import LLM\n", + "from ragas_annotator.metric import DiscreteMetric, MetricResult\n", "from pydantic import BaseModel\n", "\n", + "from ragas_annotator.llm import ragas_llm\n", + "from openai import OpenAI\n", + "\n", + "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", + "\n", "discrete_metric = create_metric_decorator(DiscreteMetric)\n", "\n", - "@discrete_metric(llm=LLM(),\n", + "@discrete_metric(llm=llm,\n", " prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n", " name='new_metric',values=[\"low\",\"med\",\"high\"])\n", "def my_metric(llm,prompt,**kwargs):\n", @@ -188,7 +186,7 @@ " score = 'low'\n", " else:\n", " score = 'high'\n", - " return score,\"reason\"\n", + " return MetricResult(result=score, reason=response.reason)\n", "\n", "result = my_metric.score(response='my response') # result\n", "print(result)\n", diff --git a/nbs/metric/discrete.ipynb b/nbs/metric/discrete.ipynb index c27815c..bdb8a28 100644 --- a/nbs/metric/discrete.ipynb +++ b/nbs/metric/discrete.ipynb @@ -90,8 +90,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "low\n", - "No context or content was provided for evaluation.\n" + "med\n", + "The given input \"this is my response\" is too vague to provide a comprehensive evaluation.\n", + "\n", + "Positives:\n", + "1. Clear Statement: It's a straightforward indication that a response has been provided.\n", + "\n", + "Negatives:\n", + "1. Lack of Context: Without context or additional information, it's impossible to assess the relevance or accuracy of the response.\n", + "2. No Specificity: The response doesn't convey any specific information or insight related to a topic or question.\n", + "\n", + "If this response was intended to be part of a conversation or instruction, more detail would be required to make it highly effective. At present, it serves as a neutral statement without actionable or informative content.\n" ] } ], @@ -99,10 +108,14 @@ "\n", "#| eval: false\n", "\n", - "from ragas_annotator.metric.llm import LLM\n", + "from ragas_annotator.llm import ragas_llm\n", + "from openai import OpenAI\n", + "\n", + "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", + "\n", "\n", "my_metric = DiscreteMetric(\n", - " llm=LLM(),\n", + " llm=llm,\n", " name='helpfulness',\n", " prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n", " values=[\"low\",\"med\",\"high\"],\n", @@ -131,13 +144,15 @@ "output_type": "stream", "text": [ "low\n", - "reason\n" + "The prompt 'my response' does not provide sufficient information or context for me to evaluate its helpfulness. An answer needs to be specific and provide insight or information relative to a clear question or context.\n" ] } ], "source": [ "#| eval: false\n", - "@discrete_metric(llm=LLM(),\n", + "from ragas_annotator.metric.result import MetricResult\n", + "\n", + "@discrete_metric(llm=llm,\n", " prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n", " name='new_metric',values=[\"low\",\"med\",\"high\"])\n", "def my_metric(llm,prompt,**kwargs):\n", @@ -145,14 +160,17 @@ " class response_model(BaseModel):\n", " output: t.List[bool]\n", " reason: str\n", - " \n", + " traces = {}\n", + " traces['input'] = kwargs\n", " response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n", + " traces['output'] = response.model_dump()\n", " total = sum(response.output)\n", " if total < 1:\n", " score = 'low'\n", " else:\n", " score = 'high'\n", - " return score,\"reason\"\n", + " \n", + " return MetricResult(result=score,reason=response.reason,traces=traces)\n", "\n", "result = my_metric.score(response='my response') # result\n", "print(result)\n", diff --git a/nbs/metric/llm.ipynb b/nbs/metric/llm.ipynb deleted file mode 100644 index 6ceca63..0000000 --- a/nbs/metric/llm.ipynb +++ /dev/null @@ -1,61 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| default_exp metric.llm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "#| export\n", - "\n", - "import openai\n", - "import instructor\n", - "from dataclasses import dataclass\n", - "\n", - "@dataclass\n", - "class LLM:\n", - "\n", - " def __post_init__(self):\n", - " self.aclient = instructor.from_openai(openai.AsyncOpenAI())\n", - " self.client = instructor.from_openai(openai.OpenAI())\n", - "\n", - " \n", - " def generate(self,prompt,response_model):\n", - " return self.client.chat.completions.create(\n", - " model=\"gpt-4o-mini\",\n", - " messages=[\n", - " {\"role\": \"user\", \"content\": prompt},\n", - " ],\n", - " response_model=response_model,\n", - " )\n", - "\n", - " async def agenerate(self,prompt,response_model):\n", - " return await self.aclient.chat.completions.create(\n", - " model=\"gpt-4o-mini\",\n", - " messages=[\n", - " {\"role\": \"user\", \"content\": prompt},\n", - " ],\n", - " response_model=response_model,\n", - " )" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "python3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/nbs/metric/numeric.ipynb b/nbs/metric/numeric.ipynb index e3b08b0..e6e5681 100644 --- a/nbs/metric/numeric.ipynb +++ b/nbs/metric/numeric.ipynb @@ -21,7 +21,16 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "#| export\n", "\n", @@ -81,7 +90,7 @@ { "data": { "text/plain": [ - "'The response does not provide any context or information that can be evaluated as helpful.'" + "\"The provided input lacks context or content to determine if it is helpful as it merely states 'this is my response' without any additional information.\"" ] }, "execution_count": null, @@ -90,13 +99,18 @@ } ], "source": [ + "\n", "#| eval: false\n", "\n", - "from ragas_annotator.metric.llm import LLM\n", + "from ragas_annotator.llm import ragas_llm\n", + "from openai import OpenAI\n", + "\n", + "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", + "\n", "\n", "my_metric = NumericMetric(\n", " name='helpfulness',\n", - " llm=LLM(),\n", + " llm=llm,\n", " prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n", " range=(0,10),\n", ")\n", @@ -122,7 +136,7 @@ { "data": { "text/plain": [ - "20" + "10" ] }, "execution_count": null, @@ -133,8 +147,9 @@ "source": [ "\n", "#| eval: false\n", + "from ragas_annotator.metric import MetricResult\n", "\n", - "@numeric_metric(llm=LLM(),\n", + "@numeric_metric(llm=llm,\n", " prompt=\"Evaluate if given answer is helpful\\n\\n{response}\",\n", " name='new_metric',range=(0,10))\n", "def my_metric(llm,prompt,**kwargs):\n", @@ -143,13 +158,16 @@ " output: int\n", " reason: str\n", " \n", + " traces = {}\n", + " traces['input'] = kwargs\n", " response = llm.generate(prompt.format(**kwargs),response_model=response_model)\n", + " traces['output'] = response.dict()\n", " total = response.output\n", " if total < 1:\n", " score = 0\n", " else:\n", " score = 10\n", - " return score,\"reason\"\n", + " return MetricResult(result=score,reason=response.reason,traces=traces)\n", "\n", "result = my_metric.score(response='my response') # result\n", "result # 10\n", diff --git a/nbs/metric/ranking.ipynb b/nbs/metric/ranking.ipynb index 48e2aa3..1c7cd4c 100644 --- a/nbs/metric/ranking.ipynb +++ b/nbs/metric/ranking.ipynb @@ -21,7 +21,16 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniforge/base/envs/random/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "#| export\n", "\n", @@ -117,11 +126,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "[0, 1, 2]\n", + "[2, 1, 0]\n", "Ensemble ranking based on multiple evaluations.\n", - "The ranking is based on the length and detail of the responses, with 'short answer.' being the least detailed (rank 0), 'a bit more detailed.' being moderate (rank 1), and 'the longest and most detailed answer.' being the most comprehensive (rank 2).\n", - "The ranking is based on the length and detail of the responses. The shortest response is ranked the lowest (0), the moderately detailed response is ranked higher (1), and the longest and most detailed response is ranked the highest (2).\n", - "Ranking is based on length and detail; the longest answer (2) is most detailed, followed by a bit more detailed (1), and the shortest answer (0) is the least detailed.\n" + "The ranking is based on the length and detail of each response. 'the longest and most detailed answer.' is the most comprehensive, followed by 'a bit more detailed.', and 'short answer.' is the briefest.\n", + "The ranking is based on the length and detail of each response. The response 'the longest and most detailed answer.' is ranked highest (2) because it is the most detailed, followed by 'a bit more detailed.' (1), and finally 'short answer.' (0) as it is the least detailed.\n", + "The responses are ranked based on the level of detail and length. 'short answer.' is the least detailed, 'a bit more detailed.' provides more information, and 'the longest and most detailed answer.' offers the most comprehensive explanation.\n" ] } ], @@ -129,11 +138,14 @@ "\n", "#| eval: false\n", "\n", - "from ragas_annotator.metric.llm import LLM\n", + "from ragas_annotator.llm import ragas_llm\n", + "from openai import OpenAI\n", + "\n", + "llm = ragas_llm(provider=\"openai\",model=\"gpt-4o\",client=OpenAI())\n", "\n", "my_ranking_metric = RankingMetric(\n", " name='response_ranking',\n", - " llm=LLM(), # Your language model instance\n", + " llm=llm, # Your language model instance\n", " prompt=\"Rank the following responses:\\n{candidates}\",\n", " num_ranks=3,\n", ")\n", @@ -173,9 +185,10 @@ "source": [ "#| eval: false\n", "\n", + "from ragas_annotator.metric import MetricResult\n", "\n", "@ranking_metric(\n", - " llm=LLM(), # Your language model instance\n", + " llm=llm, # Your language model instance\n", " prompt=\"Rank the following responses:\\n{candidates}\",\n", " name='new_ranking_metric',\n", " num_ranks=3\n", @@ -185,7 +198,7 @@ " # For example, process the prompt (formatted with candidates) and produce a ranking.\n", " ranking = [1, 0, 2] # Dummy ranking: second candidate is best, then first, then third.\n", " reason = \"Ranked based on response clarity and detail.\"\n", - " return ranking, reason\n", + " return MetricResult(result=ranking, reason=reason)\n", "\n", "# Using the decorator-based ranking metric:\n", "result = my_ranking_metric.score(candidates=[\n", diff --git a/nbs/metric/result.ipynb b/nbs/metric/result.ipynb index cba95c7..26149b3 100644 --- a/nbs/metric/result.ipynb +++ b/nbs/metric/result.ipynb @@ -46,9 +46,14 @@ " - RankingMetrics (list results)\n", " \"\"\"\n", " \n", - " def __init__(self, result: t.Any, reason: t.Optional[str] = None):\n", + " def __init__(self, result: t.Any, reason: t.Optional[str] = None, traces: t.Optional[t.Dict[str, t.Any]] = None):\n", + " if traces is not None:\n", + " invalid_keys = [key for key in traces.keys() if key not in {\"input\", \"output\"}]\n", + " if invalid_keys:\n", + " raise ValueError(f\"Invalid keys in traces: {invalid_keys}. Allowed keys are 'input' and 'output'.\")\n", " self._result = result\n", " self.reason = reason\n", + " self.traces = traces\n", " \n", " def __repr__(self):\n", " return repr(self._result)\n", diff --git a/nbs/prompt/base.ipynb b/nbs/prompt/base.ipynb new file mode 100644 index 0000000..879b673 --- /dev/null +++ b/nbs/prompt/base.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp prompt.base" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prompt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "import typing as t\n", + "import re\n", + "\n", + "class Prompt:\n", + " def __init__(\n", + " self,\n", + " instruction: str,\n", + " examples: t.Optional[t.List[t.Tuple[t.Dict, t.Dict]]] = None\n", + " ):\n", + " \"\"\"\n", + " Create a simple prompt object.\n", + " \n", + " Parameters:\n", + " -----------\n", + " instruction : str\n", + " The prompt instruction template with placeholders like {response}, {expected_answer}\n", + " examples : Optional[List[Tuple[Dict, Dict]]]\n", + " List of (input_dict, output_dict) pairs for few-shot learning\n", + " \"\"\"\n", + " self.instruction = instruction\n", + " self.examples = []\n", + " \n", + " # Validate the instruction\n", + " self._validate_instruction()\n", + " \n", + " # Add examples if provided\n", + " if examples:\n", + " for inputs, output in examples:\n", + " self.add_example(inputs, output)\n", + " \n", + " def _validate_instruction(self):\n", + " \"\"\"Ensure the instruction contains at least one placeholder.\"\"\"\n", + " if not re.findall(r\"\\{(\\w+)\\}\", self.instruction):\n", + " raise ValueError(\"Instruction must contain at least one placeholder like {response}\")\n", + " \n", + " def format(self, **kwargs) -> str:\n", + " \"\"\"Format the prompt with the provided variables.\"\"\"\n", + "\n", + " prompt_parts = []\n", + " prompt_parts.append(self.instruction.format(**kwargs))\n", + " prompt_parts.append(self._format_examples())\n", + "\n", + " # Combine all parts\n", + " return \"\\n\\n\".join(prompt_parts)\n", + " \n", + " def _format_examples(self) -> str:\n", + " \n", + " # Add examples in a simple format\n", + " examples = []\n", + " if self.examples:\n", + " examples.append(\"Examples:\")\n", + " for i, (inputs, output) in enumerate(self.examples, 1):\n", + " example_input = \"\\n\".join([f\"{k}: {v}\" for k, v in inputs.items()])\n", + " example_output = \"\\n\".join([f\"{k}: {v}\" for k, v in output.items()])\n", + " \n", + " examples.append(f\"Example {i}:\\nInput:\\n{example_input}\\nOutput:\\n{example_output}\")\n", + " \n", + " return \"\\n\\n\".join(examples) if examples else \"\"\n", + " \n", + " \n", + " def add_example(self, inputs: t.Dict, output: t.Dict) -> None:\n", + " \"\"\"\n", + " Add an example to the prompt.\n", + " \n", + " Parameters:\n", + " -----------\n", + " inputs : Dict\n", + " Dictionary of input values\n", + " output : Dict\n", + " Dictionary of output values\n", + " \n", + " Raises:\n", + " -------\n", + " TypeError\n", + " If inputs or output is not a dictionary\n", + " \"\"\"\n", + " if not isinstance(inputs, dict):\n", + " raise TypeError(f\"Expected inputs to be dict, got {type(inputs).__name__}\")\n", + " if not isinstance(output, dict):\n", + " raise TypeError(f\"Expected output to be dict, got {type(output).__name__}\")\n", + " \n", + " self.examples.append((inputs, output))\n", + " \n", + " def __str__(self) -> str:\n", + " \"\"\"String representation showing the instruction.\"\"\"\n", + " return f\"Prompt(instruction='{self.instruction}',\\n examples={self.examples})\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example Usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate if given answer You can get a full refund if you miss your flight. is same as expected answer Refunds depend on ticket type; only refundable tickets qualify for full refunds.\n", + "\n", + "Examples:\n", + "\n", + "Example 1:\n", + "Input:\n", + "response: You can get a full refund if you miss your flight.\n", + "expected_answer: Refunds depend on ticket type; only refundable tickets qualify for full refunds.\n", + "Output:\n", + "score: fail\n", + "\n", + "Example 2:\n", + "Input:\n", + "response: Each passenger gets 1 free checked bag up to 23kg.\n", + "expected_answer: Each passenger gets 1 free checked bag up to 23kg.\n", + "Output:\n", + "score: pass\n" + ] + } + ], + "source": [ + "# Create a basic prompt\n", + "prompt = Prompt(\n", + " instruction=\"Evaluate if given answer {response} is same as expected answer {expected_answer}\"\n", + ")\n", + "\n", + "# Add examples with dict inputs and dict outputs\n", + "prompt.add_example(\n", + " {\n", + " \"response\": \"You can get a full refund if you miss your flight.\",\n", + " \"expected_answer\": \"Refunds depend on ticket type; only refundable tickets qualify for full refunds.\"\n", + " },\n", + " {\"score\": \"fail\"}\n", + ")\n", + "\n", + "prompt.add_example(\n", + " {\n", + " \"response\": \"Each passenger gets 1 free checked bag up to 23kg.\",\n", + " \"expected_answer\": \"Each passenger gets 1 free checked bag up to 23kg.\"\n", + " },\n", + " {\"score\": \"pass\"}\n", + ")\n", + "\n", + "print(prompt.format(response=\"You can get a full refund if you miss your flight.\", expected_answer=\"Refunds depend on ticket type; only refundable tickets qualify for full refunds.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt(instruction='Evaluate if given answer {response} is same as expected answer {expected_answer}',\n", + " examples=Examples:\n", + "\n", + "Example 1:\n", + "Input:\n", + "response: You can get a full refund if you miss your flight.\n", + "expected_answer: Refunds depend on ticket type; only refundable tickets qualify for full refunds.\n", + "Output:\n", + "score: fail\n", + "\n", + "Example 2:\n", + "Input:\n", + "response: Each passenger gets 1 free checked bag up to 23kg.\n", + "expected_answer: Each passenger gets 1 free checked bag up to 23kg.\n", + "Output:\n", + "score: pass)\n" + ] + } + ], + "source": [ + "print(str(prompt))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/prompt/dynamic_few_shot.ipynb b/nbs/prompt/dynamic_few_shot.ipynb new file mode 100644 index 0000000..10bb417 --- /dev/null +++ b/nbs/prompt/dynamic_few_shot.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp prompt.dynamic_few_shot" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dynamic few-shot learning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "import typing as t\n", + "import numpy as np\n", + "from abc import ABC, abstractmethod\n", + "\n", + "from ragas_annotator.prompt.base import Prompt\n", + "from ragas_annotator.embedding import BaseEmbedding\n", + "\n", + "class ExampleStore(ABC):\n", + " @abstractmethod\n", + " def get_examples(\n", + " self, data: t.Dict, top_k: int = 5\n", + " ) -> t.List[t.Tuple[t.Dict, t.Dict]]:\n", + " \"\"\"Get top_k most similar examples to data.\"\"\"\n", + " pass\n", + "\n", + " @abstractmethod\n", + " def add_example(self, inputs: t.Dict, output: t.Dict) -> None:\n", + " \"\"\"Add an example to the store.\"\"\"\n", + " pass\n", + "\n", + "\n", + "class InMemoryExampleStore(ExampleStore):\n", + " def __init__(self, embedding_model=None):\n", + " \"\"\"\n", + " Initialize an in-memory example store with optional embedding model.\n", + " \n", + " Args:\n", + " embedding_model: Model used to generate embeddings (OpenAI or similar)\n", + " \"\"\"\n", + " self.embedding_model = embedding_model\n", + " self._examples: t.List[t.Tuple[t.Dict, t.Dict]] = []\n", + " self._embeddings_list: t.List[t.List[float]] = []\n", + " \n", + " def _get_embedding(self, data: t.Dict) -> t.List[float]:\n", + " \"\"\"Convert input dict to an embedding vector.\"\"\"\n", + " if self.embedding_model is None:\n", + " return []\n", + " \n", + " # Serialize the dictionary to text\n", + " text = \"\\n\".join([f\"{k}: {v}\" for k, v in data.items()])\n", + " return self.embedding_model.embed_text(text)\n", + " \n", + " def add_example(self, inputs: t.Dict, output: t.Dict) -> None:\n", + " \"\"\"Add an example to the store with its embedding.\"\"\"\n", + " if not isinstance(inputs, dict):\n", + " raise TypeError(f\"Expected inputs to be dict, got {type(inputs).__name__}\")\n", + " if not isinstance(output, dict):\n", + " raise TypeError(f\"Expected output to be dict, got {type(output).__name__}\")\n", + " \n", + " self._examples.append((inputs, output))\n", + " \n", + " if self.embedding_model:\n", + " embedding = self._get_embedding(inputs)\n", + " self._embeddings_list.append(embedding)\n", + " \n", + " def get_examples(\n", + " self, data: t.Dict, top_k: int = 5, threshold: float = 0.7\n", + " ) -> t.List[t.Tuple[t.Dict, t.Dict]]:\n", + " \"\"\"Get examples most similar to the input data.\"\"\"\n", + " if not self._examples:\n", + " return []\n", + " \n", + " if not self.embedding_model or not self._embeddings_list:\n", + " # If no embedding model, return the most recent examples\n", + " return self._examples[-top_k:]\n", + " \n", + " # Get embedding for the query\n", + " query_embedding = self._get_embedding(data)\n", + " \n", + " # Find most similar examples\n", + " indices = self._get_nearest_examples(\n", + " query_embedding, self._embeddings_list, top_k, threshold\n", + " )\n", + " \n", + " # Return the examples at those indices\n", + " return [self._examples[i] for i in indices]\n", + " \n", + " def _get_nearest_examples(\n", + " self,\n", + " query_embedding: t.List[float],\n", + " embeddings: t.List[t.List[float]],\n", + " top_k: int = 3,\n", + " threshold: float = 0.7,\n", + " ) -> t.List[int]:\n", + " \"\"\"Find indices of the nearest examples based on cosine similarity.\"\"\"\n", + " # Convert to numpy arrays for efficient computation\n", + " query = np.array(query_embedding)\n", + " embed_matrix = np.array(embeddings)\n", + " \n", + " # Calculate cosine similarity\n", + " similarities = np.dot(embed_matrix, query) / (\n", + " np.linalg.norm(embed_matrix, axis=1) * np.linalg.norm(query) + 1e-8\n", + " )\n", + " \n", + " # Get indices of similarities above threshold\n", + " valid_indices = np.where(similarities >= threshold)[0]\n", + " \n", + " # Sort by similarity and get top-k\n", + " if len(valid_indices) > 0:\n", + " top_indices = valid_indices[np.argsort(similarities[valid_indices])[-top_k:]]\n", + " # Convert numpy indices to Python ints\n", + " return [int(idx) for idx in top_indices]\n", + " \n", + " # If no examples meet threshold, return most recent examples\n", + " return list(range(max(0, len(embeddings) - top_k), len(embeddings)))\n", + " \n", + " def __len__(self):\n", + " return len(self._examples)\n", + "\n", + "\n", + "\n", + "\n", + "class DynamicFewShotPrompt(Prompt):\n", + " \n", + " def __init__(\n", + " self,\n", + " prompt: Prompt,\n", + " example_store: InMemoryExampleStore,\n", + " num_examples: int = 3\n", + " ):\n", + " \n", + " self.example_store = example_store\n", + " super().__init__(prompt.instruction, prompt.examples)\n", + " self.num_examples = num_examples\n", + " \n", + " for example in prompt.examples:\n", + " self.example_store.add_example(*example)\n", + " \n", + " def format(self, **kwargs) -> str:\n", + " \"\"\"Format the prompt with dynamically retrieved examples.\"\"\"\n", + " prompt_parts = []\n", + " \n", + " # Add instruction with variables filled in\n", + " prompt_parts.append(self.instruction.format(**kwargs))\n", + " \n", + " # Get dynamic examples if we have a store and inputs\n", + " dynamic_examples = []\n", + " if self.example_store and kwargs:\n", + " dynamic_examples = self.example_store.get_examples(kwargs, self.num_examples)\n", + " \n", + " # Add examples in a simple format\n", + " if dynamic_examples:\n", + " prompt_parts.append(\"Examples:\")\n", + " for i, (inputs, output) in enumerate(dynamic_examples, 1):\n", + " example_input = \"\\n\".join([f\"{k}: {v}\" for k, v in inputs.items()])\n", + " example_output = \"\\n\".join([f\"{k}: {v}\" for k, v in output.items()])\n", + " \n", + " prompt_parts.append(f\"Example {i}:\\nInput:\\n{example_input}\\nOutput:\\n{example_output}\")\n", + " \n", + " \n", + " \n", + " # Combine all parts\n", + " return \"\\n\\n\".join(prompt_parts)\n", + " \n", + " def add_example(self, inputs: t.Dict, output: t.Dict) -> None:\n", + " \"\"\"\n", + " Add an example to both the prompt and the example store.\n", + " \n", + " Parameters:\n", + " -----------\n", + " inputs : Dict\n", + " Dictionary of input values\n", + " output : Dict\n", + " Dictionary of output values\n", + " \n", + " Raises:\n", + " -------\n", + " TypeError\n", + " If inputs or output is not a dictionary\n", + " \"\"\"\n", + " if (inputs, output) not in self.examples:\n", + " self.examples.append((inputs, output))\n", + " \n", + " # Add to example store\n", + " if isinstance(self.example_store, ExampleStore) and (inputs, output) not in self.example_store._examples:\n", + " self.example_store.add_example(inputs, output)\n", + " \n", + " @classmethod\n", + " def from_prompt(\n", + " cls,\n", + " prompt: Prompt,\n", + " embedding_model: BaseEmbedding,\n", + " num_examples: int = 3\n", + " ) -> \"DynamicFewShotPrompt\":\n", + " \"\"\"Create a DynamicFewShotPrompt from a Prompt object.\"\"\"\n", + " example_store = InMemoryExampleStore(embedding_model=embedding_model)\n", + " \n", + " few_shot_prompt = cls(\n", + " prompt=prompt,\n", + " example_store=example_store,\n", + " num_examples=num_examples\n", + " )\n", + " \n", + " return few_shot_prompt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example Usage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate if given answer Regularly updating your software reduces the risk of vulnerabilities. is same as expected answer Keeping software up to date helps patch known security flaws and prevents exploits.\n", + "\n", + "Examples:\n", + "\n", + "Example 1:\n", + "Input:\n", + "response: Using two-factor authentication greatly enhances account security.\n", + "expected_answer: Two-factor authentication adds a layer of protection by requiring a second form of identity verification.\n", + "Output:\n", + "score: fail\n" + ] + } + ], + "source": [ + "#| eval: false\n", + "from ragas_annotator.embedding import ragas_embedding\n", + "from ragas_annotator.prompt import Prompt\n", + "from openai import OpenAI\n", + "\n", + "embedding = ragas_embedding(provider=\"openai\", client=OpenAI(),model=\"text-embedding-3-small\")\n", + "\n", + "# Create a basic prompt\n", + "prompt = Prompt(\n", + " instruction=\"Evaluate if given answer {response} is same as expected answer {expected_answer}\"\n", + ")\n", + "\n", + "# Add examples with dict inputs and dict outputs\n", + "prompt.add_example(\n", + " {\n", + " \"response\": \"You can get a full refund if you miss your flight.\",\n", + " \"expected_answer\": \"Refunds depend on ticket type; only refundable tickets qualify for full refunds.\"\n", + " },\n", + " {\"score\": \"fail\"}\n", + ")\n", + "\n", + "prompt = DynamicFewShotPrompt.from_prompt(\n", + " prompt,\n", + " embedding_model=embedding,\n", + " num_examples=1\n", + ")\n", + "\n", + "prompt.add_example(\n", + " {\n", + " \"response\": \"Bananas are high in potassium and great for quick energy.\",\n", + " \"expected_answer\": \"Bananas provide potassium and are a good source of fast-digesting carbohydrates.\"\n", + " },\n", + " {\"score\": \"pass\"}\n", + ")\n", + "\n", + "prompt.add_example(\n", + " {\n", + " \"response\": \"Using two-factor authentication greatly enhances account security.\",\n", + " \"expected_answer\": \"Two-factor authentication adds a layer of protection by requiring a second form of identity verification.\"\n", + " },\n", + " {\"score\": \"fail\"}\n", + ")\n", + "\n", + "\n", + "prompt.example_store.get_examples(\n", + "{\n", + " \"response\": \"Regularly updating your software reduces the risk of vulnerabilities.\",\n", + " \"expected_answer\": \"Keeping software up to date helps patch known security flaws and prevents exploits.\"\n", + " })\n", + "\n", + "print(prompt.format(**{\n", + " \"response\": \"Regularly updating your software reduces the risk of vulnerabilities.\",\n", + " \"expected_answer\": \"Keeping software up to date helps patch known security flaws and prevents exploits.\"\n", + " }))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/nbs/sidebar.yml b/nbs/sidebar.yml index c6a1ec2..1ededb3 100644 --- a/nbs/sidebar.yml +++ b/nbs/sidebar.yml @@ -10,16 +10,17 @@ website: - backends/factory.ipynb - backends/mock_notion_client.ipynb - backends/notion.ipynb + - section: llm + contents: + - llm/llm.ipynb - section: metric contents: - metric/base.ipynb - metric/decorator.ipynb - metric/discrete.ipynb - - metric/llm.ipynb - metric/numeric.ipynb - metric/ranking.ipynb - metric/result.ipynb - - metric/test_base.ipynb - section: model contents: - model/notion_model.ipynb diff --git a/ragas_annotator/_modidx.py b/ragas_annotator/_modidx.py index c055929..a92767f 100644 --- a/ragas_annotator/_modidx.py +++ b/ragas_annotator/_modidx.py @@ -119,6 +119,30 @@ 'ragas_annotator.dataset.Dataset.pop': ('dataset.html#dataset.pop', 'ragas_annotator/dataset.py'), 'ragas_annotator.dataset.Dataset.save': ( 'dataset.html#dataset.save', 'ragas_annotator/dataset.py')}, + 'ragas_annotator.embedding.base': { 'ragas_annotator.embedding.base.BaseEmbedding': ( 'embedding/base.html#baseembedding', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.BaseEmbedding.aembed_document': ( 'embedding/base.html#baseembedding.aembed_document', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.BaseEmbedding.aembed_text': ( 'embedding/base.html#baseembedding.aembed_text', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.BaseEmbedding.embed_document': ( 'embedding/base.html#baseembedding.embed_document', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.BaseEmbedding.embed_text': ( 'embedding/base.html#baseembedding.embed_text', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.OpenAIEmbeddings': ( 'embedding/base.html#openaiembeddings', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.OpenAIEmbeddings.__init__': ( 'embedding/base.html#openaiembeddings.__init__', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.OpenAIEmbeddings.aembed_document': ( 'embedding/base.html#openaiembeddings.aembed_document', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.OpenAIEmbeddings.aembed_text': ( 'embedding/base.html#openaiembeddings.aembed_text', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.OpenAIEmbeddings.embed_document': ( 'embedding/base.html#openaiembeddings.embed_document', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.OpenAIEmbeddings.embed_text': ( 'embedding/base.html#openaiembeddings.embed_text', + 'ragas_annotator/embedding/base.py'), + 'ragas_annotator.embedding.base.ragas_embedding': ( 'embedding/base.html#ragas_embedding', + 'ragas_annotator/embedding/base.py')}, 'ragas_annotator.exceptions': { 'ragas_annotator.exceptions.DuplicateError': ( 'utils/exceptions.html#duplicateerror', 'ragas_annotator/exceptions.py'), 'ragas_annotator.exceptions.NotFoundError': ( 'utils/exceptions.html#notfounderror', @@ -131,8 +155,24 @@ 'ragas_annotator/experiment.py'), 'ragas_annotator.experiment.Experiment.__str__': ( 'experiment.html#experiment.__str__', 'ragas_annotator/experiment.py')}, + 'ragas_annotator.llm.llm': { 'ragas_annotator.llm.llm.RagasLLM': ('llm/llm.html#ragasllm', 'ragas_annotator/llm/llm.py'), + 'ragas_annotator.llm.llm.RagasLLM.__init__': ( 'llm/llm.html#ragasllm.__init__', + 'ragas_annotator/llm/llm.py'), + 'ragas_annotator.llm.llm.RagasLLM._check_client_async': ( 'llm/llm.html#ragasllm._check_client_async', + 'ragas_annotator/llm/llm.py'), + 'ragas_annotator.llm.llm.RagasLLM._initialize_client': ( 'llm/llm.html#ragasllm._initialize_client', + 'ragas_annotator/llm/llm.py'), + 'ragas_annotator.llm.llm.RagasLLM._run_async_in_current_loop': ( 'llm/llm.html#ragasllm._run_async_in_current_loop', + 'ragas_annotator/llm/llm.py'), + 'ragas_annotator.llm.llm.RagasLLM.agenerate': ( 'llm/llm.html#ragasllm.agenerate', + 'ragas_annotator/llm/llm.py'), + 'ragas_annotator.llm.llm.RagasLLM.generate': ( 'llm/llm.html#ragasllm.generate', + 'ragas_annotator/llm/llm.py'), + 'ragas_annotator.llm.llm.ragas_llm': ('llm/llm.html#ragas_llm', 'ragas_annotator/llm/llm.py')}, 'ragas_annotator.metric.base': { 'ragas_annotator.metric.base.Metric': ( 'metric/base.html#metric', 'ragas_annotator/metric/base.py'), + 'ragas_annotator.metric.base.Metric.__post_init__': ( 'metric/base.html#metric.__post_init__', + 'ragas_annotator/metric/base.py'), 'ragas_annotator.metric.base.Metric._ensemble': ( 'metric/base.html#metric._ensemble', 'ragas_annotator/metric/base.py'), 'ragas_annotator.metric.base.Metric._get_response_model': ( 'metric/base.html#metric._get_response_model', @@ -144,6 +184,8 @@ 'ragas_annotator.metric.base.Metric.batch_score': ( 'metric/base.html#metric.batch_score', 'ragas_annotator/metric/base.py'), 'ragas_annotator.metric.base.Metric.score': ( 'metric/base.html#metric.score', + 'ragas_annotator/metric/base.py'), + 'ragas_annotator.metric.base.Metric.train': ( 'metric/base.html#metric.train', 'ragas_annotator/metric/base.py')}, 'ragas_annotator.metric.decorator': { 'ragas_annotator.metric.decorator.create_metric_decorator': ( 'metric/decorator.html#create_metric_decorator', 'ragas_annotator/metric/decorator.py')}, @@ -153,13 +195,6 @@ 'ragas_annotator/metric/discrete.py'), 'ragas_annotator.metric.discrete.DiscreteMetric._get_response_model': ( 'metric/discrete.html#discretemetric._get_response_model', 'ragas_annotator/metric/discrete.py')}, - 'ragas_annotator.metric.llm': { 'ragas_annotator.metric.llm.LLM': ('metric/llm.html#llm', 'ragas_annotator/metric/llm.py'), - 'ragas_annotator.metric.llm.LLM.__post_init__': ( 'metric/llm.html#llm.__post_init__', - 'ragas_annotator/metric/llm.py'), - 'ragas_annotator.metric.llm.LLM.agenerate': ( 'metric/llm.html#llm.agenerate', - 'ragas_annotator/metric/llm.py'), - 'ragas_annotator.metric.llm.LLM.generate': ( 'metric/llm.html#llm.generate', - 'ragas_annotator/metric/llm.py')}, 'ragas_annotator.metric.numeric': { 'ragas_annotator.metric.numeric.NumericMetric': ( 'metric/numeric.html#numericmetric', 'ragas_annotator/metric/numeric.py'), 'ragas_annotator.metric.numeric.NumericMetric._ensemble': ( 'metric/numeric.html#numericmetric._ensemble', @@ -390,6 +425,50 @@ 'ragas_annotator/project/naming.py'), 'ragas_annotator.project.naming.MemorableNames.generate_unique_names': ( 'project/naming.html#memorablenames.generate_unique_names', 'ragas_annotator/project/naming.py')}, + 'ragas_annotator.prompt.base': { 'ragas_annotator.prompt.base.Prompt': ( 'prompt/base.html#prompt', + 'ragas_annotator/prompt/base.py'), + 'ragas_annotator.prompt.base.Prompt.__init__': ( 'prompt/base.html#prompt.__init__', + 'ragas_annotator/prompt/base.py'), + 'ragas_annotator.prompt.base.Prompt.__str__': ( 'prompt/base.html#prompt.__str__', + 'ragas_annotator/prompt/base.py'), + 'ragas_annotator.prompt.base.Prompt._format_examples': ( 'prompt/base.html#prompt._format_examples', + 'ragas_annotator/prompt/base.py'), + 'ragas_annotator.prompt.base.Prompt._validate_instruction': ( 'prompt/base.html#prompt._validate_instruction', + 'ragas_annotator/prompt/base.py'), + 'ragas_annotator.prompt.base.Prompt.add_example': ( 'prompt/base.html#prompt.add_example', + 'ragas_annotator/prompt/base.py'), + 'ragas_annotator.prompt.base.Prompt.format': ( 'prompt/base.html#prompt.format', + 'ragas_annotator/prompt/base.py')}, + 'ragas_annotator.prompt.dynamic_few_shot': { 'ragas_annotator.prompt.dynamic_few_shot.DynamicFewShotPrompt': ( 'prompt/dynamic_few_shot.html#dynamicfewshotprompt', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.DynamicFewShotPrompt.__init__': ( 'prompt/dynamic_few_shot.html#dynamicfewshotprompt.__init__', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.DynamicFewShotPrompt.add_example': ( 'prompt/dynamic_few_shot.html#dynamicfewshotprompt.add_example', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.DynamicFewShotPrompt.format': ( 'prompt/dynamic_few_shot.html#dynamicfewshotprompt.format', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.DynamicFewShotPrompt.from_prompt': ( 'prompt/dynamic_few_shot.html#dynamicfewshotprompt.from_prompt', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.ExampleStore': ( 'prompt/dynamic_few_shot.html#examplestore', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.ExampleStore.add_example': ( 'prompt/dynamic_few_shot.html#examplestore.add_example', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.ExampleStore.get_examples': ( 'prompt/dynamic_few_shot.html#examplestore.get_examples', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.InMemoryExampleStore': ( 'prompt/dynamic_few_shot.html#inmemoryexamplestore', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.InMemoryExampleStore.__init__': ( 'prompt/dynamic_few_shot.html#inmemoryexamplestore.__init__', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.InMemoryExampleStore.__len__': ( 'prompt/dynamic_few_shot.html#inmemoryexamplestore.__len__', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.InMemoryExampleStore._get_embedding': ( 'prompt/dynamic_few_shot.html#inmemoryexamplestore._get_embedding', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.InMemoryExampleStore._get_nearest_examples': ( 'prompt/dynamic_few_shot.html#inmemoryexamplestore._get_nearest_examples', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.InMemoryExampleStore.add_example': ( 'prompt/dynamic_few_shot.html#inmemoryexamplestore.add_example', + 'ragas_annotator/prompt/dynamic_few_shot.py'), + 'ragas_annotator.prompt.dynamic_few_shot.InMemoryExampleStore.get_examples': ( 'prompt/dynamic_few_shot.html#inmemoryexamplestore.get_examples', + 'ragas_annotator/prompt/dynamic_few_shot.py')}, 'ragas_annotator.tracing.langfuse': { 'ragas_annotator.tracing.langfuse.LangfuseTrace': ( 'tracing/langfuse.html#langfusetrace', 'ragas_annotator/tracing/langfuse.py'), 'ragas_annotator.tracing.langfuse.LangfuseTrace.__init__': ( 'tracing/langfuse.html#langfusetrace.__init__', diff --git a/ragas_annotator/embedding/__init__.py b/ragas_annotator/embedding/__init__.py new file mode 100644 index 0000000..eb0ef1a --- /dev/null +++ b/ragas_annotator/embedding/__init__.py @@ -0,0 +1,4 @@ +from ragas_annotator.embedding.base import BaseEmbedding +from ragas_annotator.embedding.base import ragas_embedding + +__all__ = ['ragas_embedding','BaseEmbedding'] \ No newline at end of file diff --git a/ragas_annotator/embedding/base.py b/ragas_annotator/embedding/base.py new file mode 100644 index 0000000..2d8f5f3 --- /dev/null +++ b/ragas_annotator/embedding/base.py @@ -0,0 +1,67 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/embedding/base.ipynb. + +# %% auto 0 +__all__ = ['BaseEmbedding', 'OpenAIEmbeddings', 'ragas_embedding'] + +# %% ../../nbs/embedding/base.ipynb 2 +import typing as t +from abc import ABC, abstractmethod + +#TODO: Add support for other providers like HuggingFace, Cohere, etc. +#TODO: handle async calls properly and ensure that the client supports async if needed. + +class BaseEmbedding(ABC): + @abstractmethod + def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]: + pass + + @abstractmethod + async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]: + pass + + @abstractmethod + def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]: + pass + + @abstractmethod + async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]: + pass + + +class OpenAIEmbeddings(BaseEmbedding): + def __init__(self, client: t.Any, model: str): + self.client = client + self.model = model + + def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]: + return self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding + + async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]: + response = await self.client.embeddings.create(input=text, model=self.model, **kwargs) + return response.data[0].embedding + + def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]: + embeddings = self.client.embeddings.create(input=documents, model=self.model, **kwargs) + return [embedding.embedding for embedding in embeddings.data] + + async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]: + embeddings = await self.client.embeddings.create(input=documents, model=self.model, **kwargs) + return [embedding.embedding for embedding in embeddings.data] + + +def ragas_embedding(provider: str, model: str, client: t.Any) -> BaseEmbedding: + """ + Factory function to create an embedding instance based on the provider. + + Args: + provider (str): The name of the embedding provider (e.g., "openai"). + model (str): The model name to use for embeddings. + **kwargs: Additional arguments for the provider's client. + + Returns: + BaseEmbedding: An instance of the specified embedding provider. + """ + if provider.lower() == "openai": + return OpenAIEmbeddings(client=client, model=model) + + raise ValueError(f"Unsupported provider: {provider}") diff --git a/ragas_annotator/llm/__init__.py b/ragas_annotator/llm/__init__.py new file mode 100644 index 0000000..cea67d0 --- /dev/null +++ b/ragas_annotator/llm/__init__.py @@ -0,0 +1,3 @@ +from ragas_annotator.llm.llm import RagasLLM, ragas_llm + +__all__ = ["RagasLLM", "ragas_llm"] \ No newline at end of file diff --git a/ragas_annotator/llm/llm.py b/ragas_annotator/llm/llm.py new file mode 100644 index 0000000..f4e0086 --- /dev/null +++ b/ragas_annotator/llm/llm.py @@ -0,0 +1,145 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/llm/llm.ipynb. + +# %% auto 0 +__all__ = ['T', 'RagasLLM', 'ragas_llm'] + +# %% ../../nbs/llm/llm.ipynb 2 +import typing as t +import asyncio +import inspect +import threading +from pydantic import BaseModel +import instructor + +T = t.TypeVar('T', bound=BaseModel) + +class RagasLLM: + def __init__(self, provider: str, model: str, client: t.Any, **model_args): + self.provider = provider.lower() + self.model = model + self.model_args = model_args or {} + self.client = self._initialize_client(provider, client) + # Check if client is async-capable at initialization + self.is_async = self._check_client_async() + + def _check_client_async(self) -> bool: + """Determine if the client is async-capable.""" + try: + # Check if this is an async client by checking for a coroutine method + if hasattr(self.client.chat.completions, 'create'): + return inspect.iscoroutinefunction(self.client.chat.completions.create) + return False + except (AttributeError, TypeError): + return False + + def _initialize_client(self, provider: str, client: t.Any) -> t.Any: + provider = provider.lower() + + if provider == "openai": + return instructor.from_openai(client) + elif provider == "anthropic": + return instructor.from_anthropic(client) + elif provider == "cohere": + return instructor.from_cohere(client) + elif provider == "gemini": + return instructor.from_gemini(client) + elif provider == "litellm": + return instructor.from_litellm(client) + else: + raise ValueError(f"Unsupported provider: {provider}") + + def _run_async_in_current_loop(self, coro): + """Run an async coroutine in the current event loop if possible. + + This handles Jupyter environments correctly by using a separate thread + when a running event loop is detected. + """ + try: + # Try to get the current event loop + loop = asyncio.get_event_loop() + + if loop.is_running(): + # If the loop is already running (like in Jupyter notebooks), + # we run the coroutine in a separate thread with its own event loop + result_container = {'result': None, 'exception': None} + + def run_in_thread(): + # Create a new event loop for this thread + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + # Run the coroutine in this thread's event loop + result_container['result'] = new_loop.run_until_complete(coro) + except Exception as e: + # Capture any exceptions to re-raise in the main thread + result_container['exception'] = e + finally: + # Clean up the event loop + new_loop.close() + + # Start the thread and wait for it to complete + thread = threading.Thread(target=run_in_thread) + thread.start() + thread.join() + + # Re-raise any exceptions that occurred in the thread + if result_container['exception']: + raise result_container['exception'] + + return result_container['result'] + else: + # Standard case - event loop exists but isn't running + return loop.run_until_complete(coro) + + except RuntimeError: + # If we get a runtime error about no event loop, create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(coro) + finally: + # Clean up + loop.close() + asyncio.set_event_loop(None) + + def generate(self, prompt: str, response_model: t.Type[T]) -> T: + """Generate a response using the configured LLM. + + For async clients, this will run the async method in the appropriate event loop. + """ + messages = [{"role": "user", "content": prompt}] + + # If client is async, use the appropriate method to run it + if self.is_async: + return self._run_async_in_current_loop( + self.agenerate(prompt, response_model) + ) + else: + # Regular sync client, just call the method directly + return self.client.chat.completions.create( + model=self.model, + messages=messages, + response_model=response_model, + **self.model_args, + ) + + async def agenerate(self, prompt: str, response_model: t.Type[T]) -> T: + """Asynchronously generate a response using the configured LLM.""" + messages = [{"role": "user", "content": prompt}] + + # If client is not async, raise a helpful error + if not self.is_async: + raise TypeError( + "Cannot use agenerate() with a synchronous client. Use generate() instead." + ) + + # Regular async client, call the method directly + return await self.client.chat.completions.create( + model=self.model, + messages=messages, + response_model=response_model, + **self.model_args, + ) + +def ragas_llm(provider: str, model: str, client: t.Any, **model_args) -> RagasLLM: + return RagasLLM(provider=provider, client=client, model=model, **model_args) diff --git a/ragas_annotator/metric/__init__.py b/ragas_annotator/metric/__init__.py index 57a31d3..4733fc4 100644 --- a/ragas_annotator/metric/__init__.py +++ b/ragas_annotator/metric/__init__.py @@ -1,12 +1,10 @@ from ragas_annotator.metric.result import MetricResult -from ragas_annotator.metric.llm import LLM from ragas_annotator.metric.base import Metric from ragas_annotator.metric.discrete import DiscreteMetric from ragas_annotator.metric.numeric import NumericMetric from ragas_annotator.metric.ranking import RankingMetric __all__ = ['MetricResult', - 'LLM', 'Metric', 'DiscreteMetric', 'NumericMetric', diff --git a/ragas_annotator/metric/base.py b/ragas_annotator/metric/base.py index d37b9c5..78d1104 100644 --- a/ragas_annotator/metric/base.py +++ b/ragas_annotator/metric/base.py @@ -11,19 +11,32 @@ from dataclasses import dataclass, field from pydantic import BaseModel import typing as t +import json +from tqdm import tqdm + +from ..prompt.base import Prompt +from ..embedding.base import BaseEmbedding from . import MetricResult -from . import LLM +from ..llm import RagasLLM +from ..project.core import Project +from ..model.notion_model import NotionModel +from ..prompt.dynamic_few_shot import DynamicFewShotPrompt + @dataclass class Metric(ABC): """Base class for all metrics in the LLM evaluation library.""" name: str - prompt: str - llm: LLM + prompt: str | Prompt + llm: RagasLLM _response_models: t.Dict[bool, t.Type[BaseModel]] = field( default_factory=dict, init=False, repr=False ) + def __post_init__(self): + if isinstance(self.prompt,str): + self.prompt = Prompt(self.prompt) + @abstractmethod def _get_response_model(self, with_reasoning: bool) -> t.Type[BaseModel]: """Get the appropriate response model.""" @@ -36,22 +49,32 @@ def _ensemble(self, results: t.List[MetricResult]) -> MetricResult: def score(self, reasoning: bool = True, n: int = 1, **kwargs) -> t.Any: responses = [] + traces = {} + traces["input"] = kwargs prompt_input = self.prompt.format(**kwargs) for _ in range(n): response = self.llm.generate(prompt_input, response_model = self._get_response_model(reasoning)) + traces['output'] = response.model_dump() response = MetricResult(**response.model_dump()) responses.append(response) - return self._ensemble(responses) + results = self._ensemble(responses) + results.traces = traces + return results async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs) -> MetricResult: responses = [] # Added missing initialization + traces = {} + traces["input"] = kwargs prompt_input = self.prompt.format(**kwargs) for _ in range(n): response = await self.llm.agenerate(prompt_input, response_model = self._get_response_model(reasoning)) + traces['output'] = response.model_dump() response = MetricResult(**response.model_dump()) # Fixed missing parentheses responses.append(response) - return self._ensemble(responses) + results = self._ensemble(responses) + results.traces = traces + return results def batch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool = True, n: int = 1) -> t.List[t.Any]: return [self.score(reasoning, n, **input_dict) for input_dict in inputs] @@ -64,3 +87,30 @@ async def abatch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool # Run all tasks concurrently and return results return await asyncio.gather(*async_tasks) + + def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: BaseEmbedding,method: t.Dict[str, t.Any]): + + assert isinstance(self.prompt, Prompt) + self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model) + datasets = [] + for experiment_name in experiment_names: + experiment_data = project.get_experiment(experiment_name,model) + experiment_data.load() + datasets.append(experiment_data) + + total_items = sum([len(dataset) for dataset in datasets]) + with tqdm(total=total_items, desc="Processing examples") as pbar: + for dataset in datasets: + for row in dataset: + if hasattr(row, f'{self.name}_traces'): + traces = json.loads(getattr(row, f'{self.name}_traces')) + if traces: + self.prompt.add_example(traces['input'],traces['output']) + pbar.update(1) + + + + + + + diff --git a/ragas_annotator/metric/decorator.py b/ragas_annotator/metric/decorator.py index 016773a..aa1b76e 100644 --- a/ragas_annotator/metric/decorator.py +++ b/ragas_annotator/metric/decorator.py @@ -11,6 +11,8 @@ import asyncio from dataclasses import dataclass from . import MetricResult +from ..llm import RagasLLM +from ..prompt.base import Prompt @@ -25,7 +27,7 @@ def create_metric_decorator(metric_class): Returns: A decorator factory function for the specified metric type """ - def decorator_factory(llm, prompt, name: t.Optional[str] = None, **metric_params): + def decorator_factory(llm:RagasLLM, prompt: t.Union[str, Prompt], name: t.Optional[str] = None, **metric_params): """ Creates a decorator that wraps a function into a metric instance. @@ -44,17 +46,9 @@ def decorator(func): metric_name = name or func.__name__ is_async = inspect.iscoroutinefunction(func) + #TODO: Move to dataclass type implementation @dataclass class CustomMetric(metric_class): - def _extract_result(self, result, reasoning: bool): - """Extract score and reason from the result.""" - if isinstance(result, tuple) and len(result) == 2: - score, reason = result - else: - score, reason = result, None - - # Use "result" instead of "score" for the new MetricResult implementation - return MetricResult(result=score, reason=reason if reasoning else None) def _run_sync_in_async(self, func, *args, **kwargs): """Run a synchronous function in an async context.""" @@ -81,7 +75,7 @@ def _execute_metric(self, is_async_execution, reasoning, **kwargs): # Sync function implementation result = func(self.llm, self.prompt, **kwargs) - return self._extract_result(result, reasoning) + return result except Exception as e: # Handle errors gracefully error_msg = f"Error executing metric {self.name}: {str(e)}" @@ -100,7 +94,7 @@ async def ascore(self, reasoning: bool = True, n: int = 1, **kwargs): else: # For sync functions, run normally result = self._run_sync_in_async(func, self.llm, self.prompt, **kwargs) - return self._extract_result(result, reasoning) + return result # Create the metric instance with all parameters metric_instance = CustomMetric( diff --git a/ragas_annotator/metric/llm.py b/ragas_annotator/metric/llm.py deleted file mode 100644 index c602e53..0000000 --- a/ragas_annotator/metric/llm.py +++ /dev/null @@ -1,35 +0,0 @@ -# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/metric/llm.ipynb. - -# %% auto 0 -__all__ = ['LLM'] - -# %% ../../nbs/metric/llm.ipynb 1 -import openai -import instructor -from dataclasses import dataclass - -@dataclass -class LLM: - - def __post_init__(self): - self.aclient = instructor.from_openai(openai.AsyncOpenAI()) - self.client = instructor.from_openai(openai.OpenAI()) - - - def generate(self,prompt,response_model): - return self.client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "user", "content": prompt}, - ], - response_model=response_model, - ) - - async def agenerate(self,prompt,response_model): - return await self.aclient.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "user", "content": prompt}, - ], - response_model=response_model, - ) diff --git a/ragas_annotator/metric/result.py b/ragas_annotator/metric/result.py index c4636c7..a50e97e 100644 --- a/ragas_annotator/metric/result.py +++ b/ragas_annotator/metric/result.py @@ -24,9 +24,14 @@ class MetricResult: - RankingMetrics (list results) """ - def __init__(self, result: t.Any, reason: t.Optional[str] = None): + def __init__(self, result: t.Any, reason: t.Optional[str] = None, traces: t.Optional[t.Dict[str, t.Any]] = None): + if traces is not None: + invalid_keys = [key for key in traces.keys() if key not in {"input", "output"}] + if invalid_keys: + raise ValueError(f"Invalid keys in traces: {invalid_keys}. Allowed keys are 'input' and 'output'.") self._result = result self.reason = reason + self.traces = traces def __repr__(self): return repr(self._result) diff --git a/ragas_annotator/prompt/__init__.py b/ragas_annotator/prompt/__init__.py new file mode 100644 index 0000000..a0dffbc --- /dev/null +++ b/ragas_annotator/prompt/__init__.py @@ -0,0 +1,5 @@ +from ragas_annotator.prompt.base import Prompt +from ragas_annotator.prompt.dynamic_few_shot import DynamicFewShotPrompt + + +__all__ = ['Prompt', 'DynamicFewShotPrompt'] \ No newline at end of file diff --git a/ragas_annotator/prompt/base.py b/ragas_annotator/prompt/base.py new file mode 100644 index 0000000..f691d98 --- /dev/null +++ b/ragas_annotator/prompt/base.py @@ -0,0 +1,92 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/prompt/base.ipynb. + +# %% auto 0 +__all__ = ['Prompt'] + +# %% ../../nbs/prompt/base.ipynb 2 +import typing as t +import re + +class Prompt: + def __init__( + self, + instruction: str, + examples: t.Optional[t.List[t.Tuple[t.Dict, t.Dict]]] = None + ): + """ + Create a simple prompt object. + + Parameters: + ----------- + instruction : str + The prompt instruction template with placeholders like {response}, {expected_answer} + examples : Optional[List[Tuple[Dict, Dict]]] + List of (input_dict, output_dict) pairs for few-shot learning + """ + self.instruction = instruction + self.examples = [] + + # Validate the instruction + self._validate_instruction() + + # Add examples if provided + if examples: + for inputs, output in examples: + self.add_example(inputs, output) + + def _validate_instruction(self): + """Ensure the instruction contains at least one placeholder.""" + if not re.findall(r"\{(\w+)\}", self.instruction): + raise ValueError("Instruction must contain at least one placeholder like {response}") + + def format(self, **kwargs) -> str: + """Format the prompt with the provided variables.""" + + prompt_parts = [] + prompt_parts.append(self.instruction.format(**kwargs)) + prompt_parts.append(self._format_examples()) + + # Combine all parts + return "\n\n".join(prompt_parts) + + def _format_examples(self) -> str: + + # Add examples in a simple format + examples = [] + if self.examples: + examples.append("Examples:") + for i, (inputs, output) in enumerate(self.examples, 1): + example_input = "\n".join([f"{k}: {v}" for k, v in inputs.items()]) + example_output = "\n".join([f"{k}: {v}" for k, v in output.items()]) + + examples.append(f"Example {i}:\nInput:\n{example_input}\nOutput:\n{example_output}") + + return "\n\n".join(examples) if examples else "" + + + def add_example(self, inputs: t.Dict, output: t.Dict) -> None: + """ + Add an example to the prompt. + + Parameters: + ----------- + inputs : Dict + Dictionary of input values + output : Dict + Dictionary of output values + + Raises: + ------- + TypeError + If inputs or output is not a dictionary + """ + if not isinstance(inputs, dict): + raise TypeError(f"Expected inputs to be dict, got {type(inputs).__name__}") + if not isinstance(output, dict): + raise TypeError(f"Expected output to be dict, got {type(output).__name__}") + + self.examples.append((inputs, output)) + + def __str__(self) -> str: + """String representation showing the instruction.""" + return f"Prompt(instruction='{self.instruction}',\n examples={self.examples})" diff --git a/ragas_annotator/prompt/dynamic_few_shot.py b/ragas_annotator/prompt/dynamic_few_shot.py new file mode 100644 index 0000000..56e2f7a --- /dev/null +++ b/ragas_annotator/prompt/dynamic_few_shot.py @@ -0,0 +1,200 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/prompt/dynamic_few_shot.ipynb. + +# %% auto 0 +__all__ = ['ExampleStore', 'InMemoryExampleStore', 'DynamicFewShotPrompt'] + +# %% ../../nbs/prompt/dynamic_few_shot.ipynb 2 +import typing as t +import numpy as np +from abc import ABC, abstractmethod + +from .base import Prompt +from ..embedding import BaseEmbedding + +class ExampleStore(ABC): + @abstractmethod + def get_examples( + self, data: t.Dict, top_k: int = 5 + ) -> t.List[t.Tuple[t.Dict, t.Dict]]: + """Get top_k most similar examples to data.""" + pass + + @abstractmethod + def add_example(self, inputs: t.Dict, output: t.Dict) -> None: + """Add an example to the store.""" + pass + + +class InMemoryExampleStore(ExampleStore): + def __init__(self, embedding_model=None): + """ + Initialize an in-memory example store with optional embedding model. + + Args: + embedding_model: Model used to generate embeddings (OpenAI or similar) + """ + self.embedding_model = embedding_model + self._examples: t.List[t.Tuple[t.Dict, t.Dict]] = [] + self._embeddings_list: t.List[t.List[float]] = [] + + def _get_embedding(self, data: t.Dict) -> t.List[float]: + """Convert input dict to an embedding vector.""" + if self.embedding_model is None: + return [] + + # Serialize the dictionary to text + text = "\n".join([f"{k}: {v}" for k, v in data.items()]) + return self.embedding_model.embed_text(text) + + def add_example(self, inputs: t.Dict, output: t.Dict) -> None: + """Add an example to the store with its embedding.""" + if not isinstance(inputs, dict): + raise TypeError(f"Expected inputs to be dict, got {type(inputs).__name__}") + if not isinstance(output, dict): + raise TypeError(f"Expected output to be dict, got {type(output).__name__}") + + self._examples.append((inputs, output)) + + if self.embedding_model: + embedding = self._get_embedding(inputs) + self._embeddings_list.append(embedding) + + def get_examples( + self, data: t.Dict, top_k: int = 5, threshold: float = 0.7 + ) -> t.List[t.Tuple[t.Dict, t.Dict]]: + """Get examples most similar to the input data.""" + if not self._examples: + return [] + + if not self.embedding_model or not self._embeddings_list: + # If no embedding model, return the most recent examples + return self._examples[-top_k:] + + # Get embedding for the query + query_embedding = self._get_embedding(data) + + # Find most similar examples + indices = self._get_nearest_examples( + query_embedding, self._embeddings_list, top_k, threshold + ) + + # Return the examples at those indices + return [self._examples[i] for i in indices] + + def _get_nearest_examples( + self, + query_embedding: t.List[float], + embeddings: t.List[t.List[float]], + top_k: int = 3, + threshold: float = 0.7, + ) -> t.List[int]: + """Find indices of the nearest examples based on cosine similarity.""" + # Convert to numpy arrays for efficient computation + query = np.array(query_embedding) + embed_matrix = np.array(embeddings) + + # Calculate cosine similarity + similarities = np.dot(embed_matrix, query) / ( + np.linalg.norm(embed_matrix, axis=1) * np.linalg.norm(query) + 1e-8 + ) + + # Get indices of similarities above threshold + valid_indices = np.where(similarities >= threshold)[0] + + # Sort by similarity and get top-k + if len(valid_indices) > 0: + top_indices = valid_indices[np.argsort(similarities[valid_indices])[-top_k:]] + # Convert numpy indices to Python ints + return [int(idx) for idx in top_indices] + + # If no examples meet threshold, return most recent examples + return list(range(max(0, len(embeddings) - top_k), len(embeddings))) + + def __len__(self): + return len(self._examples) + + + + +class DynamicFewShotPrompt(Prompt): + + def __init__( + self, + prompt: Prompt, + example_store: InMemoryExampleStore, + num_examples: int = 3 + ): + + self.example_store = example_store + super().__init__(prompt.instruction, prompt.examples) + self.num_examples = num_examples + + for example in prompt.examples: + self.example_store.add_example(*example) + + def format(self, **kwargs) -> str: + """Format the prompt with dynamically retrieved examples.""" + prompt_parts = [] + + # Add instruction with variables filled in + prompt_parts.append(self.instruction.format(**kwargs)) + + # Get dynamic examples if we have a store and inputs + dynamic_examples = [] + if self.example_store and kwargs: + dynamic_examples = self.example_store.get_examples(kwargs, self.num_examples) + + # Add examples in a simple format + if dynamic_examples: + prompt_parts.append("Examples:") + for i, (inputs, output) in enumerate(dynamic_examples, 1): + example_input = "\n".join([f"{k}: {v}" for k, v in inputs.items()]) + example_output = "\n".join([f"{k}: {v}" for k, v in output.items()]) + + prompt_parts.append(f"Example {i}:\nInput:\n{example_input}\nOutput:\n{example_output}") + + + + # Combine all parts + return "\n\n".join(prompt_parts) + + def add_example(self, inputs: t.Dict, output: t.Dict) -> None: + """ + Add an example to both the prompt and the example store. + + Parameters: + ----------- + inputs : Dict + Dictionary of input values + output : Dict + Dictionary of output values + + Raises: + ------- + TypeError + If inputs or output is not a dictionary + """ + if (inputs, output) not in self.examples: + self.examples.append((inputs, output)) + + # Add to example store + if isinstance(self.example_store, ExampleStore) and (inputs, output) not in self.example_store._examples: + self.example_store.add_example(inputs, output) + + @classmethod + def from_prompt( + cls, + prompt: Prompt, + embedding_model: BaseEmbedding, + num_examples: int = 3 + ) -> "DynamicFewShotPrompt": + """Create a DynamicFewShotPrompt from a Prompt object.""" + example_store = InMemoryExampleStore(embedding_model=embedding_model) + + few_shot_prompt = cls( + prompt=prompt, + example_store=example_store, + num_examples=num_examples + ) + + return few_shot_prompt diff --git a/settings.ini b/settings.ini index 0215d37..3bd13b7 100644 --- a/settings.ini +++ b/settings.ini @@ -38,7 +38,7 @@ status = 3 user = explodinggradients ### Dependencies ### -requirements = notion-client fastcore tqdm langfuse openai instructor pydantic +requirements = notion-client fastcore tqdm langfuse instructor pydantic numpy dev_requirements = pytest # console_scripts = # conda_user =