From aabb0151595a930d966e40e34f5e0b9258f96f97 Mon Sep 17 00:00:00 2001 From: TejasGupta-27 Date: Sun, 27 Oct 2024 14:05:00 +0530 Subject: [PATCH] RAG --- RAG/README | 2 + .../__pycache__/client.cpython-312.pyc | Bin 0 -> 1754 bytes RAG/VectorDB/client.py | 26 ++++ .../database_handler.cpython-312.pyc | Bin 0 -> 3661 bytes RAG/__pycache__/preprocessing.cpython-312.pyc | Bin 0 -> 1873 bytes RAG/__pycache__/rag_pipeline.cpython-312.pyc | Bin 0 -> 4152 bytes RAG/database_handler.py | 112 +++++++++++++++ RAG/preprocessing.py | 29 ++++ RAG/rag_pipeline.py | 64 +++++++++ RAG/system_prompt.py | 58 ++++++++ RAG/ui.py | 134 ++++++++++++++++++ 11 files changed, 425 insertions(+) create mode 100644 RAG/README create mode 100644 RAG/VectorDB/__pycache__/client.cpython-312.pyc create mode 100644 RAG/VectorDB/client.py create mode 100644 RAG/__pycache__/database_handler.cpython-312.pyc create mode 100644 RAG/__pycache__/preprocessing.cpython-312.pyc create mode 100644 RAG/__pycache__/rag_pipeline.cpython-312.pyc create mode 100644 RAG/database_handler.py create mode 100644 RAG/preprocessing.py create mode 100644 RAG/rag_pipeline.py create mode 100644 RAG/system_prompt.py create mode 100644 RAG/ui.py diff --git a/RAG/README b/RAG/README new file mode 100644 index 0000000..34131ff --- /dev/null +++ b/RAG/README @@ -0,0 +1,2 @@ +Get GEMINI API KEY from google console +For Database query you can alter the prompts according to your schema diff --git a/RAG/VectorDB/__pycache__/client.cpython-312.pyc b/RAG/VectorDB/__pycache__/client.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a71abe7e8fe9ba13d53da5f47a7b1a4ba320924 GIT binary patch literal 1754 zcmbtU&u<$=6rNe{+S{adY&YxtXmJgSs&1%bCB#WU2q;Oz0W|`Ga)?x|*0bZ(VQq(* zO;I8T<n_@yn=npXnn$&Vm=fSy;EM#FDEsI{VaamfHFtS2-(Q?R&Sn^%G zEL#eU6dJa|pIt?WsS={fUMmSk8+Zfj>Yrk~pkwa>Ef=U|l`1PnalJyEQnkWK3+#6x zu;M$N0COKX*q;Tn8Wa{xKyTrFX|ILYLN&CA^w7fQY_3eLl9_Yps>ySOiM&OorY5r! zCMDU)iAj>3DNd1Wapui@p+Lzr&0Dz}tVB4kR*X-se~`0G$AnVlHPgAROP)$@vufG2 z3SRsp%6N%8wBpR#C9r@B6W`;aip-0T|L4|AlsHv3_kOOKb!=_JORP}GXqI3!XPF>R zoDm-cvyGldwMO(xJ$l8x_W5CyZO9)CsBKitG8LTqeVB=3@kSPbLRum8Q)X6p zvC7Jn-3a74-%_U+w}EkkRBao$5prelf5=M^Moouy1-zKgR2!(POclt!&RV|=78SDt zytTREbappr(d+HNxrd1t?~MKB-0SJgaIr(C)tER)Z%9;8WyVYOaW(|;2z>lYFrb_C zcq6@3PcOLz`!HR0<%>tj!A4TAC-sBm>)Rh5U47%{fyv$RZ}Iox2Vr+$(v|zc{uW8O z>9Klb>`1%V(6V(c`}kc~%N}U6J3Yr?bUx+E+HnY_E;Y2#x;FZ_&(%f`wD)#;j(Rl+ z{e2QegLy9Gw0Gh9HJO0ZZemt&bGi?vTcT~L2_bZgy0x?&$2B~MZskFO(%GV|NoNT& zkUB7EDgFcP8`*6Bpf!1<&Nzx>K^bf5n==E@bCdDKd!*Xc#{rKehj3 zFVwf;j;`DE9rhX^#Y4(-V2(o=<7cSvS2T7aNjQ0`p2Mk=l!9lDV@N)~v-u=6^ouTOXy!U%=<}a~W1VOWxKQG>d(Kn26o4-3a{xOHpQ=}k;(@>%F=L(!d`>Or|&oQ1) z6RLqifI~h+0x!Bt9ST+AuFC#j@xK7AtQA7Ag0GBssKK6uzTjOcM-yL}7@Y5n6~UD$ zM3q2c#L@OWawJ+g`?{tQ`XboV6({k^JFm--5my3E9eRig2_>q8A995;B?8oVgG-M5 zlP$|7IXkKuk|HXGMf5v%d}V!oWi=zFHnQT4%ttUClxnK@Thh3nD3PkFtJ8Bc3)4kS zGRxAR1%x75T$b(-Ea}+1t*I8lieyRKl1cDg)hc7FOmNZAG~=$SmoUs|!~|gl%Z663 z>gJV5WQ#GeE>+1DoR=)Jp-Flq?M}O#tb((-t;|X$k2mt!wN(Bi3dqhHjO$#xfyzd38Ot6-i5)#>%j+T#pN~P@gmn>x3w{Vbrl@;4;x_7_2T4oq%BkZ~k%> z->ws?0(N1(L@cMe;UYDvxZ3EIx|zs^Mp(&%e6GPcu%T;>e*I%A8d!x2sa9)H>MU~4 zf}%*Pk~M>u46GX#u2JKT3aKmU4NUIUpaeReF|k3hq$mz{B+Xem6NzMtPGPN^;C;(A zuWBf2vEfW8ZVh$~a%%B=nS6dde+A2u&bI8zE8FiySe-AcQ0!0J&)pB)x^)X~FdNP z;PaXtqy*w56I-ZJRkxCUI{;Cv)+{^Jal1+EFsKxr_?ygQ7~}8qvsI#~(yT=)k~zCQ zzo5vMFJGFTmxw%lY5tNt{eJPXJY9VMgN5yFA}^8!Wi|sf%QOvL%%@gn-Pnp{D1REE zGqr|2?m|}V?MF|+e!qdwEI`*pUyYtR8pVgB`1$G9=*8x>cH-@)u}@=1iK)ZHR4egr zGuuv_IZ9kOOk8+=@sC$qiP@vX;$dR(OW{k}N?bikTsur$YbCBXvtRMyz2r-Nq8*($ z;3tkjJCeHrVq*laf(~f@1p@$04^WjsDJBVRhjT;W0D!F>YCVAXcl;*-s1G{`rvD^} z^fd4=?>r^FTP&e{LH3p46gL4tyTXp3aA!c|Q{NxLJlYKa$5)AY*D%V}3;PCk_#yyF z_Hpp-3U^V`m!s^s$)FwVJdA#32cgF4M7HdJT&Jc%?I5{FN!d%0ZQwkrOqK_LHP$G5v)bd% z1|q|b$f}p9Y|{%cA;M>B&;dZA~I&YRAU*KYsl2CpVkRuL4MjKFU4F zJ^A1zf36*$*f$;l2-Cz4ZQ(1sDRpnw_Zyxb`}pdElK3ZC(lrI8&5m=x7gof3O=rE{f$9_NW<3sC%t zX2;BiX_2bfty>XzK~hU(g^#jQOj)uxg~S1i---Uh-dg>ba1)T>;7ANyF z=wA6bj{63M|AFFPqp7da`M;y+!&o~sw&&XqKMo&^|ELw3Y6jbAv?)9qeK7hg_*aD6 zp>un}QRv)Z=-k2ipSD7CP2p7_!;SIH*sBd63ZEaIZ3jn=f|G~A$>+X<;AAT})#TfV fKSE=NXza=KON77W1@5imc$oX)@o7Id=Ggrot+v9@ literal 0 HcmV?d00001 diff --git a/RAG/__pycache__/preprocessing.cpython-312.pyc b/RAG/__pycache__/preprocessing.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac22df38f9f024a677d303434bf4edca87511355 GIT binary patch literal 1873 zcmZ`)UrZcD7@wJ)h5NVVpn=d<2(=GpF@UA~O^qpktgR+jQhec(dMrD8aL3*5b#^I# za>Ua#5n>}WHG%kqi4Q71)F)p}^ufdzkXp%Dk|w^`w+2j_^5i$m-3hhK=Dzu6=9}4X zf8TF@_p7Fb5VXHPzL)G#5c*3FN)5(-+na-K1u?{M2E}DhGB``(1WRAZC|Me(zE3l1 zwjy5P`)Z~#8;A!mBFIANn&Mf9I##ic4~7t5c^kz;EC4plf?y*o%tE(uT${yu(91GMdK>&yze(6Mx4MGl*`~BQ&q%NnS~-N0$7-dU1*oiQAk8vEfkV2&WfqW~F>bjcV$FEmFyYg{(WwaJSTd1JWh_y_{jU>Lt_bRf%biWo z_OG>NEtWFcJS%OuZDZ}7%I$7J>JY~AB zV<)cnz0;QCR*pNS<+>?*yfrr~Y7gg;NOC8efK0e7UJ!n6AD9Aqq}46YY-vq}%iCeC zFt{D9TZ!C_>>{kbzCgDFk!AWv^z3Hz>{j&L8rzIEZw8td=~D%kABU?7@9n`M=tw~- zVB(m9FwDqz%6gHCooF6IQA)&=2Nhch^mzi5glD~2IPAlI4xL9Xo(95>APmokP>LC4 z1WjX|#subb>Xesk>5;^-Y|~)?hfV}_a+WQqtbncv&71N3nj~_DHz9D6SHZhonG4i{ zS400Y{9qS918Y$k+zAHo3u+75)#0zZ9|SKv1=Q;A-`uW>F5OzZ_36w4eFi|)`#?QY zOWwX4fSpU#&C0iGN97UhBpH0{J_F}CG9O3eC1R3sWRe>vBuVomEl+!)-6eOT6zC4f z;&kN;i!jV9;z3aDO*=XkU8ee z0pAsHMX(&m8(`xB(DJ4M&?Ttr{!dkT3V#*CYT@Tz0s~lwYLt2&WpLRGS07!x^kenU1HS}+ifz3A!PdpZ!|>>RV_{(F%Hoyf zk=2`P)obxD8#ij-_&sv|xA5o&9eoa+kMdWIeJa)Q6N!|5LQp&7O$trg){@>Id`QL6 z?9kxl_5=MSDTm7)g8IexX%=A4w4HGb=IfS!`FvLRX@7(GNr?Ba>we(Ptiz@<)?54n z%*nq2w;K#d6~=#}@D6%y2fea`>UPkn$4D#CPqdG;h4BZdZkMV!x_d$a6QVf!PlR9t H`v?CEWoV+~ literal 0 HcmV?d00001 diff --git a/RAG/__pycache__/rag_pipeline.cpython-312.pyc b/RAG/__pycache__/rag_pipeline.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57969fe830544f7407d51b2d585ba999d42fba22 GIT binary patch literal 4152 zcma)9O>7&-6`tkp@@M(0A4>FZ)<1E=*rH=amJ%aLA=$DOr;(f3NfDwhvEmNNrI)+x z?9#H>Wy%x*VxtC<00pX|4XU2%z&`lsTZ^WL0=-aS60sAx2~b~hbL?K?0!7~}xuo3q z&;fRS-+S|B=Dqj5{aZ9DAt?X;$L|Z3Fhc)i1F!JC&i%a{Lcc;PQaK$}xGcx9zEAgM zeQwKFgskY!@w&ef$Od}r!E6xL1wB-evQj0S4Ob%BNF|z$y6a+BRwb5=xnsW`uOzYw z4*3vHT<>k|j(p|4o=SsO4O|DUSji3phgd$)9c``g@Qt8HJtGm5!F&B)uDX;^82Ch;YkD0w?)Lnl`Nrd(AmP$^%k85i71I`&->)y!Xd z6svv7PWuRZIl7H$VUk>)V8zC{ys7KBw<|}(x~Zr+)wHp3k;YD)Idkg!^SRj{yp(%+ z{!P#|pj5To1-w#E6mdl}w232`qZ0*Pu}W;Qq*=B}R`SA@s|rA$ZF~Qij~R>IdmMzk zr)&teXHduy9M0jd_r?v;7ClL2zt0Zf5J?6^liplXDSBS0baD25(MAp?cjqZoN4Mcht$^Qzwo{syUm4|z|M zpL6!Ox9Wr(-*wQ8m4;+*>-G?`cX&OgZ+;&C3G`kd?<=Cb@BCxzO!h91)`oGnlx43s zXua3)l2l#f$fPU#ljO8CqLq-R54Yo=M1HS>p})Rq%LNd z6Tkp$zCQlCVb!Wtlh|053!08)d!>qx$@RxG^4b0Zo3mtvU>O7)k{SOY8Q3#v9}U8& zRsmn7qU(BTp7NS%(U3O}-3qo9Rk0N-J?#1z3BmaYa|62nEU;oDkIxasunH!rU^4GH z{K=lbQ=e&3p{m#=8nmjq#$28-+b7I@2(xS|y`kteJWq&82y+H1g1{oCyp9dZTQ;FW zt){6|Wc3u~6;(~gXw=g(=Sq(cyxqMrFPvCf9TR5I)ZYzEu+WbSbja$RAqNr3uyVLS zxg3qL>&$_s*svgGQGsa<{473MWRwZxG#c#RO$=t_hsuM+!xgM*%3&Lq73=WQI&V);tz)K)qSyCKeAfnFzn8J+WHSvO|%)^-Hyu5lXr%mx;B0Ddw)9J9y-%Jy*ZZN z7+YK)TWn=3pN<(VA-S2_{=2}h0~@I)*Hcek+tp6ZtP9EJZ1b&7Ah|mBN#L=~`2Mx? z?f5fSgidIDb>@@M-p%;qYn66!5hbK`mev-KJe`4(sOG=SD(3RzgNFf|Aq5ecxOv` zu6Z1)&DoQGvHm*!@lV^cXFrdhyEVId{^qOg#7rxGuDJl!S3>OF#kUvVx7G@`g3ZOx zgxPQJ1yK0dw={mviog~8d20WO5#i&fBPUWqz87R*8X;!2Zotk$XKTPNZ>3_CP#y}w zK8Gv!19o?v)eb6hDt|-hhce#b&!V)rNZ5Oj1XN!$Ht@CZPimN~xZnm}W#kC9YPxNu zdDoW+i;PrsY0?6jV%!Yf$rEfo&?C0k0hpPQ44WP%*d`jjIIs$fOvA$C@L5ko)kJr` z^Xgy4QX8?|>#^O}=C94%Qd+Uy?bvMdR3|*#T)YoF5SUoqE6f|+8~df34?P0ne2%Z& z12vC{wO-I|@G9pZ00Q5KJisA4fLqGU`~Rmnf)}F8eZU)#l*5C**X!>hs(i!m@F&pu z{Vp{0A<+??K>u|Z2nsCb10A?*Bjkjj$AY@#_#LSZTn#@9>Q@u+9T1Qzy1M%vzYAqz zBWyEQcT2eslHLs1L)be3fI{R^r)Wg&K6DPia*Nclfi$nD2AR&DWUu9r69rofE!O`q zf2pb%s%#S34Pq=3%XSIN1&DLzGGt$}Ygn>v%9@eaYburr)*(YuWk`6IO`-xqfPiX= z0QgwItJw0grt7j+sS@)dAPY##mlTNgmTVRPZ48#wF)8?>V%Rc^k04uC7{IuYA~Von zSawX3US*~az%(RDT~%RYSEFMT5tcF3Bd=SS^ycJxN|IjlBx?pcj>1yMV^aN(k~9Z0 z;jD(Mld{CL+d^|luId;lWTv+U2L+7)!62NgpsiJf1QcgZn4mI%2$Qm6)`)B_gX+eL zU1EnS*DPDMY%PC*T^Kt#lP?#wi#}B-%o>O^R%EMIEMnk>jo5<_ zoHr`~v2Yt~>T8xGrX$^KHK9XNxiv|pqKo(xGEf7PZU)N`)+SE@8x;WwwW{mlRDjp8 zX#~(q%~h~vDFByZm1uzF0)u?QJRlwGMv{HdCoC$^@BsQOn!wC%7{KM=wklPd#@rbd zt2y^zG?0gzVu0<&BKA^bm{o_EEPsy)atFn>XwbV}i@6AkQiIZsSMZO}JC;XEP6H_u zW<7vv09~EL*hb=s^~4kHM7sI%T_4(hoWBch%fG~$!C5&z+w{Ij5 zt|t$+4n5yazOa#;TTjkyBp23`3+?3T=3*z7fFL)v=N=No9hare(F1F5-7L39XD^33 zsXc24*Hf8$$QRmid7+csbq{gT9q)6U#Q5sa2hsP-5rm{OU|= z?+dqzx84LLPH_tYG_jWiz};|sBfN7xymR%9c6fht{;T*X!?@NwN8TQE{ZqmhS3B8*@%*8M$Gmohb5)2%M?ef!;a>~>qzFS?8UjNpXvG@AG=8f4%AmGXlx76~g`y6`OpIiDpTQ zYs#(-e1zO9aP8zl@Of~9nvP#2KLjG?F;*R_yFABnx6#BGD0Ca`yNxnmpzxQd)JCN* z(ZX%C`|l|HPP7voy6k%|bS2ap-rElDYX&+f-W1=Bza3u2~l)Q|tsH8-ell!1$W46&P;^_BDn3!U#9V-8&?5Bmatxa-*(p{{={&CJz7r literal 0 HcmV?d00001 diff --git a/RAG/database_handler.py b/RAG/database_handler.py new file mode 100644 index 0000000..003d138 --- /dev/null +++ b/RAG/database_handler.py @@ -0,0 +1,112 @@ +# database_handler.py +import os +import sqlite3 +import google.generativeai as genai +from dotenv import load_dotenv +load_dotenv() +api_key=os.getenv("GOOGLE_API_KEY") +genai.configure(api_key=api_key) +model = genai.GenerativeModel('gemini-1.5-flash') + +# System prompt for the Gemini model +system_prompt = """ +I have an sqlite database with the following tables and columns: + +Table name: RatePlan +Columns: +RatePlanId INTEGER PRIMARY KEY +Name VARCHAR(255) +MonthlyFee FLOAT +CallRate FLOAT +SmsRate FLOAT +DataRate FLOAT + + +Table name: Customer +Columns: +CustomerId INTEGER PRIMARY KEY +FirstName VARCHAR(255) +LastName VARCHAR(255) +Address VARCHAR(255) +City VARCHAR(255) +State VARCHAR(255) +Country VARCHAR(255) +PostalCode VARCHAR(255) +Phone VARCHAR(255) +Email VARCHAR(255) +RatePlanId INT +ContractStart DATE +ContractEnd DATE + +Foreign Keys: +Foreign key: RatePlanId references RatePlanId(NO ACTION) + +Table name: Phone +Columns: +PhoneId INTEGER PRIMARY KEY +Brand VARCHAR(255) +Model VARCHAR(255) +OS VARCHAR(255) +Price FLOAT + +Table name: CustomerPhone +Columns: +CustomerPhoneId INTEGER PRIMARY KEY +CustomerId INT +PhoneId INT +PhoneAcquisitionDate DATE + +Foreign Keys: +Foreign key: PhoneId references PhoneId(NO ACTION) +Foreign key: CustomerId references CustomerId(NO ACTION) + +Table name: CDR +Columns: +CdrId INTEGER PRIMARY KEY +CustomerId INT +PhoneNumber VARCHAR(255) +CallDateTime DATETIME +CallType VARCHAR(255) +DurationInSeconds INT +DataUsageKb INT +SmsCount INT + +Foreign Keys: +Foreign key: CustomerId references CustomerId(NO ACTION) + +I will need you to help me generate SQL queries to get data from my database. +Please respond only with the query in simple text format. Do not provide any explanations or additional text. + +If the user tries to modify the database respond with 'ERROR: cannot modify db' +""" + +# Initialize chat with system prompt +model = genai.GenerativeModel('gemini-1.5-flash') +chat = model.start_chat(history=[]) +chat.send_message(system_prompt) + +def generate_sql_query(prompt): + response = chat.send_message(prompt) + sql_query = response.text.strip() + if(sql_query=="ERROR: cannot modify db"):return "ERROR: cannot modify db" + sql_query = sql_query.replace('```sql', '').replace('```', '').strip() + print(sql_query) + return sql_query + +def fetch_data_from_db(sql_query, db_path='/media/tejas/b25dc664-2aec-424c-8f6c-f895bbec7e5d/Ericsson_RAG/call_db.sqlite'): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + + cursor.execute(sql_query) + results = cursor.fetchall() + columns = [description[0] for description in cursor.description] + conn.close() + return columns, results + +def format_results_as_table(columns, results): + table = [columns] + table.extend(results) + return table + + diff --git a/RAG/preprocessing.py b/RAG/preprocessing.py new file mode 100644 index 0000000..6eff7d1 --- /dev/null +++ b/RAG/preprocessing.py @@ -0,0 +1,29 @@ +import nltk +from nltk.tokenize import sent_tokenize +from PyPDF2 import PdfReader +from io import BytesIO +from docx import Document + +nltk.download('punkt') + +def extract_text_from_pdf(pdf_file): + reader = PdfReader(pdf_file) + text = '' + for page in reader.pages: + text += page.extract_text() + '\n' + return text + +def extract_text_from_txt(txt_file_path): + with open(txt_file_path, 'r', encoding='utf-8') as f: + text = f.read() + return text + +def extract_text_from_docx(docx_file): + doc = Document(docx_file) + paragraphs = [paragraph.text for paragraph in doc.paragraphs] + text = '\n'.join(paragraphs) + return text + +def chunk_text(text, chunk_size=5): + sentences = sent_tokenize(text) + return [' '.join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)] diff --git a/RAG/rag_pipeline.py b/RAG/rag_pipeline.py new file mode 100644 index 0000000..97182f0 --- /dev/null +++ b/RAG/rag_pipeline.py @@ -0,0 +1,64 @@ +from VectorDB.client import get_chroma_client +import os +import google.generativeai as genai +from chromadb.utils import embedding_functions +import uuid +from preprocessing import extract_text_from_pdf, chunk_text,extract_text_from_docx,extract_text_from_txt +from VectorDB.client import get_chroma_client, get_or_create_collections +from dotenv import load_dotenv +load_dotenv() +api_key=os.getenv("GOOGLE_API_KEY") +genai.configure(api_key=api_key) +model = genai.GenerativeModel('gemini-1.5-flash') +chat = model.start_chat(history=[]) + +def store_document_embeddings(document_path, collection_name='document', chunk_size=5, model_name='all-mpnet-base-v2'): + client = get_chroma_client() + + collection = get_or_create_collections(client, collection_name, model_name) + sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2') + _, file_extension = os.path.splitext(document_path) + if file_extension == '.pdf': + text = extract_text_from_pdf(document_path) + elif file_extension == '.docx': + text = extract_text_from_docx(document_path) + elif file_extension == '.txt': + text= extract_text_from_txt(document_path) + else: + raise ValueError(f"Unsupported file type: {file_extension}. Supported types are pdf, docx, txt.") + + chunks = chunk_text(text, chunk_size) + ids = [str(uuid.uuid4()) for _ in range(len(chunks))] + metadata = [{"document": document_path, "chunk": i} for i in range(len(chunks))] + + collection.add(ids=ids, documents=chunks, metadatas=metadata) + + +def retrieve_documents(query, collection): + + sentence_trans_ef=embedding_functions.SentenceTransformerEmbeddingFunction(model_name='all-mpnet-base-v2') + + results = collection.query(query_texts=[query], n_results=5,) + return results + +def generate_response(query, collection_name='document'): + client = get_chroma_client() + collections=client.list_collections() + expansion_prompt = f"Expand or transform the following query to include related keywords and phrases that will improve the chances of finding relevant text in a document database:\n\nQuery: {query}\n\nExpanded query:" + expanded_query = model.generate_content(expansion_prompt).text.strip() + context = "" + for collection in collections: + documents = retrieve_documents(expanded_query,collection) + for document in documents["documents"]: + for i in document : + context+=i + + + prompt = f"User query: {query}\n\nRelevant information:\n{context}\n\nBased on the above information, please provide a detailed response.Dont add on your own anything just stick to the info given and answer the query without any suggestions or recommendations" + + chats= chat.send_message(prompt) + # Generate the response in streaming mode + response = chats.text + + print(type(chat.history[1])) + return response diff --git a/RAG/system_prompt.py b/RAG/system_prompt.py new file mode 100644 index 0000000..89ff0e9 --- /dev/null +++ b/RAG/system_prompt.py @@ -0,0 +1,58 @@ +import sqlite3 + +def get_database_schema(db_path): + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Get all table names + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + + schema_details = [] + + for table in tables: + table_name = table[0] + cursor.execute(f"PRAGMA table_info({table_name});") + columns = cursor.fetchall() + + column_details = [] + for col in columns: + column_info = f"{col[1]} {col[2]}" + if col[5]: # Check if the column is a primary key + column_info += " PRIMARY KEY" + column_details.append(column_info) + + # Get foreign key information + cursor.execute(f"PRAGMA foreign_key_list({table_name});") + foreign_keys = cursor.fetchall() + foreign_key_details = [] + for fk in foreign_keys: + foreign_key_details.append( + f"Foreign key: {fk[3]} references {fk[4]}({fk[5]})" + ) + + schema_details.append(f"Table name: {table_name}\nColumns:\n" + "\n".join(column_details)) + if foreign_key_details: + schema_details.append("Foreign Keys:\n" + "\n".join(foreign_key_details)) + + conn.close() + + return "\n\n".join(schema_details) + +def generate_system_prompt(db_path): + schema_details = get_database_schema(db_path) + system_prompt = f"""I have an sqlite database with the following tables and columns: + +{schema_details} + +I will need you to help me generate SQL queries to get data from my database. +Please respond only with the query. Do not provide any explanations or additional text. + +If the user tries to modify the database respond with 'ERROR: cannot modify db' +""" + return system_prompt + +# Example usage +db_path = 'call_db.sqlite' # Replace with your actual database path +system_prompt = generate_system_prompt(db_path) +print(system_prompt) diff --git a/RAG/ui.py b/RAG/ui.py new file mode 100644 index 0000000..cdfb38e --- /dev/null +++ b/RAG/ui.py @@ -0,0 +1,134 @@ +import streamlit as st +from rag_pipeline import generate_response, store_document_embeddings +from preprocessing import extract_text_from_pdf, chunk_text +import os +import tempfile +import re +from database_handler import generate_sql_query,fetch_data_from_db,format_results_as_table +def sanitize_collection_name(name): + name = re.sub(r'[^\w-]', '', name) + + + if not name[0].isalnum(): + name = '_' + name + if not name[-1].isalnum(): + name = name + '_' + + name = name[:63] + + return name + +# Page title and configuration +st.set_page_config(page_title="Chat Interface", page_icon="🤖") + +# Initialize session state for chat history and uploaded file state +if 'history' not in st.session_state: + st.session_state.history = [] + +if 'uploaded_file' not in st.session_state: + st.session_state.uploaded_file = None + +if "messages" not in st.session_state: + st.session_state.messages=[] + +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + +# # Function to display chat messages in a chat message container style +# def display_chat_messages(messages): +# for idx, message in enumerate(messages): +# if 'role' in message and 'content' in message: +# if message['role'] == 'user': +# st.info(f"You: {message['content']}") +# elif message['role'] == 'bot': +# st.success(f"Bot: {message['content']}") + +# Sidebar for File Upload and Embeddings +with st.sidebar: + + st.title('Ericsson Chat Interface') + + # File uploader for PDFs, DOCX, and TXT files + uploaded_file = st.file_uploader("Upload a File", type=["pdf", "docx", "txt"]) + if uploaded_file is not None: + # Create a temporary directory if it doesn't exist + temp_dir = os.path.join('data', 'temp') + os.makedirs(temp_dir, exist_ok=True) + + # Save uploaded file to temporary directory + collection_name=sanitize_collection_name(uploaded_file.name) + temp_file_path = os.path.join(temp_dir, uploaded_file.name) + with open(temp_file_path, 'wb') as f: + f.write(uploaded_file.getbuffer()) + st.success("File uploaded successfully!") + + # Update session state with uploaded file + st.session_state.uploaded_file = temp_file_path + db_query=st.checkbox("DB Query") + +# Automatically process and store embeddings when a file is uploaded +if st.session_state.uploaded_file: + + store_document_embeddings(st.session_state.uploaded_file,collection_name=collection_name, chunk_size=5) + + os.remove(st.session_state.uploaded_file) + st.session_state.uploaded_file = None + +# Container for the conversation +conversation_area = st.empty() + +# User input for query (moved to sidebar) +with st.sidebar: + query = st.chat_input('Enter your query:', key='query_input') + + + # Handle user query submission + +if db_query and query: + + query_sql=generate_sql_query(query) + if(query_sql=="ERROR: cannot modify db"): + + with st.chat_message("user"): + st.markdown(query) + + # Add bot response to history + st.session_state.messages.append({'role': 'user', 'content': query}) + with st.chat_message("assistant"): + st.markdown(query_sql) + st.session_state.messages.append({"role":"assistant","content":query_sql}) + else: + columns,results=fetch_data_from_db(query_sql) + table=format_results_as_table(columns,results) + + + with st.chat_message("user"): + st.markdown(query) + + # Add bot response to history + st.session_state.messages.append({'role': 'user', 'content': query}) + with st.chat_message("assistant"): + st.markdown('Table') + st.table(table) + # st.session_state.messages.append({"role":"assistant","content":results}) + + + +elif query: + # Add user query to history + # st.session_state.history.append({'role': 'user', 'content': query}) + # Generate bot response + with st.chat_message("user"): + st.markdown(query) + + # Add bot response to history + st.session_state.messages.append({'role': 'user', 'content': query}) + response = generate_response(query) + with st.chat_message("assistant"): + st.markdown(response) + st.session_state.messages.append({"role":"assistant","content":response}) + + +# # Display updated chat messages +# display_chat_messages(st.session_state.history)