Skip to content

Commit 7e93bd1

Browse files
authored
Merge pull request #600 from DagsHub/fix-local-auto-labeling
Fix local auto labeling
2 parents 7b3c1a4 + 3cdcdca commit 7e93bd1

File tree

1 file changed

+75
-11
lines changed

1 file changed

+75
-11
lines changed

dagshub/data_engine/model/query_result.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def annotate_with_mlflow_model(
525525
Default batch size is 1, but it is still being sent as a list for consistency.
526526
log_to_field: Field to store the resulting annotations in.
527527
"""
528-
res = self.predict_with_mlflow_model(
528+
res = self._infer_with_mlflow_on_files(
529529
repo,
530530
name,
531531
host=host,
@@ -534,6 +534,7 @@ def annotate_with_mlflow_model(
534534
post_hook=post_hook,
535535
batch_size=batch_size,
536536
log_to_field=log_to_field,
537+
is_prediction=False,
537538
)
538539
self.datasource.metadata_field(log_to_field).set_annotation().apply()
539540
return res
@@ -578,6 +579,62 @@ def predict_with_mlflow_model(
578579
Default batch size is 1, but it is still being sent as a list for consistency.
579580
log_to_field: If set, writes prediction results to this metadata field in the datasource.
580581
"""
582+
res = self._infer_with_mlflow_on_files(
583+
repo,
584+
name,
585+
host=host,
586+
version=version,
587+
pre_hook=pre_hook,
588+
post_hook=post_hook,
589+
batch_size=batch_size,
590+
log_to_field=log_to_field,
591+
is_prediction=True,
592+
)
593+
return res
594+
595+
def _infer_with_mlflow_on_files(
596+
self,
597+
repo: str,
598+
name: str,
599+
host: Optional[str] = None,
600+
version: str = "latest",
601+
pre_hook: Callable[[List[str]], Any] = identity_func,
602+
post_hook: Callable[[Any], Any] = identity_func,
603+
batch_size: int = 1,
604+
log_to_field: Optional[str] = None,
605+
is_prediction: bool = True,
606+
) -> Dict[str, Any]:
607+
"""
608+
Fetch an MLflow model from a specific repository and use it to predict on the datapoints in this QueryResult.
609+
610+
Any MLflow model that has a ``model.predict`` endpoint is supported.
611+
This includes, but is not limited to the following flavors:
612+
613+
* ``torch``
614+
* ``tensorflow``
615+
* ``pyfunc``
616+
* ``scikit-learn``
617+
618+
Keep in mind that by default ``mlflow.predict()`` will receive the list of downloaded datapoint paths as input,
619+
so additional "massaging" of the data might be required for prediction to work.
620+
Use the ``pre_hook`` function to do so.
621+
622+
Args:
623+
repo: repository to extract the model from
624+
625+
name: name of the model in the repository's MLflow registry.
626+
host: address of the DagsHub instance with the repo to load the model from.
627+
Set it if the model is hosted on a different DagsHub instance than the datasource.
628+
version: version of the model in the mlflow registry.
629+
pre_hook: function that runs before datapoints are sent to ``model.predict()``.
630+
The input argument is the list of paths to datapoint files in the current batch.
631+
post_hook: function that converts the model output to the desired format.
632+
batch_size: Size of the file batches that are sent to ``model.predict()``.
633+
Default batch size is 1, but it is still being sent as a list for consistency.
634+
log_to_field: If set, writes prediction results to this metadata field in the datasource.
635+
is_prediction: If True, log as a prediction (will need to be manually approved as annotation).
636+
If False, will be automatically approved as an annotation
637+
"""
581638
if not host:
582639
host = self.datasource.source.repoApi.host
583640

@@ -602,7 +659,9 @@ def predict_with_mlflow_model(
602659
if "torch" in loader_module:
603660
model.predict = model.__call__
604661

605-
return self.generate_predictions(lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field)
662+
return self.generate_predictions(
663+
lambda x: post_hook(model.predict(pre_hook(x))), batch_size, log_to_field, is_prediction
664+
)
606665

607666
def get_annotations(self, **kwargs) -> "QueryResult":
608667
"""
@@ -860,8 +919,9 @@ def to_voxel51_dataset(self, **kwargs) -> "fo.Dataset":
860919
return ds
861920

862921
@staticmethod
863-
def _get_predict_dict(predictions, remote_path, log_to_field):
864-
res = {log_to_field: json.dumps(predictions[remote_path][0]).encode("utf-8")}
922+
def _get_predict_dict(predictions, remote_path, log_to_field, is_prediction=False):
923+
ls_json_key = "annotations" if not is_prediction else "predictions"
924+
res = {log_to_field: json.dumps({ls_json_key: [predictions[remote_path][0]]}).encode("utf-8")}
865925
if len(predictions[remote_path]) == 2:
866926
res[f"{log_to_field}_score"] = predictions[remote_path][1]
867927

@@ -934,6 +994,7 @@ def generate_predictions(
934994
predict_fn: CustomPredictor,
935995
batch_size: int = 1,
936996
log_to_field: Optional[str] = None,
997+
is_prediction: Optional[bool] = False,
937998
) -> Dict[str, Tuple[str, Optional[float]]]:
938999
"""
9391000
Sends all the datapoints returned in this QueryResult as prediction targets for
@@ -943,6 +1004,7 @@ def generate_predictions(
9431004
predict_fn: function that handles batched input and returns predictions with an optional prediction score.
9441005
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously
9451006
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
1007+
is_prediction: (optional, default: False) whether we're creating predictions or annotations.
9461008
If None, just returns predictions.
9471009
(in addition to logging to a field, iff that parameter is set)
9481010
"""
@@ -955,8 +1017,12 @@ def generate_predictions(
9551017
for idx, local_paths in enumerate(
9561018
_Batcher(dset, batch_size) if batch_size != 1 else dset
9571019
): # encapsulates dataset with batcher if necessary and iterates over it
1020+
prediction_result = predict_fn(local_paths)
1021+
# for single batches the result is for one datapoint, so it needs to be wrapped in a list
1022+
if batch_size == 1:
1023+
prediction_result = [prediction_result]
9581024
for prediction, remote_path in zip(
959-
predict_fn(local_paths),
1025+
prediction_result,
9601026
[result.path for result in self[idx * batch_size : (idx + 1) * batch_size]],
9611027
):
9621028
predictions[remote_path] = prediction
@@ -965,7 +1031,9 @@ def generate_predictions(
9651031
if log_to_field:
9661032
with self.datasource.metadata_context() as ctx:
9671033
for remote_path in predictions:
968-
ctx.update_metadata(remote_path, self._get_predict_dict(predictions, remote_path, log_to_field))
1034+
ctx.update_metadata(
1035+
remote_path, self._get_predict_dict(predictions, remote_path, log_to_field, is_prediction)
1036+
)
9691037
return predictions
9701038

9711039
def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1, log_to_field: str = "annotation"):
@@ -978,11 +1046,7 @@ def generate_annotations(self, predict_fn: CustomPredictor, batch_size: int = 1,
9781046
batch_size: (optional, default: 1) number of datapoints to run inference on simultaneously.
9791047
log_to_field: (optional, default: 'prediction') write prediction results to metadata logged in data engine.
9801048
"""
981-
self.generate_predictions(
982-
predict_fn,
983-
batch_size=batch_size,
984-
log_to_field=log_to_field,
985-
)
1049+
self.generate_predictions(predict_fn, batch_size=batch_size, log_to_field=log_to_field, is_prediction=False)
9861050
self.datasource.metadata_field(log_to_field).set_annotation().apply()
9871051

9881052
def annotate(

0 commit comments

Comments
 (0)