@@ -141,52 +141,23 @@ def num_trees
141141 out . read_int
142142 end
143143
144- def predict ( data , start_iteration : 0 , num_iteration : -1 , raw_score : false , pred_leaf : false , pred_contrib : false , **kwargs )
144+ def predict ( data , start_iteration : 0 , num_iteration : nil , raw_score : false , pred_leaf : false , pred_contrib : false , **kwargs )
145+ predictor = InnerPredictor . from_booster ( self , kwargs . transform_values ( &:dup ) )
145146 if num_iteration . nil?
146147 if start_iteration <= 0
147148 num_iteration = best_iteration
148149 else
149150 num_iteration = -1
150151 end
151152 end
152-
153- if data . is_a? ( Dataset )
154- raise TypeError , "Cannot use Dataset instance for prediction, please use raw data instead"
155- end
156-
157- predict_type = FFI ::C_API_PREDICT_NORMAL
158- if raw_score
159- predict_type = FFI ::C_API_PREDICT_RAW_SCORE
160- end
161- if pred_leaf
162- predict_type = FFI ::C_API_PREDICT_LEAF_INDEX
163- end
164- if pred_contrib
165- predict_type = FFI ::C_API_PREDICT_CONTRIB
166- end
167-
168- preds , nrow , singular =
169- preds_for_data (
170- data ,
171- start_iteration ,
172- num_iteration ,
173- predict_type ,
174- **kwargs
175- )
176-
177- if pred_leaf
178- preds = preds . map ( &:to_i )
179- end
180-
181- if preds . size != nrow
182- if preds . size % nrow == 0
183- preds = preds . each_slice ( preds . size / nrow ) . to_a
184- else
185- raise Error , "Length of predict result (#{ preds . size } ) cannot be divide nrow (#{ nrow } )"
186- end
187- end
188-
189- singular ? preds . first : preds
153+ predictor . predict (
154+ data ,
155+ start_iteration : start_iteration ,
156+ num_iteration : num_iteration ,
157+ raw_score : raw_score ,
158+ pred_leaf : pred_leaf ,
159+ pred_contrib : pred_contrib
160+ )
190161 end
191162
192163 def save_model ( filename , num_iteration : nil , start_iteration : 0 )
@@ -261,61 +232,6 @@ def num_class
261232 out . read_int
262233 end
263234
264- def preds_for_data ( input , start_iteration , num_iteration , predict_type , **params )
265- input =
266- if daru? ( input )
267- input [ *cached_feature_name ] . map_rows ( &:to_a )
268- elsif input . is_a? ( Hash ) # sort feature.values to match the order of model.feature_name
269- sorted_feature_values ( input )
270- elsif input . is_a? ( Array ) && input . first . is_a? ( Hash ) # on multiple elems, if 1st is hash, assume they all are
271- input . map ( &method ( :sorted_feature_values ) )
272- elsif rover? ( input )
273- # TODO improve performance
274- input [ cached_feature_name ] . to_numo . to_a
275- else
276- input . to_a
277- end
278-
279- singular = !input . first . is_a? ( Array )
280- input = [ input ] if singular
281-
282- nrow = input . count
283- n_preds =
284- num_preds (
285- start_iteration ,
286- num_iteration ,
287- nrow ,
288- predict_type
289- )
290-
291- flat_input = input . flatten
292- handle_missing ( flat_input )
293- data = ::FFI ::MemoryPointer . new ( :double , input . count * input . first . count )
294- data . write_array_of_double ( flat_input )
295-
296- out_len = ::FFI ::MemoryPointer . new ( :int64 )
297- out_result = ::FFI ::MemoryPointer . new ( :double , n_preds )
298- check_result FFI . LGBM_BoosterPredictForMat ( handle_pointer , data , 1 , input . count , input . first . count , 1 , predict_type , start_iteration , num_iteration , params_str ( params ) , out_len , out_result )
299-
300- if n_preds != out_len . read_int64
301- raise Error , "Wrong length for predict results"
302- end
303-
304- preds = out_result . read_array_of_double ( out_len . read_int64 )
305-
306- [ preds , nrow , singular ]
307- end
308-
309- def num_preds ( start_iteration , num_iteration , nrow , predict_type )
310- out = ::FFI ::MemoryPointer . new ( :int64 )
311- check_result FFI . LGBM_BoosterCalcNumPredict ( handle_pointer , nrow , predict_type , start_iteration , num_iteration , out )
312- out . read_int64
313- end
314-
315- def sorted_feature_values ( input_hash )
316- input_hash . transform_keys ( &:to_s ) . fetch_values ( *cached_feature_name )
317- end
318-
319235 def cached_feature_name
320236 @cached_feature_name ||= feature_name
321237 end
0 commit comments