|
106 | 106 | "source": [ |
107 | 107 | "import os\n", |
108 | 108 | "import openai\n", |
| 109 | + "\n", |
109 | 110 | "openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n", |
110 | 111 | "\n", |
111 | 112 | "completion = openai.ChatCompletion.create(\n", |
112 | | - " model=\"gpt-3.5-turbo\",\n", |
113 | | - " messages=[\n", |
114 | | - " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", |
115 | | - " ]\n", |
| 113 | + " model=\"gpt-3.5-turbo\",\n", |
| 114 | + " messages=[\n", |
| 115 | + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", |
| 116 | + " ],\n", |
116 | 117 | ")\n", |
117 | 118 | "\n", |
118 | | - "print(completion.choices[0].message)\n" |
| 119 | + "print(completion.choices[0].message)" |
119 | 120 | ] |
120 | 121 | }, |
121 | 122 | { |
|
125 | 126 | "metadata": {}, |
126 | 127 | "outputs": [], |
127 | 128 | "source": [ |
128 | | - "\n", |
129 | 129 | "def llm2(prompt, **kwargs):\n", |
130 | 130 | " response = openai.ChatCompletion.create(\n", |
131 | | - " model=kwargs.get(\"model\",\"gpt-3.5-turbo-16k\"),\n", |
132 | | - " messages=[{\"role\": \"system\", \"content\":prompt}],\n", |
| 131 | + " model=kwargs.get(\"model\", \"gpt-3.5-turbo-16k\"),\n", |
| 132 | + " messages=[{\"role\": \"system\", \"content\": prompt}],\n", |
133 | 133 | " temperature=kwargs.get(\"temperature\", 0),\n", |
134 | 134 | " top_p=kwargs.get(\"top_p\", 1),\n", |
135 | 135 | " frequency_penalty=kwargs.get(\"frequency_penalty\", 0.0),\n", |
|
139 | 139 | " )\n", |
140 | 140 | " return response\n", |
141 | 141 | "\n", |
| 142 | + "\n", |
142 | 143 | "def llm(prompt, **kwargs):\n", |
143 | 144 | " response = openai.Completion.create(\n", |
144 | 145 | " model=kwargs.get(\"model\", \"text-davinci-003\"),\n", |
|
375 | 376 | } |
376 | 377 | ], |
377 | 378 | "source": [ |
378 | | - "llm2([Question_generation.format(2,answer)])" |
| 379 | + "llm2([Question_generation.format(2, answer)])" |
379 | 380 | ] |
380 | 381 | }, |
381 | 382 | { |
|
1039 | 1040 | ], |
1040 | 1041 | "source": [ |
1041 | 1042 | "def get_all_facts(item):\n", |
1042 | | - " all_facts = item['context']['sentences']\n", |
| 1043 | + " all_facts = item[\"context\"][\"sentences\"]\n", |
1043 | 1044 | " all_facts = [sent for para in all_facts for sent in para]\n", |
1044 | | - " return {\"full_context\":''.join(all_facts)}\n", |
1045 | | - "hotpot_qa = hotpot_qa.map(get_all_facts, batched=False) " |
| 1045 | + " return {\"full_context\": \"\".join(all_facts)}\n", |
| 1046 | + "\n", |
| 1047 | + "\n", |
| 1048 | + "hotpot_qa = hotpot_qa.map(get_all_facts, batched=False)" |
1046 | 1049 | ] |
1047 | 1050 | }, |
1048 | 1051 | { |
|
1090 | 1093 | "metadata": {}, |
1091 | 1094 | "outputs": [], |
1092 | 1095 | "source": [ |
1093 | | - "i=15\n", |
1094 | | - "q,c = hotpot_qa[i]['question'],hotpot_qa[i]['full_context']" |
| 1096 | + "i = 15\n", |
| 1097 | + "q, c = hotpot_qa[i][\"question\"], hotpot_qa[i][\"full_context\"]" |
1095 | 1098 | ] |
1096 | 1099 | }, |
1097 | 1100 | { |
|
1112 | 1115 | "outputs": [], |
1113 | 1116 | "source": [ |
1114 | 1117 | "q = \"what is general relativity?\"\n", |
1115 | | - "n=2" |
| 1118 | + "n = 2" |
1116 | 1119 | ] |
1117 | 1120 | }, |
1118 | 1121 | { |
|
1123 | 1126 | "outputs": [], |
1124 | 1127 | "source": [ |
1125 | 1128 | "import wikipediaapi\n", |
| 1129 | + "\n", |
1126 | 1130 | "wiki_wiki = wikipediaapi.Wikipedia(\n", |
1127 | | - " language='en',\n", |
1128 | | - " extract_format=wikipediaapi.ExtractFormat.WIKI\n", |
| 1131 | + " language=\"en\", extract_format=wikipediaapi.ExtractFormat.WIKI\n", |
1129 | 1132 | ")\n", |
1130 | 1133 | "\n", |
1131 | 1134 | "p_wiki = wiki_wiki.page(\"Black hole\")\n", |
1132 | 1135 | "\n", |
| 1136 | + "\n", |
1133 | 1137 | "def get_page_section(page, section):\n", |
1134 | 1138 | " all_text = \"\"\n", |
1135 | 1139 | " p_wiki = wiki_wiki.page(page)\n", |
1136 | 1140 | " sections = p_wiki.sections_by_title(section)\n", |
1137 | 1141 | " for s in sections:\n", |
1138 | 1142 | " all_text += s.full_text()\n", |
1139 | | - " return all_text\n" |
| 1143 | + " return all_text" |
1140 | 1144 | ] |
1141 | 1145 | }, |
1142 | 1146 | { |
|
1152 | 1156 | "\n", |
1153 | 1157 | "cross_encoder = CrossEncoder(\"cross-encoder/stsb-TinyBERT-L-4\")\n", |
1154 | 1158 | "\n", |
1155 | | - " \n", |
| 1159 | + "\n", |
1156 | 1160 | "def sent_tokenize(sent):\n", |
1157 | | - " return [s[:-1] if s.endswith('.') else s for s in sent.strip().split('. ')]\n", |
| 1161 | + " return [s[:-1] if s.endswith(\".\") else s for s in sent.strip().split(\". \")]\n", |
| 1162 | + "\n", |
1158 | 1163 | "\n", |
1159 | 1164 | "class SentenceAgreement:\n", |
1160 | | - " \n", |
1161 | 1165 | " def __init__(self, scoring=\"bert_score\"):\n", |
1162 | | - " \n", |
1163 | 1166 | " self.scoring = scoring\n", |
1164 | 1167 | "\n", |
1165 | | - " \n", |
1166 | 1168 | " @staticmethod\n", |
1167 | 1169 | " def bert_score(para1, para2):\n", |
1168 | | - " \n", |
1169 | 1170 | " sentences1, sentences2 = sent_tokenize(para1), sent_tokenize(para2)\n", |
1170 | 1171 | " scores = cross_encoder.predict(list(itertools.product(sentences1, sentences2)))\n", |
1171 | 1172 | " scores = scores.reshape(len(sentences1), len(sentences2))\n", |
1172 | 1173 | " return scores.max(axis=1).mean()\n", |
1173 | 1174 | "\n", |
1174 | 1175 | " @staticmethod\n", |
1175 | 1176 | " def jaccard_score(para1, para2):\n", |
1176 | | - " \n", |
1177 | 1177 | " sentences1, sentences2 = sent_tokenize(para1), sent_tokenize(para2)\n", |
1178 | 1178 | " intersect = len(np.intersect1d(sentences1, sentences2))\n", |
1179 | 1179 | " union = len(np.union1d(sentences1, sentences2))\n", |
1180 | | - " return intersect/union\n", |
1181 | | - " \n", |
1182 | | - " def evaluate(self,answers:List[List[str]]):\n", |
1183 | | - " \n", |
| 1180 | + " return intersect / union\n", |
| 1181 | + "\n", |
| 1182 | + " def evaluate(self, answers: List[List[str]]):\n", |
1184 | 1183 | " \"\"\"\n", |
1185 | 1184 | " eval nC2 combinations\n", |
1186 | 1185 | " \"\"\"\n", |
1187 | 1186 | " scores = []\n", |
1188 | | - " groups = combinations(answers,2)\n", |
| 1187 | + " groups = combinations(answers, 2)\n", |
1189 | 1188 | " for group in groups:\n", |
1190 | 1189 | " if self.scoring == \"jaccard\":\n", |
1191 | 1190 | " score = self.jaccard_score(*group)\n", |
1192 | 1191 | " elif self.scoring == \"bert_score\":\n", |
1193 | 1192 | " score = self.bert_score(*group)\n", |
1194 | 1193 | " scores.append(score)\n", |
1195 | | - " return np.mean(scores)\n", |
1196 | | - " " |
| 1194 | + " return np.mean(scores)" |
1197 | 1195 | ] |
1198 | 1196 | }, |
1199 | 1197 | { |
|
1204 | 1202 | "outputs": [], |
1205 | 1203 | "source": [ |
1206 | 1204 | "class ContextRelevacy:\n", |
1207 | | - " \n", |
1208 | | - " def __init__(self, strictness = 2, agreement_metric=\"bert_score\"):\n", |
1209 | | - " \n", |
| 1205 | + " def __init__(self, strictness=2, agreement_metric=\"bert_score\"):\n", |
1210 | 1206 | " self.strictness = strictness\n", |
1211 | 1207 | " self.sent_agreement = SentenceAgreement(agreement_metric)\n", |
1212 | | - " \n", |
1213 | | - " def score(self,question,context):\n", |
| 1208 | + "\n", |
| 1209 | + " def score(self, question, context):\n", |
1214 | 1210 | " scores = []\n", |
1215 | | - " outputs = llm(Context_relevency.format(q,c),n=self.strictness,temperature=1)\n", |
1216 | | - " outputs = [outputs['choices'][i]['text'].strip() for i in range(self.strictness)]\n", |
| 1211 | + " outputs = llm(Context_relevency.format(q, c), n=self.strictness, temperature=1)\n", |
| 1212 | + " outputs = [\n", |
| 1213 | + " outputs[\"choices\"][i][\"text\"].strip() for i in range(self.strictness)\n", |
| 1214 | + " ]\n", |
1217 | 1215 | " context_sents = sent_tokenize(context)\n", |
1218 | 1216 | " for output in outputs:\n", |
1219 | | - " indices = [context.find(sent) for sent in sent_tokenize(output) if context.find(sent)!=-1]\n", |
1220 | | - " scores.append(len(indices)/len(context_sents))\n", |
1221 | | - " \n", |
| 1217 | + " indices = [\n", |
| 1218 | + " context.find(sent)\n", |
| 1219 | + " for sent in sent_tokenize(output)\n", |
| 1220 | + " if context.find(sent) != -1\n", |
| 1221 | + " ]\n", |
| 1222 | + " scores.append(len(indices) / len(context_sents))\n", |
| 1223 | + "\n", |
1222 | 1224 | " if self.strictness > 1:\n", |
1223 | 1225 | " agr_score = self.sent_agreement.evaluate(outputs)\n", |
1224 | 1226 | " else:\n", |
1225 | | - " agr_score =1 \n", |
1226 | | - " return agr_score * np.mean(scores)\n" |
| 1227 | + " agr_score = 1\n", |
| 1228 | + " return agr_score * np.mean(scores)" |
1227 | 1229 | ] |
1228 | 1230 | }, |
1229 | 1231 | { |
|
1234 | 1236 | "outputs": [], |
1235 | 1237 | "source": [ |
1236 | 1238 | "c = get_page_section(\"HIV/AIDS\", \"Prevention\")\n", |
1237 | | - "c = ' '.join(c.split(' ')[:500])\n", |
| 1239 | + "c = \" \".join(c.split(\" \")[:500])\n", |
1238 | 1240 | "q = \"When was the first HIV case detected?\"" |
1239 | 1241 | ] |
1240 | 1242 | }, |
|
1245 | 1247 | "metadata": {}, |
1246 | 1248 | "outputs": [], |
1247 | 1249 | "source": [ |
1248 | | - "output = llm([Context_relevency.format(q,c), Context_relevency.format(\"How to prevent AIDS?\",c)],n=n,temperature=1)" |
| 1250 | + "output = llm(\n", |
| 1251 | + " [\n", |
| 1252 | + " Context_relevency.format(q, c),\n", |
| 1253 | + " Context_relevency.format(\"How to prevent AIDS?\", c),\n", |
| 1254 | + " ],\n", |
| 1255 | + " n=n,\n", |
| 1256 | + " temperature=1,\n", |
| 1257 | + ")" |
1249 | 1258 | ] |
1250 | 1259 | }, |
1251 | 1260 | { |
|
1397 | 1406 | } |
1398 | 1407 | ], |
1399 | 1408 | "source": [ |
1400 | | - "context_relevancy.score(dataset[\"baseline\"].select(range(0,3)))" |
| 1409 | + "context_relevancy.score(dataset[\"baseline\"].select(range(0, 3)))" |
1401 | 1410 | ] |
1402 | 1411 | }, |
1403 | 1412 | { |
|
1491 | 1500 | } |
1492 | 1501 | ], |
1493 | 1502 | "source": [ |
1494 | | - "context_relevancy.score(dataset[\"baseline\"].select(range(0,3)))" |
| 1503 | + "context_relevancy.score(dataset[\"baseline\"].select(range(0, 3)))" |
1495 | 1504 | ] |
1496 | 1505 | }, |
1497 | 1506 | { |
|
0 commit comments