Skip to content

Commit 240025c

Browse files
committed
Added NGBClassifier and NGBRegressor to Encodable class hierarchy
1 parent 9e3bbf9 commit 240025c

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

pmml-sklearn-extension/src/main/java/ngboost/NGBClassifier.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.dmg.pmml.MiningFunction;
2828
import org.dmg.pmml.Model;
2929
import org.dmg.pmml.OpType;
30+
import org.dmg.pmml.PMML;
3031
import org.dmg.pmml.mining.MiningModel;
3132
import org.dmg.pmml.mining.Segmentation.MissingPredictionTreatment;
3233
import org.dmg.pmml.regression.RegressionModel;
@@ -40,16 +41,25 @@
4041
import org.jpmml.converter.Schema;
4142
import org.jpmml.converter.mining.MiningModelUtil;
4243
import org.jpmml.converter.regression.RegressionModelUtil;
44+
import org.jpmml.sklearn.Encodable;
45+
import org.jpmml.sklearn.HasSkLearnOptions;
4346
import org.jpmml.sklearn.SkLearnException;
4447
import sklearn.Classifier;
48+
import sklearn.EstimatorUtil;
49+
import sklearn.HasFeatureNamesIn;
4550
import sklearn.Regressor;
4651

47-
public class NGBClassifier extends Classifier {
52+
public class NGBClassifier extends Classifier implements HasFeatureNamesIn, HasSkLearnOptions, Encodable {
4853

4954
public NGBClassifier(String module, String name){
5055
super(module, name);
5156
}
5257

58+
@Override
59+
public int getNumberOfFeatures(){
60+
return getInteger("n_features");
61+
}
62+
5363
@Override
5464
public List<?> getClasses(){
5565
Integer k = getK();
@@ -139,6 +149,11 @@ public MiningModel encodeCategoricalModel(Schema schema){
139149
}
140150
}
141151

152+
@Override
153+
public PMML encodePMML(){
154+
return EstimatorUtil.encodePMML(this);
155+
}
156+
142157
public List<List<Regressor>> getBaseModels(){
143158
return NGBoostUtil.getBaseModels(this);
144159
}

pmml-sklearn-extension/src/main/java/ngboost/NGBRegressor.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.dmg.pmml.OpType;
3838
import org.dmg.pmml.Output;
3939
import org.dmg.pmml.OutputField;
40+
import org.dmg.pmml.PMML;
4041
import org.dmg.pmml.PMMLFunctions;
4142
import org.dmg.pmml.ParameterField;
4243
import org.dmg.pmml.ResultFeature;
@@ -52,16 +53,24 @@
5253
import org.jpmml.converter.mining.MiningModelUtil;
5354
import org.jpmml.python.ClassDictConstructorUtil;
5455
import org.jpmml.python.ClassDictUtil;
56+
import org.jpmml.sklearn.Encodable;
5557
import org.jpmml.sklearn.SkLearnException;
58+
import sklearn.EstimatorUtil;
59+
import sklearn.HasFeatureNamesIn;
5660
import sklearn.HasRegressorOptions;
5761
import sklearn.Regressor;
5862

59-
public class NGBRegressor extends Regressor implements HasRegressorOptions {
63+
public class NGBRegressor extends Regressor implements HasFeatureNamesIn, HasRegressorOptions, Encodable {
6064

6165
public NGBRegressor(String module, String name){
6266
super(module, name);
6367
}
6468

69+
@Override
70+
public int getNumberOfFeatures(){
71+
return getInteger("n_features");
72+
}
73+
6574
@Override
6675
public MiningModel encodeModel(Schema schema){
6776
String distName = getDistName();
@@ -198,6 +207,11 @@ public MiningModel encodePoissonModel(Schema schema){
198207
return MiningModelUtil.createModelChain(Arrays.asList(locModel, regressionModel), Segmentation.MissingPredictionTreatment.RETURN_MISSING);
199208
}
200209

210+
@Override
211+
public PMML encodePMML(){
212+
return EstimatorUtil.encodePMML(this);
213+
}
214+
201215
public List<List<Regressor>> getBaseModels(){
202216
return NGBoostUtil.getBaseModels(this);
203217
}

0 commit comments

Comments
 (0)