Skip to content

Commit a684230

Browse files
committed
add source
1 parent 30e6a32 commit a684230

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+9016
-0
lines changed

entity_linkings/__init__.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import os
2+
from typing import Optional, Union
3+
4+
import datasets
5+
from datasets import Dataset, DatasetDict
6+
7+
from .dataset import DATASET_ID2CLS
8+
from .entity_dictionary import DICTIONARY_ID2CLS, EntityDictionaryBase
9+
from .models import (
10+
ED_ID2CLS,
11+
EL_ID2CLS,
12+
RETRIEVER_ID2CLS,
13+
EntityRetrieverBase,
14+
PipelineBase,
15+
)
16+
17+
18+
def get_dataset_ids() -> list[str]:
19+
'''Generate a list of ids with the class name in lower case.
20+
'''
21+
ids = list(DATASET_ID2CLS.keys())
22+
return ids
23+
24+
25+
def get_dictionary_ids() -> list[str]:
26+
'''Generate a list of ids with the class name in lower case.
27+
'''
28+
ids = list(DICTIONARY_ID2CLS.keys())
29+
return ids
30+
31+
32+
def get_retriever_ids() -> list[str]:
33+
'''Generate a list of ids with the class name in lower case.
34+
'''
35+
ids = list(RETRIEVER_ID2CLS.keys())
36+
return ids
37+
38+
39+
def get_el_ids() -> list[str]:
40+
'''Generate a list of ids with the class name in lower case.
41+
'''
42+
ids = list(EL_ID2CLS.keys())
43+
return ids
44+
45+
46+
def get_ed_ids() -> list[str]:
47+
'''Generate a list of ids with the class name in lower case.
48+
'''
49+
ids = list(ED_ID2CLS.keys())
50+
return ids
51+
52+
53+
def get_model_ids() -> list[str]:
54+
'''Generate a list of ids with the class name in lower case.
55+
'''
56+
ids = list(RETRIEVER_ID2CLS.keys()) + list(EL_ID2CLS.keys()) + list(ED_ID2CLS.keys())
57+
return ids
58+
59+
60+
def load_dataset(
61+
name: str = "json",
62+
data_files: Optional[Union[str, dict[str, str]]] = None,
63+
split: Optional[str] = None,
64+
cache_dir: Optional[str] = None
65+
) -> Union[DatasetDict, Dataset]:
66+
'''Generate a dataset class with the class name in lower case as the key.
67+
If the name is not found, use the custom dataset class.
68+
For custom dataset, data_files must be provided.
69+
'''
70+
if name == "json":
71+
if not data_files:
72+
raise ValueError("Either name or data_files must be provided.")
73+
dataset = datasets.load_dataset("json", data_files=data_files, cache_dir=cache_dir)
74+
else:
75+
subset = str(name.split('-')[1]) if '-' in name else None
76+
name = name.split('-')[0]
77+
if name not in get_dataset_ids():
78+
raise ValueError(f"The id should be one of {get_dataset_ids()}.")
79+
dataset_cls = DATASET_ID2CLS[name]
80+
if subset:
81+
dataset_cls(config_name=subset, cache_dir=cache_dir).download_and_prepare()
82+
dataset = dataset_cls(config_name=subset, cache_dir=cache_dir).as_dataset()
83+
else:
84+
dataset_cls(cache_dir=cache_dir).download_and_prepare()
85+
dataset = dataset_cls(cache_dir=cache_dir).as_dataset()
86+
87+
if split is not None:
88+
return dataset[split]
89+
return dataset
90+
91+
92+
def load_dictionary(
93+
dictionary_name_or_path: str,
94+
nil_id: str = "-1",
95+
nil_name: str = "[NIL]",
96+
nil_description: str = "[NIL] is an entity that does not exist in this dictionary.",
97+
default_description: str = """{name} is an entity in this dictionary.""",
98+
cache_dir: Optional[str|os.PathLike] = None,
99+
) -> EntityDictionaryBase:
100+
'''Generate a dictionary of ids and classes with the class name in lower case as the key.
101+
'''
102+
if os.path.isfile(dictionary_name_or_path):
103+
dictionary = datasets.load_dataset("json", data_files=dictionary_name_or_path, cache_dir=cache_dir, split="train")
104+
else:
105+
if dictionary_name_or_path not in get_dictionary_ids():
106+
raise ValueError(f"The id should be one of {get_dictionary_ids()}.")
107+
dictionary_cls = DICTIONARY_ID2CLS[dictionary_name_or_path]
108+
dictionary_cls(cache_dir=cache_dir).download_and_prepare()
109+
dictionary = dictionary_cls(cache_dir=cache_dir).as_dataset()['dictionary']
110+
111+
return EntityDictionaryBase(dictionary=dictionary, config=EntityDictionaryBase.Config(
112+
nil_id=nil_id,
113+
nil_name=nil_name,
114+
nil_description=nil_description,
115+
default_description=default_description,
116+
cache_dir=cache_dir,
117+
))
118+
119+
120+
def get_ed_models(name: str) -> type[PipelineBase]:
121+
'''Generate a dictionary of ids and classes with the class name in lower case as the key.
122+
'''
123+
if name not in get_ed_ids():
124+
raise ValueError(f"The id should be one of {get_el_ids()}.")
125+
return ED_ID2CLS[name]
126+
127+
128+
def get_el_models(name: str) -> type[PipelineBase]:
129+
'''Generate a dictionary of ids and classes with the class name in lower case as the key.
130+
'''
131+
if name not in get_el_ids():
132+
raise ValueError(f"The id should be one of {get_el_ids()}.")
133+
return EL_ID2CLS[name]
134+
135+
136+
def get_retrievers(name: str) -> type[EntityRetrieverBase]:
137+
'''Generate a retriever model class.
138+
If without_span is True, use SentenceRetrieval class.
139+
Otherwise, use Retrieval class.
140+
'''
141+
if name not in get_retriever_ids():
142+
raise ValueError(f"The id should be one of {get_retriever_ids()}.")
143+
return RETRIEVER_ID2CLS[name]
144+
145+
146+
def get_models(name: str) -> type[PipelineBase]:
147+
'''Generate a dictionary of ids and classes with the class name in lower case as the key.
148+
'''
149+
if name in get_el_ids():
150+
return EL_ID2CLS[name]
151+
elif name in get_ed_ids():
152+
return ED_ID2CLS[name]
153+
else:
154+
raise ValueError(f"The id should be one of {get_retriever_ids() + get_el_ids() + get_ed_ids()}.")

entity_linkings/cli/evaluate.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import json
2+
import logging
3+
import os
4+
from argparse import ArgumentParser, Namespace
5+
6+
import torch
7+
8+
from entity_linkings import get_models, get_retrievers, load_dataset, load_dictionary
9+
from entity_linkings.utils import read_yaml
10+
11+
logger = logging.getLogger(__name__)
12+
13+
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
14+
15+
def evaluate(args: Namespace) -> None:
16+
dictionary = load_dictionary(args.dictionary_id_or_path, cache_dir=args.cache_dir)
17+
dataset_id = args.dataset_id if args.dataset_id else "json"
18+
if dataset_id != "json":
19+
test_dataset = load_dataset(dataset_id, split='test', cache_dir=args.cache_dir)
20+
else:
21+
test_dataset = load_dataset("json", data_files={"test": args.test_file}, cache_dir=args.cache_dir)['test']
22+
if args.remove_nil:
23+
from entity_linkings.data_utils import filter_nil_entities
24+
test_dataset = filter_nil_entities(test_dataset, dictionary)
25+
26+
if args.retriever_config is not None:
27+
retriever_config = read_yaml(args.retriever_config)[args.retriever_id.lower()]
28+
else:
29+
retriever_config = {}
30+
31+
if args.model_config is not None:
32+
model_config = read_yaml(args.model_config).get(args.model_id, {})
33+
else:
34+
model_config = {}
35+
if args.model_name_or_path is not None:
36+
model_config["model_name_or_path"] = args.model_name_or_path
37+
38+
if args.wandb:
39+
import wandb
40+
wandb.init(
41+
project=os.environ.get("WANDB_PROJECT", "entity_linking_benchmark"),
42+
name=args.model_id, tags=["evaluation"]
43+
)
44+
wandb.log({
45+
"model_type": args.model_type,
46+
"retriever_id": args.retriever_id,
47+
"model_name_or_path": model_config.get("model_name_or_path", None),
48+
"retriever_model_name_or_path": retriever_config.get("model_name_or_path", None),
49+
"remove_nil": args.remove_nil,
50+
"dataset_id": dataset_id,
51+
"test_file": args.test_file,
52+
"dictionary_id_or_path": args.dictionary_id_or_path,
53+
})
54+
55+
retriever_cls = get_retrievers(args.retriever_id)
56+
retriever = retriever_cls(dictionary, config=retriever_cls.Config(**retriever_config))
57+
model_cls = get_models(args.model_id)
58+
model = model_cls(retriever, config=model_cls.Config(**model_config))
59+
60+
metrics = model.evaluate(test_dataset, num_candidates=args.num_candidates, batch_size=args.test_batch_size)
61+
logger.info(f"Evaluation results: {metrics}")
62+
if args.output_dir is not None:
63+
os.makedirs(args.output_dir, exist_ok=True)
64+
output_path = f"{args.output_dir}/eval_results.json"
65+
with open(output_path, "w") as f:
66+
json.dump(metrics, f, indent=4)
67+
logger.info(f"Saved evaluation results to {output_path}")
68+
if args.wandb:
69+
for key, value in metrics.items():
70+
wandb.log({key: value})
71+
72+
73+
def cli_main() -> None:
74+
parser = ArgumentParser()
75+
parser.add_argument('--model_type', type=str, default='ed', help='Task to perform. "ed" (entity disambiguation) and "el" (entity linking) are supported.')
76+
parser.add_argument('--model_id', type=str, required=True, help='Name of the model to use.')
77+
parser.add_argument('--model_name_or_path', type=str, default=None, help='Name of the model to use.')
78+
parser.add_argument('--retriever_id', type=str, required=True, help='Name of the retriever model to use.')
79+
parser.add_argument('--retriever_model_name_or_path', type=str, default=None, help='Name of the retriever model to use.')
80+
parser.add_argument('--dictionary_id_or_path', '-d', type=str, default=None, help='Path to the entity dictionary file.')
81+
parser.add_argument('--dataset_id', '-D', type=str, default=None, help='Name of the dataset to use.')
82+
parser.add_argument('--test_file', type=str, default=None, help='Path to the dataset file.')
83+
parser.add_argument('--num_candidates', type=int, default=5, help='Number of candidate entities to consider during evaluation.')
84+
parser.add_argument('--test_batch_size', type=int, default=32, help='Batch size for evaluation.')
85+
parser.add_argument('--remove_nil', action='store_true', default=False, help='Whether to remove nil entities from the dataset.')
86+
parser.add_argument('--output_dir', type=str, default=None, help='Path to the output directory.')
87+
parser.add_argument("--cache_dir", type=str, default=None, help='Path to the cache directory.')
88+
parser.add_argument('--model_config', type=str, default=None, help='YAML-based config file.')
89+
parser.add_argument('--retriever_config', type=str, default=None, help='YAML-based retriever config file.')
90+
parser.add_argument('--wandb', action='store_true', default=False, help='Whether to use wandb for logging.')
91+
args = parser.parse_args()
92+
evaluate(args)
93+
94+
if __name__ == "__main__":
95+
cli_main()
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
import logging
3+
import os
4+
from argparse import ArgumentParser, Namespace
5+
6+
import torch
7+
8+
from entity_linkings import get_retrievers, load_dataset, load_dictionary
9+
from entity_linkings.utils import read_yaml
10+
11+
logger = logging.getLogger(__name__)
12+
13+
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
14+
15+
def evaluate(args: Namespace) -> None:
16+
dictionary = load_dictionary(args.dictionary_id_or_path, cache_dir=args.cache_dir)
17+
dataset_id = args.dataset_id if args.dataset_id else "json"
18+
if dataset_id != "json":
19+
test_dataset = load_dataset(dataset_id, split='test', cache_dir=args.cache_dir)
20+
else:
21+
test_dataset = load_dataset("json", data_files={"test": args.test_file}, cache_dir=args.cache_dir)['test']
22+
if args.remove_nil:
23+
from entity_linkings.data_utils import filter_nil_entities
24+
test_dataset = filter_nil_entities(test_dataset, dictionary)
25+
26+
if args.wandb:
27+
import wandb
28+
wandb.init(
29+
project=os.environ.get("WANDB_PROJECT", "entity_linkings"),
30+
name=args.retriever_id, tags=["evaluation"]
31+
)
32+
wandb.log({
33+
"retriever_id": args.retriever_id,
34+
"dataset_id": dataset_id,
35+
"dictionary_id_or_path": args.dictionary_id_or_path,
36+
"model_name_or_path": args.retriever_model_name_or_path,
37+
"remove_nil": args.remove_nil
38+
})
39+
40+
if args.retriever_config is not None:
41+
retriever_config = read_yaml(args.retriever_config).get(args.retriever_id, {})
42+
else:
43+
retriever_config = {}
44+
if args.retriever_model_name_or_path is not None:
45+
retriever_config["model_name_or_path"] = args.retriever_model_name_or_path
46+
47+
retriever_cls = get_retrievers(args.retriever_id)
48+
model = retriever_cls(dictionary=dictionary, config=retriever_cls.Config(**retriever_config))
49+
metrics = model.evaluate(test_dataset, batch_size=args.test_batch_size)
50+
logger.info(f"Evaluation results: {metrics}")
51+
if args.output_dir is not None:
52+
os.makedirs(args.output_dir, exist_ok=True)
53+
output_path = f"{args.output_dir}/eval_results.json"
54+
with open(output_path, "w") as f:
55+
json.dump(metrics, f, indent=4)
56+
logger.info(f"Saved evaluation results to {output_path}")
57+
if args.wandb:
58+
for key, value in metrics.items():
59+
wandb.log({key: value})
60+
61+
62+
def cli_main() -> None:
63+
parser = ArgumentParser()
64+
parser.add_argument('--retriever_id', type=str, required=True, help='Name of the retriever model to use.')
65+
parser.add_argument('--retriever_model_name_or_path', type=str, default=None, help='Name of the model to use.')
66+
parser.add_argument('--dictionary_id_or_path', type=str, default=None, help='Path to the entity dictionary file.')
67+
parser.add_argument('--dataset_id', type=str, default=None, help='Name of the dataset to use.')
68+
parser.add_argument('--test_file', type=str, default=None, help='Path to the dataset file.')
69+
parser.add_argument('--test_batch_size', type=int, default=32, help='Batch size for evaluation.')
70+
parser.add_argument('--remove_nil', action='store_true', default=False, help='Whether to remove nil entities from the dataset.')
71+
parser.add_argument('--output_dir', type=str, default=None, help='Path to the output directory.')
72+
parser.add_argument("--cache_dir", type=str, default=None, help='Path to the cache directory.')
73+
parser.add_argument('--retriever_config', type=str, default=None, help='YAML-based config file.')
74+
parser.add_argument('--wandb', action='store_true', default=False, help='Whether to use wandb for logging.')
75+
args = parser.parse_args()
76+
evaluate(args)
77+
78+
if __name__ == "__main__":
79+
cli_main()

0 commit comments

Comments
 (0)