@@ -170,3 +170,64 @@ def test_local_classifier_from_to_parquet(setup):
170170 expected = classifier .predict (X )
171171 expected = np .stack ([1 - expected , expected ]).argmax (axis = 0 )
172172 np .testing .assert_array_equal (ret , expected )
173+
174+
175+ @pytest .mark .skipif (lightgbm is None , reason = "LightGBM not installed" )
176+ def test_classifier_on_multiple_machines (setup ):
177+ from .._train import LGBMTrain
178+
179+ class MockLGMBTrain (LGBMTrain ):
180+ @classmethod
181+ def execute (cls , ctx , op : "LGBMTrain" ):
182+ super ().execute (ctx , op )
183+ # Note: There may be a list result when running on multiple
184+ # machines, here just make an array of length 1 to simulate
185+ # this scenario.
186+ ctx [op .outputs [0 ].key ] = [ctx [op .outputs [0 ].key ]]
187+
188+ from ..core import LGBMModelType
189+ from .._train import train
190+ from ....utils import check_consistent_length
191+
192+ class MockLGBMClassifier (LGBMClassifier , lightgbm .LGBMClassifier ):
193+ def fit (
194+ self ,
195+ X ,
196+ y ,
197+ sample_weight = None ,
198+ init_score = None ,
199+ eval_set = None ,
200+ eval_sample_weight = None ,
201+ eval_init_score = None ,
202+ session = None ,
203+ run_kwargs = None ,
204+ ** kwargs ,
205+ ):
206+ check_consistent_length (X , y , session = session , run_kwargs = run_kwargs )
207+ params = self .get_params (True )
208+ model = train (
209+ params ,
210+ self ._wrap_train_tuple (X , y , sample_weight , init_score ),
211+ eval_sets = self ._wrap_eval_tuples (
212+ eval_set , eval_sample_weight , eval_init_score
213+ ),
214+ model_type = LGBMModelType .CLASSIFIER ,
215+ session = session ,
216+ run_kwargs = run_kwargs ,
217+ train_cls = MockLGMBTrain ,
218+ ** kwargs ,
219+ )
220+
221+ self .set_params (** model .get_params ())
222+ self ._copy_extra_params (model , self )
223+ return self
224+
225+ y_data = (y * 10 ).astype (mt .int32 )
226+ classifier = MockLGBMClassifier (n_estimators = 2 )
227+ classifier .fit (X , y_data , eval_set = [(X , y_data )], verbose = True )
228+ prediction = classifier .predict (X )
229+
230+ assert prediction .ndim == 1
231+ assert prediction .shape [0 ] == len (X )
232+
233+ assert isinstance (prediction , mt .Tensor )
0 commit comments