diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java index 1da90ae9..aeacccce 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java @@ -25,8 +25,6 @@ public class XGBoostRawJsonParser implements LtrRankerParser { public static final String TYPE = "model/xgboost+json+raw"; - private static final Integer MISSING_NODE_ID = Integer.MAX_VALUE; - @Override public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { XGBoostRawJsonParser.XGBoostDefinition modelDefinition; @@ -439,8 +437,16 @@ private NaiveAdditiveDecisionTree.Node asLibTree(Integer nodeId) { } if (isSplit(nodeId)) { - return new NaiveAdditiveDecisionTree.Split(asLibTree(leftChildren.get(nodeId)), asLibTree(rightChildren.get(nodeId)), - splitIndices.get(nodeId), splitConditions.get(nodeId), splitIndices.get(nodeId), MISSING_NODE_ID); + Integer missingNodeId = + defaultLeft.get(nodeId) == 1 ? leftChildren.get(nodeId) : rightChildren.get(nodeId); + return new NaiveAdditiveDecisionTree.Split( + asLibTree(leftChildren.get(nodeId)), + asLibTree(rightChildren.get(nodeId)), + splitIndices.get(nodeId), + splitConditions.get(nodeId), + leftChildren.get(nodeId), + missingNodeId + ); } else { return new NaiveAdditiveDecisionTree.Leaf(baseWeights.get(nodeId)); }