11import os
22import warnings
3- from typing import Any , Union
3+ from typing import Any
44from contextlib import nullcontext
55
6- from torch import Tensor , nn
6+ from torch import nn
77
88try :
99 import mlflow
2020
2121def register_model (
2222 lume_model ,
23- input : dict [str , Union [float , Tensor ]] | Tensor ,
2423 artifact_path : str ,
2524 registered_model_name : str | None = None ,
2625 tags : dict [str , Any ] | None = None ,
@@ -40,9 +39,10 @@ def register_model(
4039 a tracking server, set the environment variable MLFLOW_TRACKING_URI, e.g. a local port/path. See
4140 https://mlflow.org/docs/latest/getting-started/intro-quickstart/ for more info.
4241
42+ Note that at the moment, this does not log artifacts for custom models other than the YAML dump file.
43+
4344 Args:
4445 lume_model: LumeModel to register.
45- input: Input dictionary to infer the model signature.
4646 artifact_path: Path to store the model in MLflow.
4747 registered_model_name: Name of the registered model in MLflow.
4848 tags: Tags to add to the MLflow model.
@@ -70,30 +70,19 @@ def register_model(
7070 else nullcontext ()
7171 )
7272 with ctx :
73- # Define the signature of the model
7473 if isinstance (lume_model , nn .Module ):
75- signature = mlflow .models .infer_signature (
76- input .numpy (), lume_model (Tensor (input )).detach ().numpy ()
77- )
7874 model_info = mlflow .pytorch .log_model (
7975 pytorch_model = lume_model ,
8076 artifact_path = artifact_path ,
81- signature = signature ,
8277 registered_model_name = registered_model_name ,
8378 ** kwargs ,
8479 )
8580 else :
8681 # Create pyfunc model for MLflow to be able to log/load the model
8782 pf_model = create_mlflow_model (lume_model )
88- # Adjust the input to match the expected input format
89- # Must be one of `numpy.ndarray`, `List[numpy.ndarray]`, `Dict[str, numpy.ndarray]` or `pandas.DataFrame`
90- input = {key : value .numpy () for key , value in input .items ()}
91- signature = mlflow .models .infer_signature (input , pf_model .predict (input ))
9283 model_info = mlflow .pyfunc .log_model (
9384 python_model = pf_model ,
9485 artifact_path = artifact_path ,
95- signature = signature ,
96- input_example = input ,
9786 registered_model_name = registered_model_name ,
9887 ** kwargs ,
9988 )
@@ -106,23 +95,35 @@ def register_model(
10695
10796 lume_model .dump (f"{ name } .yml" , save_jit = save_jit )
10897 mlflow .log_artifact (f"{ name } .yml" , artifact_path )
109- mlflow .log_artifact (f"{ name } _model.pt" , artifact_path )
11098 os .remove (f"{ name } .yml" )
111- os .remove (f"{ name } _model.pt" )
112- if save_jit :
113- mlflow .log_artifact (f"{ name } _model.jit" , artifact_path )
114- os .remove (f"{ name } _model.jit" )
115-
116- # Get and log the input and output transformers
117- lume_model = (
118- lume_model ._model if isinstance (lume_model , nn .Module ) else lume_model
119- )
120- for i in range (len (lume_model .input_transformers )):
121- mlflow .log_artifact (f"{ name } _input_transformers_{ i } .pt" , artifact_path )
122- os .remove (f"{ name } _input_transformers_{ i } .pt" )
123- for i in range (len (lume_model .output_transformers )):
124- mlflow .log_artifact (f"{ name } _output_transformers_{ i } .pt" , artifact_path )
125- os .remove (f"{ name } _output_transformers_{ i } .pt" )
99+
100+ from lume_model .models import registered_models
101+
102+ if type (lume_model ) in registered_models :
103+ # all registered models are torch models at the moment
104+ # may change in the future
105+ mlflow .log_artifact (f"{ name } _model.pt" , artifact_path )
106+ os .remove (f"{ name } _model.pt" )
107+ if save_jit :
108+ mlflow .log_artifact (f"{ name } _model.jit" , artifact_path )
109+ os .remove (f"{ name } _model.jit" )
110+
111+ # Get and log the input and output transformers
112+ lume_model = (
113+ lume_model ._model
114+ if isinstance (lume_model , nn .Module )
115+ else lume_model
116+ )
117+ for i in range (len (lume_model .input_transformers )):
118+ mlflow .log_artifact (
119+ f"{ name } _input_transformers_{ i } .pt" , artifact_path
120+ )
121+ os .remove (f"{ name } _input_transformers_{ i } .pt" )
122+ for i in range (len (lume_model .output_transformers )):
123+ mlflow .log_artifact (
124+ f"{ name } _output_transformers_{ i } .pt" , artifact_path
125+ )
126+ os .remove (f"{ name } _output_transformers_{ i } .pt" )
126127
127128 if (tags or alias or version_tags ) and registered_model_name :
128129 from mlflow import MlflowClient
@@ -167,18 +168,13 @@ class PyFuncModel(mlflow.pyfunc.PythonModel):
167168
168169 # Disable type hint validation for the predict method to avoid annoying warnings
169170 # since we have type validation in the lume-model itself.
170- # If we need to implement this, this may be helpful:
171- # g
172171 _skip_type_hint_validation = True
173172
174173 def __init__ (self , model ):
175174 self .model = model
176175
177176 def predict (self , model_input ):
178177 """Evaluate the model with the given input."""
179- # Convert input to the format expected by the model
180- # TODO: this isn't very general but type validation in torch modules requires this. May need to adjust.
181- model_input = {key : Tensor (value ) for key , value in model_input .items ()}
182178 return self .model .evaluate (model_input )
183179
184180 def save_model (self ):
0 commit comments