Skip to content

Commit f18806c

Browse files
authored
Merge branch 'FlagOpen:master' into master
2 parents 0874413 + 3c40623 commit f18806c

File tree

99 files changed

+15499
-907
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+15499
-907
lines changed

.github/workflows/documentation.yml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@ jobs:
1111
steps:
1212
- uses: actions/checkout@v4
1313
- uses: actions/setup-python@v5
14-
- name: Install dependencies
14+
- name: Install doc dependencies
1515
run: |
16-
pip install . sphinx sphinx_rtd_theme myst_parser myst-nb furo
16+
pip install . sphinx myst_parser myst-nb sphinx-design pydata-sphinx-theme sphinxcontrib-googleanalytics
17+
- name: Install content dependencies
18+
run: |
19+
pip install faiss-cpu mteb air-benchmark beir
1720
- name: Sphinx build
1821
run: |
1922
sphinx-build docs/source docs/build
23+
- name: Add CNAME
24+
run: |
25+
echo bge-model.com > docs/build/CNAME
2026
- name: Deploy to GitHub Pages
2127
uses: peaceiris/actions-gh-pages@v3
2228
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}

FlagEmbedding/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .inference import *
2+
from .evaluation import *

FlagEmbedding/abc/finetune/reranker/AbsDataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class AbsRerankerCollator(DataCollatorWithPadding):
183183
query_max_len: int = 32
184184
passage_max_len: int = 128
185185

186-
def __call__(self, features) -> list[BatchEncoding]:
186+
def __call__(self, features) -> List[BatchEncoding]:
187187
teacher_scores = [f[1] for f in features]
188188
if teacher_scores[0] is None:
189189
teacher_scores = None

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,8 @@ def _concatenate_results_from_multi_process(self, results_list: List[Union[torch
462462
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
463463
"""
464464
if isinstance(results_list[0], torch.Tensor):
465+
# move all tensors to the same device
466+
results_list = [res.to(self.target_devices[0]) for res in results_list]
465467
return torch.cat(results_list, dim=0)
466468
elif isinstance(results_list[0], np.ndarray):
467469
return np.concatenate(results_list, axis=0)

FlagEmbedding/evaluation/beir/data_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _load_remote_qrels(
145145
if dataset_name != 'cqadupstack':
146146
qrels = datasets.load_dataset(
147147
'BeIR/{d}-qrels'.format(d=dataset_name),
148-
split=split,
148+
split=split if split != 'dev' else 'validation',
149149
trust_remote_code=True,
150150
cache_dir=self.cache_dir,
151151
download_mode=self.hf_download_mode
@@ -409,7 +409,7 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
409409
Returns:
410410
datasets.DatasetDict: A dict of relevance of query and document.
411411
"""
412-
checked_split = self.check_splits(split)
412+
checked_split = self.check_splits(split, dataset_name=dataset_name)
413413
if len(checked_split) == 0:
414414
raise ValueError(f"Split {split} not found in the dataset.")
415415
split = checked_split[0]
@@ -450,7 +450,7 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None,
450450
Returns:
451451
datasets.DatasetDict: A dict of queries with id as key, query text as value.
452452
"""
453-
checked_split = self.check_splits(split)
453+
checked_split = self.check_splits(split, dataset_name=dataset_name)
454454
if len(checked_split) == 0:
455455
raise ValueError(f"Split {split} not found in the dataset.")
456456
split = checked_split[0]

FlagEmbedding/evaluation/mteb/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .arguments import MTEBEvalArgs
1111
from .searcher import MTEBEvalDenseRetriever, MTEBEvalReranker
1212
from .prompts import get_task_def_by_task_name_and_type
13-
from .examples import examples_dict
13+
1414

1515
logger = logging.getLogger(__name__)
1616

FlagEmbedding/finetune/embedder/decoder_only/base/arguments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,8 @@ class DecoderOnlyEmbedderModelArguments(AbsEmbedderModelArguments):
6969
default=False,
7070
metadata={"help": "If passed, will merge the lora modules and save the entire model."}
7171
)
72+
73+
only_merge_lora_model: bool = field(
74+
default=False,
75+
metadata={"help": "If passed, will only merge the lora modules and save the entire model."}
76+
)

FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
5151
config = AutoConfig.from_pretrained(
5252
model_args.config_name,
5353
token=model_args.token,
54-
cache_dir=model_args.cache_dir
54+
cache_dir=model_args.cache_dir,
55+
trust_remote_code=model_args.trust_remote_code,
5556
)
5657
elif model_args.model_name_or_path:
5758
config = AutoConfig.from_pretrained(
5859
model_args.model_name_or_path,
5960
token=model_args.token,
60-
cache_dir=model_args.cache_dir
61+
cache_dir=model_args.cache_dir,
62+
trust_remote_code=model_args.trust_remote_code,
6163
)
6264
else:
6365
raise ValueError(
@@ -74,6 +76,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
7476
cache_dir=model_args.cache_dir,
7577
from_tf=bool(".ckpt" in model_args.model_name_or_path),
7678
config=config,
79+
trust_remote_code=model_args.trust_remote_code,
7780
)
7881
else:
7982
logger.info("Training new model from scratch")
@@ -129,13 +132,15 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
129132
config = AutoConfig.from_pretrained(
130133
model_args.config_name,
131134
token=model_args.token,
132-
cache_dir=model_args.cache_dir
135+
cache_dir=model_args.cache_dir,
136+
trust_remote_code=model_args.trust_remote_code,
133137
)
134138
elif model_args.model_name_or_path:
135139
config = AutoConfig.from_pretrained(
136140
model_args.model_name_or_path,
137141
token=model_args.token,
138-
cache_dir=model_args.cache_dir
142+
cache_dir=model_args.cache_dir,
143+
trust_remote_code=model_args.trust_remote_code,
139144
)
140145
else:
141146
raise ValueError(
@@ -152,6 +157,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
152157
cache_dir=model_args.cache_dir,
153158
from_tf=bool(".ckpt" in model_args.model_name_or_path),
154159
config=config,
160+
trust_remote_code=model_args.trust_remote_code,
155161
)
156162
else:
157163
model = model_args.from_config(config)
@@ -171,7 +177,9 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
171177
model = PeftModel.from_pretrained(model, find_largest_checkpoint(output_dir))
172178
model = model.merge_and_unload()
173179

174-
model.save_pretrained(os.path.join(output_dir, 'merged_model'))
175-
176-
tokenizer = AutoTokenizer.from_pretrained(output_dir)
180+
tokenizer = AutoTokenizer.from_pretrained(output_dir, trust_remote_code=model_args.trust_remote_code)
177181
tokenizer.save_pretrained(os.path.join(output_dir, 'merged_model'))
182+
183+
# modify the vocab size in the model configuration
184+
model.config.vocab_size = len(tokenizer)
185+
model.save_pretrained(os.path.join(output_dir, 'merged_model'))

FlagEmbedding/finetune/embedder/decoder_only/base/runner.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ def __init__(
2929
training_args: AbsEmbedderTrainingArguments
3030
):
3131
super().__init__(model_args, data_args, training_args)
32+
self.model_args: DecoderOnlyEmbedderModelArguments
33+
self.data_args: AbsEmbedderDataArguments
34+
self.training_args: AbsEmbedderTrainingArguments
3235

3336
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]:
3437
"""Load tokenizer and model.
@@ -41,7 +44,8 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
4144
token=self.model_args.token,
4245
cache_dir=self.model_args.cache_dir,
4346
use_fast=False,
44-
add_eos_token=True
47+
add_eos_token=True,
48+
trust_remote_code=self.model_args.trust_remote_code,
4549
)
4650

4751
if tokenizer.pad_token is None:
@@ -116,11 +120,12 @@ def run(self):
116120
"""
117121
Run the finetune.
118122
"""
119-
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
123+
if not self.model_args.only_merge_lora_model:
124+
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
120125

121-
# Training
122-
self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
123-
self.trainer.save_model()
126+
# Training
127+
self.trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
128+
self.trainer.save_model()
124129

125130
# save merged model
126131
if self.model_args.save_merged_lora_model and self.training_args.process_index == 0:

FlagEmbedding/finetune/embedder/decoder_only/icl/arguments.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ class DecoderOnlyEmbedderICLModelArguments(AbsEmbedderModelArguments):
7373
metadata={"help": "If passed, will merge the lora modules and save the entire model."}
7474
)
7575

76+
only_merge_lora_model: bool = field(
77+
default=False,
78+
metadata={"help": "If passed, will only merge the lora modules and save the entire model."}
79+
)
80+
7681

7782
@dataclass
7883
class DecoderOnlyEmbedderICLDataArguments(AbsEmbedderDataArguments):

0 commit comments

Comments
 (0)