@@ -390,28 +390,6 @@ def setUp(self):
390
390
"expected_prediction_with_base_margin" ,
391
391
],
392
392
)
393
- self .ranker_df_train = self .session .createDataFrame (
394
- [
395
- (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 0 ),
396
- (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 0 ),
397
- (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 0 ),
398
- (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 1 ),
399
- (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 1 ),
400
- (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 1 ),
401
- ],
402
- ["features" , "label" , "qid" ],
403
- )
404
- self .ranker_df_test = self .session .createDataFrame (
405
- [
406
- (Vectors .dense (1.5 , 2.0 , 3.0 ), 0 , - 1.87988 ),
407
- (Vectors .dense (4.5 , 5.0 , 6.0 ), 0 , 0.29556 ),
408
- (Vectors .dense (9.0 , 4.5 , 8.0 ), 0 , 2.36570 ),
409
- (Vectors .sparse (3 , {1 : 1.0 , 2 : 6.0 }), 1 , - 1.87988 ),
410
- (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.0 }), 1 , - 0.30612 ),
411
- (Vectors .sparse (3 , {1 : 8.0 , 2 : 10.5 }), 1 , 2.44826 ),
412
- ],
413
- ["features" , "qid" , "expected_prediction" ],
414
- )
415
393
416
394
self .reg_df_sparse_train = self .session .createDataFrame (
417
395
[
@@ -1039,15 +1017,6 @@ def test_classifier_with_sparse_optim(self):
1039
1017
for row1 , row2 in zip (pred_result , pred_result2 ):
1040
1018
self .assertTrue (np .allclose (row1 .probability , row2 .probability , rtol = 1e-3 ))
1041
1019
1042
- def test_ranker (self ):
1043
- ranker = SparkXGBRanker (qid_col = "qid" )
1044
- assert ranker .getOrDefault (ranker .objective ) == "rank:pairwise"
1045
- model = ranker .fit (self .ranker_df_train )
1046
- pred_result = model .transform (self .ranker_df_test ).collect ()
1047
-
1048
- for row in pred_result :
1049
- assert np .isclose (row .prediction , row .expected_prediction , rtol = 1e-3 )
1050
-
1051
1020
def test_empty_validation_data (self ) -> None :
1052
1021
for tree_method in [
1053
1022
"hist" ,
@@ -1130,3 +1099,63 @@ def test_early_stop_param_validation(self):
1130
1099
def test_unsupported_params (self ):
1131
1100
with pytest .raises (ValueError , match = "evals_result" ):
1132
1101
SparkXGBClassifier (evals_result = {})
1102
+
1103
+
1104
+ class XgboostRankerLocalTest (SparkTestCase ):
1105
+ def setUp (self ):
1106
+ self .session .conf .set ("spark.sql.execution.arrow.maxRecordsPerBatch" , "8" )
1107
+ self .ranker_df_train = self .session .createDataFrame (
1108
+ [
1109
+ (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 0 ),
1110
+ (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 0 ),
1111
+ (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 0 ),
1112
+ (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 1 ),
1113
+ (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 1 ),
1114
+ (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 1 ),
1115
+ ],
1116
+ ["features" , "label" , "qid" ],
1117
+ )
1118
+ self .ranker_df_test = self .session .createDataFrame (
1119
+ [
1120
+ (Vectors .dense (1.5 , 2.0 , 3.0 ), 0 , - 1.87988 ),
1121
+ (Vectors .dense (4.5 , 5.0 , 6.0 ), 0 , 0.29556 ),
1122
+ (Vectors .dense (9.0 , 4.5 , 8.0 ), 0 , 2.36570 ),
1123
+ (Vectors .sparse (3 , {1 : 1.0 , 2 : 6.0 }), 1 , - 1.87988 ),
1124
+ (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.0 }), 1 , - 0.30612 ),
1125
+ (Vectors .sparse (3 , {1 : 8.0 , 2 : 10.5 }), 1 , 2.44826 ),
1126
+ ],
1127
+ ["features" , "qid" , "expected_prediction" ],
1128
+ )
1129
+ self .ranker_df_train_1 = self .session .createDataFrame (
1130
+ [
1131
+ (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 9 ),
1132
+ (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 9 ),
1133
+ (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 9 ),
1134
+ (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 8 ),
1135
+ (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 8 ),
1136
+ (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 8 ),
1137
+ (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 7 ),
1138
+ (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 7 ),
1139
+ (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 7 ),
1140
+ (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 6 ),
1141
+ (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 6 ),
1142
+ (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 6 ),
1143
+ ]
1144
+ * 4 ,
1145
+ ["features" , "label" , "qid" ],
1146
+ )
1147
+
1148
+ def test_ranker (self ):
1149
+ ranker = SparkXGBRanker (qid_col = "qid" )
1150
+ assert ranker .getOrDefault (ranker .objective ) == "rank:pairwise"
1151
+ model = ranker .fit (self .ranker_df_train )
1152
+ pred_result = model .transform (self .ranker_df_test ).collect ()
1153
+
1154
+ for row in pred_result :
1155
+ assert np .isclose (row .prediction , row .expected_prediction , rtol = 1e-3 )
1156
+
1157
+ def test_ranker_qid_sorted (self ):
1158
+ ranker = SparkXGBRanker (qid_col = "qid" , num_workers = 4 )
1159
+ assert ranker .getOrDefault (ranker .objective ) == "rank:pairwise"
1160
+ model = ranker .fit (self .ranker_df_train_1 )
1161
+ model .transform (self .ranker_df_test ).collect ()
0 commit comments