@@ -125,7 +125,7 @@ def get_experiment(self: Project, experiment_name: str, model) -> Dataset:
125125
126126# %% ../../nbs/project/experiments.ipynb 17
127127def find_git_root (
128- start_path : t .Union [str , Path , None ] = None , # starting path to search from
128+ start_path : t .Union [str , Path , None ] = None # starting path to search from
129129) -> Path :
130130 """Find the root directory of a git repository by traversing up from the start path."""
131131 # Start from the current directory if no path is provided
@@ -447,7 +447,13 @@ async def langfuse_wrapped_func(*args, **kwargs):
447447
448448
449449@patch
450- def mlflow_experiment (self : Project , experiment_model , name_prefix : str = "" ):
450+ def mlflow_experiment (
451+ self : Project ,
452+ experiment_model ,
453+ name_prefix : str = "" ,
454+ save_to_git : bool = True ,
455+ stage_all : bool = True ,
456+ ):
451457 """Decorator for creating experiment functions with mlflow integration.
452458
453459 Args:
@@ -459,31 +465,22 @@ def mlflow_experiment(self: Project, experiment_model, name_prefix: str = ""):
459465 """
460466
461467 def decorator (func : t .Callable ) -> ExperimentProtocol :
462- # First, create a base experiment wrapper
463- base_experiment = self .experiment (experiment_model , name_prefix )(func )
464468
465- # Override the wrapped function to add mlflow observation
466469 @wraps (func )
467- async def wrapped_with_mlflow (* args , ** kwargs ):
468- # wrap the function with mlflow observation
469- observed_func = trace (name = f"{ name_prefix } -{ func .__name__ } " )(func )
470+ async def mlflow_wrapped_func (* args , ** kwargs ):
471+ # Apply mlflow observation directly here
472+ trace_name = (
473+ f"{ name_prefix } -{ func .__name__ } " if name_prefix else func .__name__
474+ )
475+ observed_func = trace (name = trace_name )(func )
470476 return await observed_func (* args , ** kwargs )
471477
472- # Replace the async function to use mlflow
473- original_run_async = base_experiment .run_async
474-
475- # Use the original run_async but with the mlflow-wrapped function
476- async def run_async_with_mlflow (dataset : Dataset , name : t .Optional [str ] = None ):
477- # Override the internal wrapped_experiment with our mlflow version
478- base_experiment .__wrapped__ = wrapped_with_mlflow
479-
480- # Call the original run_async which will now use our mlflow-wrapped function
481- return await original_run_async (dataset , name )
482-
483- # Replace the run_async method
484- base_experiment .__setattr__ ("run_async" , run_async_with_mlflow )
478+ # Now create the experiment wrapper with our already-observed function
479+ experiment_wrapper = self .experiment (
480+ experiment_model , name_prefix , save_to_git , stage_all
481+ )(mlflow_wrapped_func )
485482
486- return t .cast (ExperimentProtocol , base_experiment )
483+ return t .cast (ExperimentProtocol , experiment_wrapper )
487484
488485 return decorator
489486
0 commit comments