|
48 | 48 | "from beir.datasets.data_loader import GenericDataLoader\n", |
49 | 49 | "\n", |
50 | 50 | "dataset = \"fiqa\"\n", |
51 | | - "url = \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(dataset)\n", |
| 51 | + "url = (\n", |
| 52 | + " \"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip\".format(\n", |
| 53 | + " dataset\n", |
| 54 | + " )\n", |
| 55 | + ")\n", |
52 | 56 | "data_path = util.download_and_unzip(url, \"datasets\")" |
53 | 57 | ] |
54 | 58 | }, |
|
218 | 222 | "source": [ |
219 | 223 | "with open(os.path.join(data_path, \"corpus.jsonl\")) as f:\n", |
220 | 224 | " cs = [pd.Series(json.loads(l)) for l in f.readlines()]\n", |
221 | | - " \n", |
| 225 | + "\n", |
222 | 226 | "corpus_df = pd.DataFrame(cs)\n", |
223 | 227 | "corpus_df" |
224 | 228 | ] |
|
299 | 303 | } |
300 | 304 | ], |
301 | 305 | "source": [ |
302 | | - "corpus_df = corpus_df.rename(columns={\n", |
303 | | - " \"_id\": \"corpus-id\", \"text\": \"ground_truth\"\n", |
304 | | - "})\n", |
| 306 | + "corpus_df = corpus_df.rename(columns={\"_id\": \"corpus-id\", \"text\": \"ground_truth\"})\n", |
305 | 307 | "corpus_df = corpus_df.drop(columns=[\"title\", \"metadata\"])\n", |
306 | 308 | "corpus_df[\"corpus-id\"] = corpus_df[\"corpus-id\"].astype(int)\n", |
307 | 309 | "corpus_df.head()" |
|
387 | 389 | " qs = [pd.Series(json.loads(l)) for l in f.readlines()]\n", |
388 | 390 | "\n", |
389 | 391 | "queries_df = pd.DataFrame(qs)\n", |
390 | | - "queries_df = queries_df.rename(columns={\n", |
391 | | - " \"_id\": \"query-id\", \"text\": \"question\"\n", |
392 | | - "})\n", |
| 392 | + "queries_df = queries_df.rename(columns={\"_id\": \"query-id\", \"text\": \"question\"})\n", |
393 | 393 | "queries_df = queries_df.drop(columns=[\"metadata\"])\n", |
394 | 394 | "queries_df[\"query-id\"] = queries_df[\"query-id\"].astype(int)\n", |
395 | 395 | "queries_df.head()" |
|
474 | 474 | "splits = [\"dev\", \"test\", \"train\"]\n", |
475 | 475 | "split_df = {}\n", |
476 | 476 | "for s in splits:\n", |
477 | | - " split_df[s] = pd.read_csv(\n", |
478 | | - " os.path.join(data_path, f\"qrels/{s}.tsv\"), sep=\"\\t\"\n", |
479 | | - " ).drop(columns=[\"score\"])\n", |
480 | | - " \n", |
| 477 | + " split_df[s] = pd.read_csv(os.path.join(data_path, f\"qrels/{s}.tsv\"), sep=\"\\t\").drop(\n", |
| 478 | + " columns=[\"score\"]\n", |
| 479 | + " )\n", |
| 480 | + "\n", |
481 | 481 | "split_df[\"dev\"].head()" |
482 | 482 | ] |
483 | 483 | }, |
|
515 | 515 | " df = queries_df.merge(split_df[split], on=\"query-id\")\n", |
516 | 516 | " df = df.merge(corpus_df, on=\"corpus-id\")\n", |
517 | 517 | " df = df.drop(columns=[\"corpus-id\"])\n", |
518 | | - " grouped = df.groupby('query-id').apply(lambda x: pd.Series({\n", |
519 | | - " 'question': x['question'].sample().values[0],\n", |
520 | | - " 'ground_truths': x['ground_truth'].tolist()\n", |
521 | | - " }))\n", |
| 518 | + " grouped = df.groupby(\"query-id\").apply(\n", |
| 519 | + " lambda x: pd.Series(\n", |
| 520 | + " {\n", |
| 521 | + " \"question\": x[\"question\"].sample().values[0],\n", |
| 522 | + " \"ground_truths\": x[\"ground_truth\"].tolist(),\n", |
| 523 | + " }\n", |
| 524 | + " )\n", |
| 525 | + " )\n", |
522 | 526 | "\n", |
523 | 527 | " grouped = grouped.reset_index()\n", |
524 | 528 | " grouped = grouped.drop(columns=\"query-id\")\n", |
|
797 | 801 | "assert os.path.exists(path_to_ds_repo), f\"{path_to_ds_repo} doesnot exist!\"\n", |
798 | 802 | "\n", |
799 | 803 | "for s in final_split_df:\n", |
800 | | - " final_split_df[s].to_csv(\n", |
801 | | - " os.path.join(path_to_ds_repo, f\"{s}.csv\"),\n", |
802 | | - " index=False\n", |
803 | | - " )\n", |
804 | | - " \n", |
| 804 | + " final_split_df[s].to_csv(os.path.join(path_to_ds_repo, f\"{s}.csv\"), index=False)\n", |
| 805 | + "\n", |
805 | 806 | "corpus_df.to_csv(os.path.join(path_to_ds_repo, \"corpus.csv\"), index=False)" |
806 | 807 | ] |
807 | 808 | }, |
|
1009 | 1010 | "from llama_index.node_parser import SimpleNodeParser\n", |
1010 | 1011 | "from langchain.text_splitter import TokenTextSplitter\n", |
1011 | 1012 | "\n", |
1012 | | - "spliter = TokenTextSplitter(\n", |
1013 | | - " chunk_size = 100,\n", |
1014 | | - " chunk_overlap = 50\n", |
1015 | | - ")\n", |
| 1013 | + "spliter = TokenTextSplitter(chunk_size=100, chunk_overlap=50)\n", |
1016 | 1014 | "\n", |
1017 | | - "parser = SimpleNodeParser(\n", |
1018 | | - " text_splitter=spliter\n", |
1019 | | - ")\n", |
| 1015 | + "parser = SimpleNodeParser(text_splitter=spliter)\n", |
1020 | 1016 | "\n", |
1021 | | - "nodes = parser.get_nodes_from_documents(\n", |
1022 | | - " documents=docs\n", |
1023 | | - ")" |
| 1017 | + "nodes = parser.get_nodes_from_documents(documents=docs)" |
1024 | 1018 | ] |
1025 | 1019 | }, |
1026 | 1020 | { |
|
1088 | 1082 | "source": [ |
1089 | 1083 | "# create index\n", |
1090 | 1084 | "index = GPTVectorStoreIndex.from_documents(\n", |
1091 | | - " documents=docs, \n", |
| 1085 | + " documents=docs,\n", |
1092 | 1086 | " service_context=openai_sc,\n", |
1093 | 1087 | ")\n", |
1094 | 1088 | "\n", |
1095 | 1089 | "# query with embed_model specified\n", |
1096 | | - "qe = index.as_query_engine(\n", |
1097 | | - " mode=\"embedding\", \n", |
1098 | | - " verbose=True, \n", |
1099 | | - " service_context=openai_sc\n", |
1100 | | - ")" |
| 1090 | + "qe = index.as_query_engine(mode=\"embedding\", verbose=True, service_context=openai_sc)" |
1101 | 1091 | ] |
1102 | 1092 | }, |
1103 | 1093 | { |
|
1171 | 1161 | "\n", |
1172 | 1162 | "# query with embed_model specified\n", |
1173 | 1163 | "qe = index.as_query_engine(\n", |
1174 | | - " mode=\"embedding\", \n", |
1175 | | - " verbose=True, \n", |
1176 | | - " service_context=openai_sc,\n", |
1177 | | - " use_async = False\n", |
| 1164 | + " mode=\"embedding\", verbose=True, service_context=openai_sc, use_async=False\n", |
1178 | 1165 | ")" |
1179 | 1166 | ] |
1180 | 1167 | }, |
|
1195 | 1182 | "\n", |
1196 | 1183 | "# configure retriever\n", |
1197 | 1184 | "retriever = VectorIndexRetriever(\n", |
1198 | | - " index=index, \n", |
| 1185 | + " index=index,\n", |
1199 | 1186 | " similarity_top_k=3,\n", |
1200 | 1187 | ")\n", |
1201 | 1188 | "\n", |
1202 | 1189 | "# configure response synthesizer\n", |
1203 | 1190 | "response_synthesizer = ResponseSynthesizer.from_args(\n", |
1204 | | - " node_postprocessors=[\n", |
1205 | | - " SimilarityPostprocessor(similarity_cutoff=0.7)\n", |
1206 | | - " ]\n", |
| 1191 | + " node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)]\n", |
1207 | 1192 | ")\n", |
1208 | 1193 | "\n", |
1209 | 1194 | "# assemble query engine\n", |
|
1257 | 1242 | " r = qe.query(row[\"question\"])\n", |
1258 | 1243 | " row[\"answer\"] = r.response\n", |
1259 | 1244 | " row[\"contexts\"] = [sn.node.text for sn in r.source_nodes]\n", |
1260 | | - " \n", |
| 1245 | + "\n", |
1261 | 1246 | " return row\n", |
1262 | 1247 | "\n", |
| 1248 | + "\n", |
1263 | 1249 | "# generate_response(test_ds[0])" |
1264 | 1250 | ] |
1265 | 1251 | }, |
|
1530 | 1516 | "from ragas.metrics import factuality, answer_relevancy, context_relevancy\n", |
1531 | 1517 | "from ragas import evaluate\n", |
1532 | 1518 | "\n", |
1533 | | - "evaluate(\n", |
1534 | | - " gen_ds, \n", |
1535 | | - " metrics=[factuality, answer_relevancy, context_relevancy]\n", |
1536 | | - ")" |
| 1519 | + "evaluate(gen_ds, metrics=[factuality, answer_relevancy, context_relevancy])" |
1537 | 1520 | ] |
1538 | 1521 | }, |
1539 | 1522 | { |
|
0 commit comments