Skip to content

Commit 0f36e62

Browse files
committed
add git support for mlflow
1 parent b3a305b commit 0f36e62

File tree

2 files changed

+31
-48
lines changed

2 files changed

+31
-48
lines changed

nbs/project/experiments.ipynb

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@
909909
"\n",
910910
"@patch\n",
911911
"def mlflow_experiment(\n",
912-
" self: Project, experiment_model, name_prefix: str = \"\"\n",
912+
" self: Project, experiment_model, name_prefix: str = \"\",save_to_git: bool = True, stage_all: bool = True\n",
913913
"):\n",
914914
" \"\"\"Decorator for creating experiment functions with mlflow integration.\n",
915915
"\n",
@@ -922,33 +922,19 @@
922922
" \"\"\"\n",
923923
"\n",
924924
" def decorator(func: t.Callable) -> ExperimentProtocol:\n",
925-
" # First, create a base experiment wrapper\n",
926-
" base_experiment = self.experiment(experiment_model, name_prefix)(func)\n",
927-
"\n",
928-
" # Override the wrapped function to add mlflow observation\n",
925+
" \n",
929926
" @wraps(func)\n",
930-
" async def wrapped_with_mlflow(*args, **kwargs):\n",
931-
" # wrap the function with mlflow observation\n",
932-
" observed_func = trace(name=f\"{name_prefix}-{func.__name__}\")(func)\n",
927+
" async def mlflow_wrapped_func(*args, **kwargs):\n",
928+
" # Apply mlflow observation directly here\n",
929+
" trace_name = f\"{name_prefix}-{func.__name__}\" if name_prefix else func.__name__\n",
930+
" observed_func = trace(name=trace_name)(func)\n",
933931
" return await observed_func(*args, **kwargs)\n",
934-
"\n",
935-
" # Replace the async function to use mlflow\n",
936-
" original_run_async = base_experiment.run_async\n",
937-
"\n",
938-
" # Use the original run_async but with the mlflow-wrapped function\n",
939-
" async def run_async_with_mlflow(\n",
940-
" dataset: Dataset, name: t.Optional[str] = None\n",
941-
" ):\n",
942-
" # Override the internal wrapped_experiment with our mlflow version\n",
943-
" base_experiment.__wrapped__ = wrapped_with_mlflow\n",
944-
"\n",
945-
" # Call the original run_async which will now use our mlflow-wrapped function\n",
946-
" return await original_run_async(dataset, name)\n",
947-
"\n",
948-
" # Replace the run_async method\n",
949-
" base_experiment.__setattr__(\"run_async\", run_async_with_mlflow)\n",
950-
"\n",
951-
" return t.cast(ExperimentProtocol, base_experiment)\n",
932+
" \n",
933+
" # Now create the experiment wrapper with our already-observed function\n",
934+
" experiment_wrapper = self.experiment(experiment_model, name_prefix, save_to_git, stage_all)(mlflow_wrapped_func)\n",
935+
" \n",
936+
" return t.cast(ExperimentProtocol, experiment_wrapper)\n",
937+
" \n",
952938
"\n",
953939
" return decorator"
954940
]

ragas_experimental/project/experiments.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def get_experiment(self: Project, experiment_name: str, model) -> Dataset:
125125

126126
# %% ../../nbs/project/experiments.ipynb 17
127127
def find_git_root(
128-
start_path: t.Union[str, Path, None] = None, # starting path to search from
128+
start_path: t.Union[str, Path, None] = None # starting path to search from
129129
) -> Path:
130130
"""Find the root directory of a git repository by traversing up from the start path."""
131131
# Start from the current directory if no path is provided
@@ -447,7 +447,13 @@ async def langfuse_wrapped_func(*args, **kwargs):
447447

448448

449449
@patch
450-
def mlflow_experiment(self: Project, experiment_model, name_prefix: str = ""):
450+
def mlflow_experiment(
451+
self: Project,
452+
experiment_model,
453+
name_prefix: str = "",
454+
save_to_git: bool = True,
455+
stage_all: bool = True,
456+
):
451457
"""Decorator for creating experiment functions with mlflow integration.
452458
453459
Args:
@@ -459,31 +465,22 @@ def mlflow_experiment(self: Project, experiment_model, name_prefix: str = ""):
459465
"""
460466

461467
def decorator(func: t.Callable) -> ExperimentProtocol:
462-
# First, create a base experiment wrapper
463-
base_experiment = self.experiment(experiment_model, name_prefix)(func)
464468

465-
# Override the wrapped function to add mlflow observation
466469
@wraps(func)
467-
async def wrapped_with_mlflow(*args, **kwargs):
468-
# wrap the function with mlflow observation
469-
observed_func = trace(name=f"{name_prefix}-{func.__name__}")(func)
470+
async def mlflow_wrapped_func(*args, **kwargs):
471+
# Apply mlflow observation directly here
472+
trace_name = (
473+
f"{name_prefix}-{func.__name__}" if name_prefix else func.__name__
474+
)
475+
observed_func = trace(name=trace_name)(func)
470476
return await observed_func(*args, **kwargs)
471477

472-
# Replace the async function to use mlflow
473-
original_run_async = base_experiment.run_async
474-
475-
# Use the original run_async but with the mlflow-wrapped function
476-
async def run_async_with_mlflow(dataset: Dataset, name: t.Optional[str] = None):
477-
# Override the internal wrapped_experiment with our mlflow version
478-
base_experiment.__wrapped__ = wrapped_with_mlflow
479-
480-
# Call the original run_async which will now use our mlflow-wrapped function
481-
return await original_run_async(dataset, name)
482-
483-
# Replace the run_async method
484-
base_experiment.__setattr__("run_async", run_async_with_mlflow)
478+
# Now create the experiment wrapper with our already-observed function
479+
experiment_wrapper = self.experiment(
480+
experiment_model, name_prefix, save_to_git, stage_all
481+
)(mlflow_wrapped_func)
485482

486-
return t.cast(ExperimentProtocol, base_experiment)
483+
return t.cast(ExperimentProtocol, experiment_wrapper)
487484

488485
return decorator
489486

0 commit comments

Comments
 (0)