@@ -525,7 +525,7 @@ def annotate_with_mlflow_model(
525525 Default batch size is 1, but it is still being sent as a list for consistency.
526526 log_to_field: Field to store the resulting annotations in.
527527 """
528- res = self .predict_with_mlflow_model (
528+ res = self ._infer_with_mlflow_on_files (
529529 repo ,
530530 name ,
531531 host = host ,
@@ -534,6 +534,7 @@ def annotate_with_mlflow_model(
534534 post_hook = post_hook ,
535535 batch_size = batch_size ,
536536 log_to_field = log_to_field ,
537+ is_prediction = False ,
537538 )
538539 self .datasource .metadata_field (log_to_field ).set_annotation ().apply ()
539540 return res
@@ -578,6 +579,62 @@ def predict_with_mlflow_model(
578579 Default batch size is 1, but it is still being sent as a list for consistency.
579580 log_to_field: If set, writes prediction results to this metadata field in the datasource.
580581 """
582+ res = self ._infer_with_mlflow_on_files (
583+ repo ,
584+ name ,
585+ host = host ,
586+ version = version ,
587+ pre_hook = pre_hook ,
588+ post_hook = post_hook ,
589+ batch_size = batch_size ,
590+ log_to_field = log_to_field ,
591+ is_prediction = True ,
592+ )
593+ return res
594+
595+ def _infer_with_mlflow_on_files (
596+ self ,
597+ repo : str ,
598+ name : str ,
599+ host : Optional [str ] = None ,
600+ version : str = "latest" ,
601+ pre_hook : Callable [[List [str ]], Any ] = identity_func ,
602+ post_hook : Callable [[Any ], Any ] = identity_func ,
603+ batch_size : int = 1 ,
604+ log_to_field : Optional [str ] = None ,
605+ is_prediction : bool = True ,
606+ ) -> Dict [str , Any ]:
607+ """
608+ Fetch an MLflow model from a specific repository and use it to predict on the datapoints in this QueryResult.
609+
610+ Any MLflow model that has a ``model.predict`` endpoint is supported.
611+ This includes, but is not limited to the following flavors:
612+
613+ * ``torch``
614+ * ``tensorflow``
615+ * ``pyfunc``
616+ * ``scikit-learn``
617+
618+ Keep in mind that by default ``mlflow.predict()`` will receive the list of downloaded datapoint paths as input,
619+ so additional "massaging" of the data might be required for prediction to work.
620+ Use the ``pre_hook`` function to do so.
621+
622+ Args:
623+ repo: repository to extract the model from
624+
625+ name: name of the model in the repository's MLflow registry.
626+ host: address of the DagsHub instance with the repo to load the model from.
627+ Set it if the model is hosted on a different DagsHub instance than the datasource.
628+ version: version of the model in the mlflow registry.
629+ pre_hook: function that runs before datapoints are sent to ``model.predict()``.
630+ The input argument is the list of paths to datapoint files in the current batch.
631+ post_hook: function that converts the model output to the desired format.
632+ batch_size: Size of the file batches that are sent to ``model.predict()``.
633+ Default batch size is 1, but it is still being sent as a list for consistency.
634+ log_to_field: If set, writes prediction results to this metadata field in the datasource.
635+ is_prediction: If True, log as a prediction (will need to be manually approved as annotation).
636+ If False, will be automatically approved as an annotation
637+ """
581638 if not host :
582639 host = self .datasource .source .repoApi .host
583640
@@ -602,7 +659,9 @@ def predict_with_mlflow_model(
602659 if "torch" in loader_module :
603660 model .predict = model .__call__
604661
605- return self .generate_predictions (lambda x : post_hook (model .predict (pre_hook (x ))), batch_size , log_to_field )
662+ return self .generate_predictions (
663+ lambda x : post_hook (model .predict (pre_hook (x ))), batch_size , log_to_field , is_prediction
664+ )
606665
607666 def get_annotations (self , ** kwargs ) -> "QueryResult" :
608667 """
@@ -860,8 +919,9 @@ def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset":
860919 return ds
861920
862921 @staticmethod
863- def _get_predict_dict (predictions , remote_path , log_to_field ):
864- res = {log_to_field : json .dumps (predictions [remote_path ][0 ]).encode ("utf-8" )}
922+ def _get_predict_dict (predictions , remote_path , log_to_field , is_prediction = False ):
923+ ls_json_key = "annotations" if not is_prediction else "predictions"
924+ res = {log_to_field : json .dumps ({ls_json_key : [predictions [remote_path ][0 ]]}).encode ("utf-8" )}
865925 if len (predictions [remote_path ]) == 2 :
866926 res [f"{ log_to_field } _score" ] = predictions [remote_path ][1 ]
867927
@@ -934,6 +994,7 @@ def generate_predictions(
934994 predict_fn : CustomPredictor ,
935995 batch_size : int = 1 ,
936996 log_to_field : Optional [str ] = None ,
997+ is_prediction : Optional [bool ] = False ,
937998 ) -> Dict [str , Tuple [str , Optional [float ]]]:
938999 """
9391000 Sends all the datapoints returned in this QueryResult as prediction targets for
@@ -943,6 +1004,7 @@ def generate_predictions(
9431004 predict_fn: function that handles batched input and returns predictions with an optional prediction score.
9441005 batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
9451006 log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
1007+ is_prediction: (optional, default: False) whether we're creating predictions or annotations.
9461008 If None, just returns predictions.
9471009 (in addition to logging to a field, iff that parameter is set)
9481010 """
@@ -955,8 +1017,12 @@ def generate_predictions(
9551017 for idx , local_paths in enumerate (
9561018 _Batcher (dset , batch_size ) if batch_size != 1 else dset
9571019 ): # encapsulates dataset with batcher if necessary and iterates over it
1020+ prediction_result = predict_fn (local_paths )
1021+ # for single batches the result is for one datapoint, so it needs to be wrapped in a list
1022+ if batch_size == 1 :
1023+ prediction_result = [prediction_result ]
9581024 for prediction , remote_path in zip (
959- predict_fn ( local_paths ) ,
1025+ prediction_result ,
9601026 [result .path for result in self [idx * batch_size : (idx + 1 ) * batch_size ]],
9611027 ):
9621028 predictions [remote_path ] = prediction
@@ -965,7 +1031,9 @@ def generate_predictions(
9651031 if log_to_field :
9661032 with self .datasource .metadata_context () as ctx :
9671033 for remote_path in predictions :
968- ctx .update_metadata (remote_path , self ._get_predict_dict (predictions , remote_path , log_to_field ))
1034+ ctx .update_metadata (
1035+ remote_path , self ._get_predict_dict (predictions , remote_path , log_to_field , is_prediction )
1036+ )
9691037 return predictions
9701038
9711039 def generate_annotations (self , predict_fn : CustomPredictor , batch_size : int = 1 , log_to_field : str = "annotation" ):
@@ -978,11 +1046,7 @@ def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1,
9781046 batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously.
9791047 log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
9801048 """
981- self .generate_predictions (
982- predict_fn ,
983- batch_size = batch_size ,
984- log_to_field = log_to_field ,
985- )
1049+ self .generate_predictions (predict_fn , batch_size = batch_size , log_to_field = log_to_field , is_prediction = False )
9861050 self .datasource .metadata_field (log_to_field ).set_annotation ().apply ()
9871051
9881052 def annotate (
0 commit comments