Skip to content

Commit d326a3d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 04ee1a8 commit d326a3d

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

ml/train_model.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def normalize(df, input_names, input_transform, output_names, output_transform):
129129
)
130130
return norm_df, norm_exp_inputs, norm_exp_outputs, norm_sim_inputs, norm_sim_outputs
131131

132+
132133
def split_data(df_exp, df_sim, variables, model_type):
133134
if model_type == "GP":
134135
if len(df_exp) > 0:
@@ -151,6 +152,7 @@ def split_data(df_exp, df_sim, variables, model_type):
151152
else:
152153
return (sim_train_df[variables], sim_val_df[variables])
153154

155+
154156
def build_transforms(n_inputs, X_data, n_outputs, y_data):
155157
input_transform = AffineInputTransform(
156158
len(input_names), coefficient=X_train.std(axis=0), offset=X_train.mean(axis=0)
@@ -161,6 +163,7 @@ def build_transforms(n_inputs, X_data, n_outputs, y_data):
161163
output_transform = AffineInputTransform(n_outputs, coefficient=y_std, offset=y_mean)
162164
return input_transform, output_transform
163165

166+
164167
def train_nn_ensemble(
165168
model_type,
166169
n_inputs,
@@ -257,10 +260,10 @@ def build_torch_model_from_nn(
257260
],
258261
)
259262

260-
def train_gp(
261-
norm_df_train, input_names, output_names,
262-
input_transform, output_transform, device):
263263

264+
def train_gp(
265+
norm_df_train, input_names, output_names, input_transform, output_transform, device
266+
):
264267
gp_models = []
265268

266269
for i, output_name in enumerate(output_names):
@@ -345,19 +348,24 @@ def train_gp(
345348
]
346349

347350
return GPModel(
348-
model=model.cpu(),
349-
input_variables=[ScalarVariable(**input_variables[k]) for k in input_variables.keys()],
350-
output_variables=output_variables,
351-
input_transform=[input_transform],
352-
output_transform=[output_transform])
351+
model=model.cpu(),
352+
input_variables=[
353+
ScalarVariable(**input_variables[k]) for k in input_variables.keys()
354+
],
355+
output_variables=output_variables,
356+
input_transform=[input_transform],
357+
output_transform=[output_transform],
358+
)
353359

354360

355361
def write_model(model, model_type, experiment, db):
356362
with tempfile.TemporaryDirectory() as temp_dir:
357363
if model_type != "GP":
358364
model.dump(file=os.path.join(temp_dir, experiment + ".yml"), save_jit=True)
359365
else:
360-
model.dump(file=os.path.join(temp_dir, experiment + ".yml"), save_models=True)
366+
model.dump(
367+
file=os.path.join(temp_dir, experiment + ".yml"), save_models=True
368+
)
361369
# Upload the model to the database
362370
# - Load the files that were just created into a dictionary
363371
with open(os.path.join(temp_dir, experiment + ".yml")) as f:
@@ -412,6 +420,7 @@ def write_model(model, model_type, experiment, db):
412420
db["models"].insert_one(document)
413421
print("Model uploaded to database")
414422

423+
415424
experiment, model_type = parse_arguments()
416425
config_dict = load_config(experiment)
417426
db = connect_to_db(config_dict)
@@ -513,7 +522,13 @@ def write_model(model, model_type, experiment, db):
513522
else:
514523
# Create separate GP models for each output to handle NaN values
515524

516-
model = train_gp(norm_df_train, input_names, output_names,
517-
input_transform, output_transform, device)
525+
model = train_gp(
526+
norm_df_train,
527+
input_names,
528+
output_names,
529+
input_transform,
530+
output_transform,
531+
device,
532+
)
518533

519534
write_model(model, model_type, experiment, db)

0 commit comments

Comments
 (0)