Skip to content

Commit cef1595

Browse files
authored
Merge pull request #1328 from hanhainebula/master
fix bugs for embedder finetune
2 parents 808b6c8 + ddf9ada commit cef1595

File tree

16 files changed

+58
-23
lines changed

16 files changed

+58
-23
lines changed

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ def _concatenate_results_from_multi_process(self, results_list: List[Union[torch
416416
Union[torch.Tensor, np.ndarray]: return the embedding vectors in a numpy array or tensor.
417417
"""
418418
if isinstance(results_list[0], torch.Tensor):
419+
# move all tensors to the same device
420+
results_list = [res.to(self.target_devices[0]) for res in results_list]
419421
return torch.cat(results_list, dim=0)
420422
elif isinstance(results_list[0], np.ndarray):
421423
return np.concatenate(results_list, axis=0)

FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
5151
config = AutoConfig.from_pretrained(
5252
model_args.config_name,
5353
token=model_args.token,
54-
cache_dir=model_args.cache_dir
54+
cache_dir=model_args.cache_dir,
55+
trust_remote_code=model_args.trust_remote_code,
5556
)
5657
elif model_args.model_name_or_path:
5758
config = AutoConfig.from_pretrained(
5859
model_args.model_name_or_path,
5960
token=model_args.token,
60-
cache_dir=model_args.cache_dir
61+
cache_dir=model_args.cache_dir,
62+
trust_remote_code=model_args.trust_remote_code,
6163
)
6264
else:
6365
raise ValueError(
@@ -74,6 +76,7 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
7476
cache_dir=model_args.cache_dir,
7577
from_tf=bool(".ckpt" in model_args.model_name_or_path),
7678
config=config,
79+
trust_remote_code=model_args.trust_remote_code,
7780
)
7881
else:
7982
logger.info("Training new model from scratch")
@@ -129,13 +132,15 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
129132
config = AutoConfig.from_pretrained(
130133
model_args.config_name,
131134
token=model_args.token,
132-
cache_dir=model_args.cache_dir
135+
cache_dir=model_args.cache_dir,
136+
trust_remote_code=model_args.trust_remote_code,
133137
)
134138
elif model_args.model_name_or_path:
135139
config = AutoConfig.from_pretrained(
136140
model_args.model_name_or_path,
137141
token=model_args.token,
138-
cache_dir=model_args.cache_dir
142+
cache_dir=model_args.cache_dir,
143+
trust_remote_code=model_args.trust_remote_code,
139144
)
140145
else:
141146
raise ValueError(
@@ -152,6 +157,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
152157
cache_dir=model_args.cache_dir,
153158
from_tf=bool(".ckpt" in model_args.model_name_or_path),
154159
config=config,
160+
trust_remote_code=model_args.trust_remote_code,
155161
)
156162
else:
157163
model = model_args.from_config(config)
@@ -173,5 +179,5 @@ def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir:
173179

174180
model.save_pretrained(os.path.join(output_dir, 'merged_model'))
175181

176-
tokenizer = AutoTokenizer.from_pretrained(output_dir)
182+
tokenizer = AutoTokenizer.from_pretrained(output_dir, trust_remote_code=model_args.trust_remote_code)
177183
tokenizer.save_pretrained(os.path.join(output_dir, 'merged_model'))

FlagEmbedding/finetune/embedder/decoder_only/base/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
4141
token=self.model_args.token,
4242
cache_dir=self.model_args.cache_dir,
4343
use_fast=False,
44-
add_eos_token=True
44+
add_eos_token=True,
45+
trust_remote_code=self.model_args.trust_remote_code,
4546
)
4647

4748
if tokenizer.pad_token is None:

FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
5151
config = AutoConfig.from_pretrained(
5252
model_args.config_name,
5353
token=model_args.token,
54-
cache_dir=model_args.cache_dir
54+
cache_dir=model_args.cache_dir,
55+
trust_remote_code=model_args.trust_remote_code,
5556
)
5657
elif model_args.model_name_or_path:
5758
config = AutoConfig.from_pretrained(
5859
model_args.model_name_or_path,
5960
token=model_args.token,
60-
cache_dir=model_args.cache_dir
61+
cache_dir=model_args.cache_dir,
62+
trust_remote_code=model_args.trust_remote_code,
6163
)
6264
else:
6365
raise ValueError(
@@ -74,6 +76,7 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
7476
cache_dir=model_args.cache_dir,
7577
from_tf=bool(".ckpt" in model_args.model_name_or_path),
7678
config=config,
79+
trust_remote_code=model_args.trust_remote_code,
7780
)
7881
else:
7982
logger.info("Training new model from scratch")
@@ -152,6 +155,7 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d
152155
cache_dir=model_args.cache_dir,
153156
from_tf=bool(".ckpt" in model_args.model_name_or_path),
154157
config=config,
158+
trust_remote_code=model_args.trust_remote_code,
155159
)
156160
else:
157161
model = model_args.from_config(config)
@@ -173,5 +177,5 @@ def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_d
173177

174178
model.save_pretrained(os.path.join(output_dir, 'merged_model'))
175179

176-
tokenizer = AutoTokenizer.from_pretrained(output_dir)
180+
tokenizer = AutoTokenizer.from_pretrained(output_dir, trust_remote_code=model_args.trust_remote_code)
177181
tokenizer.save_pretrained(os.path.join(output_dir, 'merged_model'))

FlagEmbedding/finetune/embedder/decoder_only/icl/runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
4545
token=self.model_args.token,
4646
cache_dir=self.model_args.cache_dir,
4747
use_fast=False,
48-
add_eos_token=True
48+
add_eos_token=True,
49+
trust_remote_code=self.model_args.trust_remote_code,
4950
)
5051

5152
if tokenizer.pad_token is None:

FlagEmbedding/finetune/embedder/encoder_only/m3/runner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
142142
if "position_embeddings" in k:
143143
logging.info(f"Freeze the parameters for {k}")
144144
v.requires_grad = False
145+
146+
if self.training_args.fix_encoder:
147+
for k, v in model.named_parameters():
148+
if "colbert_linear" in k or 'sparse_linear' in k:
149+
logging.info(f"train the parameters for {k}")
150+
else:
151+
v.requires_grad = False
152+
145153
return tokenizer, model
146154

147155
def load_trainer(self) -> EncoderOnlyEmbedderM3Trainer:

FlagEmbedding/inference/embedder/decoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def encode_single_device(
257257
flag = True
258258
except RuntimeError as e:
259259
batch_size = batch_size * 3 // 4
260-
except torch.OutofMemoryError as e:
260+
except torch.OutOfMemoryError as e:
261261
batch_size = batch_size * 3 // 4
262262

263263
# encode

FlagEmbedding/inference/embedder/decoder_only/icl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def encode_queries_single_device(
409409
flag = True
410410
except RuntimeError as e:
411411
batch_size = batch_size * 3 // 4
412-
except torch.OutofMemoryError as e:
412+
except torch.OutOfMemoryError as e:
413413
batch_size = batch_size * 3 // 4
414414

415415
# encode
@@ -519,7 +519,7 @@ def encode_single_device(
519519
flag = True
520520
except RuntimeError as e:
521521
batch_size = batch_size * 3 // 4
522-
except torch.OutofMemoryError as e:
522+
except torch.OutOfMemoryError as e:
523523
batch_size = batch_size * 3 // 4
524524

525525
# encode

FlagEmbedding/inference/embedder/encoder_only/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def encode_single_device(
238238
flag = True
239239
except RuntimeError as e:
240240
batch_size = batch_size * 3 // 4
241-
except torch.OutofMemoryError as e:
241+
except torch.OutOfMemoryError as e:
242242
batch_size = batch_size * 3 // 4
243243

244244
# encode

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def _process_colbert_vecs(colbert_vecs: np.ndarray, attention_mask: list):
406406
flag = True
407407
except RuntimeError as e:
408408
batch_size = batch_size * 3 // 4
409-
except torch.OutofMemoryError as e:
409+
except torch.OutOfMemoryError as e:
410410
batch_size = batch_size * 3 // 4
411411

412412
# encode

0 commit comments

Comments
 (0)