Skip to content

Commit 8dd4fbf

Browse files
committed
feat: TREC adhoc runs read/write
1 parent 8e541b7 commit 8dd4fbf

File tree

4 files changed

+45
-9
lines changed

4 files changed

+45
-9
lines changed

src/datamaestro_text/data/ir/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from pathlib import Path
77
from attrs import define
8-
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type
8+
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, TYPE_CHECKING
99
import random
1010
from experimaestro import Config
1111
from datamaestro.definitions import datatasks, Param, Meta
@@ -29,6 +29,9 @@
2929
AdhocAssessedTopic,
3030
)
3131

32+
#: A adhoc run dictionary (query id -> doc id -> score)
33+
AdhocRunDict = dict[str, dict[str, float]]
34+
3235

3336
class Documents(Base):
3437
"""A set of documents with identifiers
@@ -185,7 +188,10 @@ def iter(self) -> Iterator[AdhocAssessedTopic]:
185188
class AdhocRun(Base):
186189
"""IR adhoc run"""
187190

188-
pass
191+
@abstractmethod
192+
def get_dict(self) -> "AdhocRunDict":
193+
"""Get the run as a dictionary query ID -> doc ID -> score"""
194+
...
189195

190196

191197
class AdhocResults(Base):

src/datamaestro_text/data/ir/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class AdhocAssessedTopic:
7171
"""List of assessments for this topic"""
7272

7373

74-
def create_record(*items: Item, id: str = None, text: str = None):
74+
def create_record(*items: Item, id: str = None, text: str = None) -> Record:
7575
"""Easy creation of a text/id item"""
7676
extra_items = []
7777
if id is not None:

src/datamaestro_text/data/ir/trec.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
import re
12
from typing import Dict, List, Optional
2-
from datamaestro.data import Base
33
from experimaestro import documentation, Param, Meta
44
from pathlib import Path
5-
from datamaestro.record import Record
65
from datamaestro_text.data.ir import (
6+
AdhocRunDict,
77
Documents,
88
Topics,
99
AdhocAssessments,
@@ -47,6 +47,11 @@ def iter(self):
4747
class TrecAdhocRun(AdhocRun):
4848
path: Param[Path]
4949

50+
def get_dict(self) -> AdhocRunDict:
51+
import datamaestro_text.interfaces.trec as trec
52+
53+
return trec.parse_run(self.path)
54+
5055

5156
class TrecAdhocResults(AdhocResults):
5257
"""Adhoc results (TREC format)"""
@@ -62,8 +67,6 @@ class TrecAdhocResults(AdhocResults):
6267

6368
def get_results(self) -> Dict[str, float]:
6469
"""Returns the results as a dictionary {metric_name: value}"""
65-
import re
66-
6770
re_spaces = re.compile(r"\s+")
6871

6972
results = {}

src/datamaestro_text/interfaces/trec.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from attrs import define
21
from pathlib import Path
32
from typing import Iterator, Optional
43
import re
4+
from datamaestro_text.data.ir import AdhocRunDict
55
from datamaestro_text.data.ir.base import (
66
AdhocAssessedTopic,
77
TopicRecord,
@@ -10,6 +10,33 @@
1010
)
1111
from datamaestro_text.data.ir.formats import TrecTopicRecord, TrecTopic
1212

13+
# --- Runs
14+
15+
16+
def parse_run(path: Path) -> AdhocRunDict:
17+
results = {}
18+
with path.open("rt") as f:
19+
for line in f:
20+
query_id, _q0, doc_id, _rank, score, _model_id = re.split(
21+
r"\s+", line.strip()
22+
)
23+
results.setdefault(query_id, {})[doc_id] = score
24+
25+
return results
26+
27+
28+
def write_run_dict(run: AdhocRunDict, run_path: Path):
29+
"""Write run dict"""
30+
with run_path.open("wt") as f:
31+
for query_id, scored_documents in run.items():
32+
scored_documents = list(
33+
[(doc_id, score) for doc_id, score in scored_documents.items()]
34+
)
35+
scored_documents.sort(key=lambda x: x[1], reverse=True)
36+
for ix, (doc_id, score) in enumerate(scored_documents):
37+
f.write(f"{query_id} Q0 {doc_id} {ix + 1} {score} run\n")
38+
39+
1340
# --- Assessments
1441

1542

@@ -59,7 +86,7 @@ def parse_query_format(file, xml_prefix=None) -> Iterator[TopicRecord]:
5986
num = line[len("<num>") :].replace("Number:", "").strip()
6087
reading = None
6188
elif line.startswith(f"<{xml_prefix}title>"):
62-
title = line[len(f"<{xml_prefix}title>") :].strip()
89+
title = line[len(f"<{xml_prefix}title>") : ].strip()
6390
if title == "":
6491
reading = "title"
6592
else:

0 commit comments

Comments
 (0)