Skip to content

Commit 7ffb48a

Browse files
authored
Merge pull request #119 from slaclab/fix-gp-input
Adjust inputs order for prob models/gps, add mlflow to main requs
2 parents 6bef2a8 + bc6a6f8 commit 7ffb48a

File tree

7 files changed

+22
-3
lines changed

7 files changed

+22
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ LUME-model holds data structures used in the LUME modeling toolset. Variables an
88
* pydantic
99
* numpy
1010
* pyyaml
11+
* mlflow
1112

1213
## Install
1314

dev-environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies:
88
- numpy
99
- pyyaml
1010
- botorch
11+
- mlflow
1112

1213
# dev requirements
1314
- pytest

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ dependencies:
77
- pydantic>2.3
88
- numpy
99
- pyyaml
10+
- mlflow

lume_model/models/gp_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def _get_predictions(
142142
Returns:
143143
Dictionary of output variable names to distributions.
144144
"""
145+
# Reorder the input dictionary to match the model's input order
146+
input_dict = super()._arrange_inputs(input_dict)
145147
# Create tensor from input_dict
146148
x = super()._create_tensor_from_dict(input_dict)
147149
# Transform the input

lume_model/models/prob_model_base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,19 @@ def dtype(
5858
f"expected one of ['double', 'single']."
5959
)
6060

61+
def _arrange_inputs(
62+
self, d: dict[str, Union[float, torch.Tensor]]
63+
) -> dict[str, Union[float, torch.Tensor]]:
64+
"""Enforces order of input variables before creating a tensor.
65+
66+
Args:
67+
d: Dictionary of input variable names to tensors.
68+
69+
Returns:
70+
Ordered input tensor.
71+
"""
72+
return {k: d[k] for k in self.input_names}
73+
6174
@staticmethod
6275
def _create_tensor_from_dict(
6376
d: dict[str, Union[float, torch.Tensor]],

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ requires-python = ">=3.10"
2323
dependencies = [
2424
"pydantic",
2525
"numpy",
26-
"pyyaml"
26+
"pyyaml",
27+
"mlflow"
2728
]
2829
dynamic = ["version"]
2930
[tool.setuptools_scm]
@@ -34,8 +35,7 @@ dev = [
3435
"botorch",
3536
"torch",
3637
"pre-commit",
37-
"pytest",
38-
"mlflow"
38+
"pytest"
3939
]
4040
docs = [
4141
"mkdocs",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pydantic
22
numpy
33
pyyaml
4+
mlflow

0 commit comments

Comments
 (0)