@@ -13,7 +13,7 @@ def data_handler():
1313
1414
1515def test_nli_transformer_predict_without_trained_head (data_handler ):
16- model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L-6 -v2" , "train_head" : True })
16+ model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L6 -v2" , "train_head" : True })
1717 with pytest .raises (ValueError , match = "Classifier is not trained yet" ):
1818 model .predict (data_handler .train_utterances (0 ))
1919
@@ -48,7 +48,7 @@ def check_ranking(ranked, labels):
4848
4949
5050def test_nli_transformer_predict_with_train_head (data_handler ):
51- model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L-6 -v2" , "train_head" : True })
51+ model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L6 -v2" , "train_head" : True })
5252 texts = data_handler .train_utterances (0 )
5353 labels = data_handler .train_labels (0 )
5454 model .fit (texts , labels )
@@ -60,7 +60,7 @@ def test_nli_transformer_predict_with_train_head(data_handler):
6060
6161
6262def test_nli_transformer_predict_default (data_handler ):
63- model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L-6 -v2" , "train_head" : False })
63+ model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L6 -v2" , "train_head" : False })
6464 texts = data_handler .train_utterances (0 )
6565 labels = data_handler .train_labels (0 )
6666 predicted = model .predict (build_pairs (texts ))
@@ -71,7 +71,7 @@ def test_nli_transformer_predict_default(data_handler):
7171
7272
7373def test_nli_transformer_predict_default_with_fit (data_handler ):
74- model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L-6 -v2" , "train_head" : False })
74+ model = Ranker (cross_encoder_config = {"model_name" : "cross-encoder/ms-marco-MiniLM-L6 -v2" , "train_head" : False })
7575 texts = data_handler .train_utterances (0 )
7676 labels = data_handler .train_labels (0 )
7777 model .fit (texts , labels )
0 commit comments