Skip to content

Commit c77dfed

Browse files
authored
Merge pull request #127 from slaclab/mlflow-floats
Adjust logging to support non-torch custom models
2 parents c9bd815 + 464dbb9 commit c77dfed

File tree

4 files changed

+34
-48
lines changed

4 files changed

+34
-48
lines changed

lume_model/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import yaml
1010
import numpy as np
1111
from pydantic import BaseModel, ConfigDict, field_validator
12-
import torch # TODO: for torch.Tensor type hinting, but may need to make more general in mlflow class
1312

1413
from lume_model.variables import ScalarVariable, get_variable, ConfigEnum
1514
from lume_model.utils import (
@@ -456,7 +455,6 @@ def from_yaml(cls, yaml_obj: [str, TextIOWrapper]):
456455

457456
def register_to_mlflow(
458457
self,
459-
input_dict: dict[str, Union[float, torch.Tensor]],
460458
artifact_path: str,
461459
registered_model_name: str | None = None,
462460
tags: dict[str, Any] | None = None,
@@ -477,7 +475,6 @@ def register_to_mlflow(
477475
https://mlflow.org/docs/latest/getting-started/intro-quickstart/ for more info.
478476
479477
Args:
480-
input_dict: Input dictionary to infer the model signature.
481478
artifact_path: Path to store the model in MLflow.
482479
registered_model_name: Name of the registered model in MLflow. Optional.
483480
tags: Tags to add to the MLflow model. Optional.
@@ -493,7 +490,6 @@ def register_to_mlflow(
493490
"""
494491
return register_model(
495492
self,
496-
input_dict,
497493
artifact_path,
498494
registered_model_name,
499495
tags,

lume_model/mlflow_utils.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
22
import warnings
3-
from typing import Any, Union
3+
from typing import Any
44
from contextlib import nullcontext
55

6-
from torch import Tensor, nn
6+
from torch import nn
77

88
try:
99
import mlflow
@@ -20,7 +20,6 @@
2020

2121
def register_model(
2222
lume_model,
23-
input: dict[str, Union[float, Tensor]] | Tensor,
2423
artifact_path: str,
2524
registered_model_name: str | None = None,
2625
tags: dict[str, Any] | None = None,
@@ -40,9 +39,10 @@ def register_model(
4039
a tracking server, set the environment variable MLFLOW_TRACKING_URI, e.g. a local port/path. See
4140
https://mlflow.org/docs/latest/getting-started/intro-quickstart/ for more info.
4241
42+
Note that at the moment, this does not log artifacts for custom models other than the YAML dump file.
43+
4344
Args:
4445
lume_model: LumeModel to register.
45-
input: Input dictionary to infer the model signature.
4646
artifact_path: Path to store the model in MLflow.
4747
registered_model_name: Name of the registered model in MLflow.
4848
tags: Tags to add to the MLflow model.
@@ -70,30 +70,19 @@ def register_model(
7070
else nullcontext()
7171
)
7272
with ctx:
73-
# Define the signature of the model
7473
if isinstance(lume_model, nn.Module):
75-
signature = mlflow.models.infer_signature(
76-
input.numpy(), lume_model(Tensor(input)).detach().numpy()
77-
)
7874
model_info = mlflow.pytorch.log_model(
7975
pytorch_model=lume_model,
8076
artifact_path=artifact_path,
81-
signature=signature,
8277
registered_model_name=registered_model_name,
8378
**kwargs,
8479
)
8580
else:
8681
# Create pyfunc model for MLflow to be able to log/load the model
8782
pf_model = create_mlflow_model(lume_model)
88-
# Adjust the input to match the expected input format
89-
# Must be one of `numpy.ndarray`, `List[numpy.ndarray]`, `Dict[str, numpy.ndarray]` or `pandas.DataFrame`
90-
input = {key: value.numpy() for key, value in input.items()}
91-
signature = mlflow.models.infer_signature(input, pf_model.predict(input))
9283
model_info = mlflow.pyfunc.log_model(
9384
python_model=pf_model,
9485
artifact_path=artifact_path,
95-
signature=signature,
96-
input_example=input,
9786
registered_model_name=registered_model_name,
9887
**kwargs,
9988
)
@@ -106,23 +95,35 @@ def register_model(
10695

10796
lume_model.dump(f"{name}.yml", save_jit=save_jit)
10897
mlflow.log_artifact(f"{name}.yml", artifact_path)
109-
mlflow.log_artifact(f"{name}_model.pt", artifact_path)
11098
os.remove(f"{name}.yml")
111-
os.remove(f"{name}_model.pt")
112-
if save_jit:
113-
mlflow.log_artifact(f"{name}_model.jit", artifact_path)
114-
os.remove(f"{name}_model.jit")
115-
116-
# Get and log the input and output transformers
117-
lume_model = (
118-
lume_model._model if isinstance(lume_model, nn.Module) else lume_model
119-
)
120-
for i in range(len(lume_model.input_transformers)):
121-
mlflow.log_artifact(f"{name}_input_transformers_{i}.pt", artifact_path)
122-
os.remove(f"{name}_input_transformers_{i}.pt")
123-
for i in range(len(lume_model.output_transformers)):
124-
mlflow.log_artifact(f"{name}_output_transformers_{i}.pt", artifact_path)
125-
os.remove(f"{name}_output_transformers_{i}.pt")
99+
100+
from lume_model.models import registered_models
101+
102+
if type(lume_model) in registered_models:
103+
# all registered models are torch models at the moment
104+
# may change in the future
105+
mlflow.log_artifact(f"{name}_model.pt", artifact_path)
106+
os.remove(f"{name}_model.pt")
107+
if save_jit:
108+
mlflow.log_artifact(f"{name}_model.jit", artifact_path)
109+
os.remove(f"{name}_model.jit")
110+
111+
# Get and log the input and output transformers
112+
lume_model = (
113+
lume_model._model
114+
if isinstance(lume_model, nn.Module)
115+
else lume_model
116+
)
117+
for i in range(len(lume_model.input_transformers)):
118+
mlflow.log_artifact(
119+
f"{name}_input_transformers_{i}.pt", artifact_path
120+
)
121+
os.remove(f"{name}_input_transformers_{i}.pt")
122+
for i in range(len(lume_model.output_transformers)):
123+
mlflow.log_artifact(
124+
f"{name}_output_transformers_{i}.pt", artifact_path
125+
)
126+
os.remove(f"{name}_output_transformers_{i}.pt")
126127

127128
if (tags or alias or version_tags) and registered_model_name:
128129
from mlflow import MlflowClient
@@ -167,18 +168,13 @@ class PyFuncModel(mlflow.pyfunc.PythonModel):
167168

168169
# Disable type hint validation for the predict method to avoid annoying warnings
169170
# since we have type validation in the lume-model itself.
170-
# If we need to implement this, this may be helpful:
171-
# g
172171
_skip_type_hint_validation = True
173172

174173
def __init__(self, model):
175174
self.model = model
176175

177176
def predict(self, model_input):
178177
"""Evaluate the model with the given input."""
179-
# Convert input to the format expected by the model
180-
# TODO: this isn't very general but type validation in torch modules requires this. May need to adjust.
181-
model_input = {key: Tensor(value) for key, value in model_input.items()}
182178
return self.model.evaluate(model_input)
183179

184180
def save_model(self):

lume_model/models/torch_module.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def _validate_input(x: torch.Tensor) -> torch.Tensor:
198198

199199
def register_to_mlflow(
200200
self,
201-
input: torch.Tensor,
202201
artifact_path: str,
203202
registered_model_name: str | None = None,
204203
tags: dict[str, Any] | None = None,
@@ -219,7 +218,6 @@ def register_to_mlflow(
219218
https://mlflow.org/docs/latest/getting-started/intro-quickstart/ for more info.
220219
221220
Args:
222-
input: Input tensor to infer the model signature.
223221
artifact_path: Path to store the model in MLflow.
224222
registered_model_name: Name of the registered model in MLflow. Optional.
225223
tags: Tags to add to the MLflow model. Optional.
@@ -235,7 +233,6 @@ def register_to_mlflow(
235233
"""
236234
return register_model(
237235
self,
238-
input,
239236
artifact_path,
240237
registered_model_name,
241238
tags,

pyproject.toml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,16 @@ dependencies = [
2424
"pydantic",
2525
"numpy",
2626
"pyyaml",
27+
"torch",
28+
"botorch",
2729
"mlflow"
2830
]
2931
dynamic = ["version"]
3032
[tool.setuptools_scm]
3133
version_file = "lume_model/_version.py"
3234

3335
[project.optional-dependencies]
34-
torch = [
35-
"botorch",
36-
"torch"
37-
]
3836
dev = [
39-
"lume-model[torch]",
4037
"pre-commit",
4138
"pytest"
4239
]

0 commit comments

Comments
 (0)