Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 65 additions & 15 deletions h2o-algos/src/main/java/hex/api/MakeGLMModelHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,38 +56,88 @@ public GLMModelV3 make_model(int version, MakeGLMModelV3 args){

public GLMModelV3 make_unrestricted_model(int version, MakeUnrestrictedGLMModelV3 args){
GLMModel model = DKV.getGet(args.model.key());
if(model == null)
if (model == null)
throw new IllegalArgumentException("Missing source model " + args.model);
if(model._parms._control_variables == null){
throw new IllegalArgumentException("Source model is not trained with control variables.");
if (model._parms._control_variables == null && !model._parms._remove_offset_effects) {
throw new IllegalArgumentException("Source model is not trained with control variables or remove offset effects.");
}
Key generatedKey = Key.make(model._key.toString()+"_unrestricted_model");
Key generatedKey;
if (args.control_variables_enabled && args.remove_offset_effects_enabled) {
throw new IllegalArgumentException("The control_variables_enabled and remove_offset_effects_enabled feature " +
"cannot be used together. It produces the same model as the main model.");
} else if((args.control_variables_enabled || args.remove_offset_effects_enabled) &&
(model._parms._control_variables == null || !model._parms._remove_offset_effects)) {
throw new IllegalArgumentException("You can set control_variables_enabled to true or " +
"remove_offset_effects_enabled to true only if control_variables and remove_offset_effects are both set.");
} else if (args.remove_offset_effects_enabled) {
generatedKey = Key.make(model._key.toString() + "_remove_offset_effects_enabled");
} else if (args.control_variables_enabled) {
generatedKey = Key.make(model._key.toString() + "_control_variables_enabled");
} else {
generatedKey = Key.make(model._key.toString()+"_unrestricted_model");
}
Key key = args.dest != null ? Key.make(args.dest) : generatedKey;
GLMModel modelContrVars = DKV.getGet(key);
if(modelContrVars != null) {
GLMModel modelUnrestricted = DKV.getGet(key);
if (modelUnrestricted != null) {
throw new IllegalArgumentException("Model with "+key+" already exists.");
}
GLMModel.GLMParameters parms = (GLMModel.GLMParameters) model._parms.clone();
GLMModel.GLMParameters inputParms = (GLMModel.GLMParameters) model._input_parms.clone();
GLMModel m = new GLMModel(key, parms,null, model._ymu,
Double.NaN, Double.NaN, -1);
m.setInputParms(inputParms);
m._input_parms._control_variables = null;
m._parms._control_variables = null;
if (args.control_variables_enabled){
m._input_parms._control_variables = model._parms._control_variables;
m._parms._control_variables = model._parms._control_variables;
m._input_parms._remove_offset_effects = false;
m._parms._remove_offset_effects = false;
} else if(args.remove_offset_effects_enabled){
m._input_parms._remove_offset_effects = true;
m._parms._remove_offset_effects = true;
m._input_parms._control_variables = null;
m._parms._control_variables = null;
} else {
m._input_parms._control_variables = null;
m._parms._control_variables = null;
m._input_parms._remove_offset_effects = false;
m._parms._remove_offset_effects = false;
}
DataInfo dinfo = model.dinfo();
dinfo.setPredictorTransform(TransformType.NONE);
m._output = new GLMOutput(model.dinfo(), model._output._names, model._output._column_types, model._output._domains,
model._output.coefficientNames(), model._output.beta(), model._output._binomial, model._output._multinomial,
model._output._ordinal, null);
ModelMetrics mt = model._output._training_metrics_unrestricted_model;
ModelMetrics mv = model._output._validation_metrics_unrestricted_model;
m._output._training_metrics = mt;
m._output._validation_metrics = mv;
m._output._scoring_history = model._output._scoring_history_unrestricted_model;
if (args.control_variables_enabled) {
ModelMetrics mt = model._output._training_metrics_restricted_model_cv;
ModelMetrics mv = model._output._validation_metrics_restricted_model_cv;
m._output._training_metrics = mt;
m._output._validation_metrics = mv;
m._output._scoring_history = model._output._scoring_history_restricted_model_cv;
m.resetThreshold(model.defaultThreshold());
m._output._variable_importances = model._output._variable_importances;
m._output.setAndMapControlVariablesNames(model._parms._control_variables);
} else if (args.remove_offset_effects_enabled) {
ModelMetrics mt = model._output._training_metrics_restricted_model_ro;
ModelMetrics mv = model._output._validation_metrics_restricted_model_ro;
m._output._training_metrics = mt;
m._output._validation_metrics = mv;
m._output._scoring_history = model._output._scoring_history_restricted_model_ro;
m.resetThreshold(model.defaultThreshold());
m._output._variable_importances = model._output._variable_importances_unrestricted_model;
} else {
ModelMetrics mt = model._output._training_metrics_unrestricted_model;
ModelMetrics mv = model._output._validation_metrics_unrestricted_model;
m._output._training_metrics = mt;
m._output._validation_metrics = mv;
m._output._scoring_history = model._output._scoring_history_unrestricted_model;
m.resetThreshold(model.defaultThreshold());
m._output._variable_importances = model._output._variable_importances_unrestricted_model;
}
m._output._model_summary = model._output._model_summary;
m.resetThreshold(model.defaultThreshold());
m._output._variable_importances = model._output._variable_importances_unrestricted_model;
m._key = key;
// setting these flags is important for right scoring
m._useControlVariables = args.control_variables_enabled;
m._useRemoveOffsetEffects = args.remove_offset_effects_enabled;

DKV.put(key, m);
GLMModelV3 res = new GLMModelV3();
Expand Down
Loading
Loading