Skip to content

Commit 75641dd

Browse files
authored
fix: Added checks for DataParallel and WrappedDataParallel (#3366)
* Added checks for DataParallel and WrappedDataParallel * Update isinstance checks according to pylint recommendation * Using isinstance over types * Added test for dpr training
1 parent db6e575 commit 75641dd

File tree

3 files changed

+48
-9
lines changed

3 files changed

+48
-9
lines changed

haystack/modeling/evaluation/eval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import logging
44
import numbers
55
import torch
6+
from torch.nn import DataParallel
67
import numpy as np
78
from tqdm import tqdm
89

910
from haystack.modeling.evaluation.metrics import compute_metrics, compute_report_metrics
1011
from haystack.modeling.model.adaptive_model import AdaptiveModel
1112
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
13+
from haystack.modeling.model.optimization import WrappedDataParallel
1214
from haystack.utils.experiment_tracking import Tracker as tracker
1315
from haystack.modeling.visual import BUSH_SEP
1416

@@ -70,17 +72,21 @@ def eval(
7072
for step, batch in enumerate(tqdm(self.data_loader, desc="Evaluating", mininterval=10)):
7173
batch = {key: batch[key].to(self.device) for key in batch}
7274

73-
with torch.no_grad():
75+
if isinstance(model, (DataParallel, WrappedDataParallel)):
76+
module = model.module
77+
else:
78+
module = model
7479

75-
if isinstance(model, AdaptiveModel):
80+
with torch.no_grad():
81+
if isinstance(module, AdaptiveModel):
7682
logits = model.forward(
7783
input_ids=batch.get("input_ids", None),
7884
segment_ids=batch.get("segment_ids", None),
7985
padding_mask=batch.get("padding_mask", None),
8086
output_hidden_states=batch.get("output_hidden_states", False),
8187
output_attentions=batch.get("output_attentions", False),
8288
)
83-
elif isinstance(model, BiAdaptiveModel):
89+
elif isinstance(module, BiAdaptiveModel):
8490
logits = model.forward(
8591
query_input_ids=batch.get("query_input_ids", None),
8692
query_segment_ids=batch.get("query_segment_ids", None),

haystack/modeling/training/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from haystack.modeling.evaluation.eval import Evaluator
1919
from haystack.modeling.model.adaptive_model import AdaptiveModel
2020
from haystack.modeling.model.biadaptive_model import BiAdaptiveModel
21-
from haystack.modeling.model.optimization import get_scheduler
21+
from haystack.modeling.model.optimization import get_scheduler, WrappedDataParallel
2222
from haystack.modeling.utils import GracefulKiller
2323
from haystack.utils.experiment_tracking import Tracker as tracker
2424
from haystack.utils.early_stopping import EarlyStopping
@@ -292,12 +292,17 @@ def train(self):
292292

293293
def compute_loss(self, batch: dict, step: int) -> torch.Tensor:
294294
# Forward & backward pass through model
295-
if isinstance(self.model, AdaptiveModel):
295+
if isinstance(self.model, (DataParallel, WrappedDataParallel)):
296+
module = self.model.module
297+
else:
298+
module = self.model
299+
300+
if isinstance(module, AdaptiveModel):
296301
logits = self.model.forward(
297302
input_ids=batch["input_ids"], segment_ids=None, padding_mask=batch["padding_mask"]
298303
)
299304

300-
elif isinstance(self.model, BiAdaptiveModel):
305+
elif isinstance(module, BiAdaptiveModel):
301306
logits = self.model.forward(
302307
query_input_ids=batch["query_input_ids"],
303308
query_segment_ids=batch["query_segment_ids"],

test/modeling/test_dpr.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import os
12
from typing import Tuple
23

3-
import os
44
import logging
55
from pathlib import Path
66

@@ -17,8 +17,11 @@
1717
from haystack.modeling.model.language_model import get_language_model, DPREncoder
1818
from haystack.modeling.model.prediction_head import TextSimilarityHead
1919
from haystack.modeling.model.tokenization import get_tokenizer
20+
from haystack.nodes.retriever.dense import DensePassageRetriever
2021
from haystack.modeling.utils import set_all_seeds, initialize_device_settings
2122

23+
from ..conftest import SAMPLES_PATH
24+
2225

2326
def test_dpr_modules(caplog=None):
2427
if caplog:
@@ -970,6 +973,33 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
970973
assert np.array_equal(all_embeddings["query"][0], all_embeddings3["query"][0])
971974

972975

976+
@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
977+
def test_dpr_training(document_store, tmp_path):
978+
retriever = DensePassageRetriever(
979+
document_store=document_store,
980+
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
981+
passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
982+
max_seq_len_query=8,
983+
max_seq_len_passage=8,
984+
)
985+
986+
save_dir = f"{tmp_path}/test_dpr_training"
987+
retriever.train(
988+
data_dir=str(SAMPLES_PATH / "dpr"),
989+
train_filename="sample.json",
990+
dev_filename="sample.json",
991+
test_filename="sample.json",
992+
n_epochs=1,
993+
batch_size=1,
994+
grad_acc_steps=1,
995+
save_dir=save_dir,
996+
evaluate_every=10,
997+
embed_title=True,
998+
num_positives=1,
999+
num_hard_negatives=1,
1000+
)
1001+
1002+
9731003
# TODO fix CI errors (test pass locally or on AWS, next steps: isolate PyTorch versions once FARM dependency is removed)
9741004
# def test_dpr_training():
9751005
# batch_size = 1
@@ -982,8 +1012,6 @@ def test_dpr_processor_save_load_non_bert_tokenizer(tmp_path: Path, query_and_pa
9821012
# use_fast = True
9831013
# similarity_function = "dot_product"
9841014
#
985-
#
986-
#
9871015
# device, n_gpu = initialize_device_settings(use_cuda=False)
9881016
#
9891017
# query_tokenizer = get_tokenizer(pretrained_model_name_or_path=question_lang_model,

0 commit comments

Comments
 (0)