Skip to content

Commit f2ac63a

Browse files
authored
Merge pull request #16 from explodinggradients/feat/git-experiment
feat: git tracking with some helpful utils
2 parents a49fe0b + 605a37c commit f2ac63a

File tree

4 files changed

+478
-18
lines changed

4 files changed

+478
-18
lines changed

nbs/project/experiments.ipynb

Lines changed: 284 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,278 @@
299299
"p.get_experiment(\"test-exp\", TestModel)"
300300
]
301301
},
302+
{
303+
"cell_type": "markdown",
304+
"metadata": {},
305+
"source": [
306+
"## Git Versioning for Experiments"
307+
]
308+
},
309+
{
310+
"cell_type": "code",
311+
"execution_count": null,
312+
"metadata": {},
313+
"outputs": [],
314+
"source": [
315+
"# | export\n",
316+
"import git\n",
317+
"import os\n",
318+
"from pathlib import Path"
319+
]
320+
},
321+
{
322+
"cell_type": "code",
323+
"execution_count": null,
324+
"metadata": {},
325+
"outputs": [],
326+
"source": [
327+
"# | export\n",
328+
"def find_git_root(start_path: t.Union[str, Path, None] = None) -> Path:\n",
329+
" \"\"\"Find the root directory of a git repository by traversing up from the start path.\n",
330+
" \n",
331+
" Args:\n",
332+
" start_path: Path to start searching from (defaults to current working directory)\n",
333+
" \n",
334+
" Returns:\n",
335+
" Path: Absolute path to the git repository root\n",
336+
" \n",
337+
" Raises:\n",
338+
" ValueError: If no git repository is found\n",
339+
" \"\"\"\n",
340+
" # Start from the current directory if no path is provided\n",
341+
" if start_path is None:\n",
342+
" start_path = Path.cwd()\n",
343+
" else:\n",
344+
" start_path = Path(start_path).resolve()\n",
345+
" \n",
346+
" # Check if the current directory is a git repository\n",
347+
" current_path = start_path\n",
348+
" while current_path != current_path.parent: # Stop at filesystem root\n",
349+
" if (current_path / '.git').exists() and (current_path / '.git').is_dir():\n",
350+
" return current_path\n",
351+
" \n",
352+
" # Move up to the parent directory\n",
353+
" current_path = current_path.parent\n",
354+
" \n",
355+
" # Final check for the root directory\n",
356+
" if (current_path / '.git').exists() and (current_path / '.git').is_dir():\n",
357+
" return current_path\n",
358+
" \n",
359+
" # No git repository found\n",
360+
" raise ValueError(f\"No git repository found in or above {start_path}\")"
361+
]
362+
},
363+
{
364+
"cell_type": "code",
365+
"execution_count": null,
366+
"metadata": {},
367+
"outputs": [
368+
{
369+
"data": {
370+
"text/plain": [
371+
"Path('/Users/jjmachan/workspace/eglabs/ragas_annotator')"
372+
]
373+
},
374+
"execution_count": null,
375+
"metadata": {},
376+
"output_type": "execute_result"
377+
}
378+
],
379+
"source": [
380+
"find_git_root()"
381+
]
382+
},
383+
{
384+
"cell_type": "code",
385+
"execution_count": null,
386+
"metadata": {},
387+
"outputs": [
388+
{
389+
"data": {
390+
"text/plain": [
391+
"<git.repo.base.Repo '/Users/jjmachan/workspace/eglabs/ragas_annotator/.git'>"
392+
]
393+
},
394+
"execution_count": null,
395+
"metadata": {},
396+
"output_type": "execute_result"
397+
}
398+
],
399+
"source": [
400+
"git.Repo(find_git_root())"
401+
]
402+
},
403+
{
404+
"cell_type": "code",
405+
"execution_count": null,
406+
"metadata": {},
407+
"outputs": [],
408+
"source": [
409+
"# | export\n",
410+
"\n",
411+
"def version_experiment(\n",
412+
" experiment_name: str,\n",
413+
" commit_message: t.Optional[str] = None,\n",
414+
" repo_path: t.Union[str, Path, None] = None,\n",
415+
" create_branch: bool = True,\n",
416+
" stage_all: bool = False,\n",
417+
") -> str:\n",
418+
" \"\"\"\n",
419+
" Version control the current state of the codebase for an experiment.\n",
420+
" \"\"\"\n",
421+
" # Default to current directory if no repo path is provided\n",
422+
" if repo_path is None:\n",
423+
" repo_path = find_git_root()\n",
424+
" \n",
425+
" # Initialize git repo object\n",
426+
" repo = git.Repo(repo_path)\n",
427+
"\n",
428+
" # check if there are any changes to the repo\n",
429+
" has_changes = False\n",
430+
" if stage_all and repo.is_dirty(untracked_files=True):\n",
431+
" print(\"Staging all changes\")\n",
432+
" repo.git.add('.')\n",
433+
" has_changes = True\n",
434+
" elif repo.is_dirty(untracked_files=False):\n",
435+
" print(\"Staging changes to tracked files\")\n",
436+
" repo.git.add('-u')\n",
437+
" has_changes = True\n",
438+
" \n",
439+
" # Check if there are uncommitted changes\n",
440+
" if has_changes:\n",
441+
" # Default commit message if none provided\n",
442+
" if commit_message is None:\n",
443+
" commit_message = f\"Experiment: {experiment_name}\"\n",
444+
" \n",
445+
" # Commit changes\n",
446+
" commit = repo.index.commit(commit_message)\n",
447+
" commit_hash = commit.hexsha\n",
448+
" print(f\"Changes committed with hash: {commit_hash[:8]}\")\n",
449+
" else:\n",
450+
" # No changes to commit, use current HEAD\n",
451+
" commit_hash = repo.head.commit.hexsha\n",
452+
" print(\"No changes detected, nothing to commit\")\n",
453+
" \n",
454+
" # Format the branch/tag name\n",
455+
" version_name = f\"ragas/{experiment_name}\"\n",
456+
" \n",
457+
" # Create branch if requested\n",
458+
" if create_branch:\n",
459+
" branch = repo.create_head(version_name, commit_hash)\n",
460+
" print(f\"Created branch: {version_name}\")\n",
461+
" \n",
462+
" return commit_hash"
463+
]
464+
},
465+
{
466+
"cell_type": "code",
467+
"execution_count": null,
468+
"metadata": {},
469+
"outputs": [],
470+
"source": [
471+
"# | export\n",
472+
"def cleanup_experiment_branches(\n",
473+
" prefix: str = \"ragas/\", \n",
474+
" repo_path: t.Union[str, Path, None] = None,\n",
475+
" interactive: bool = True,\n",
476+
" dry_run: bool = False\n",
477+
") -> t.List[str]:\n",
478+
" \"\"\"Clean up git branches with the specified prefix.\"\"\"\n",
479+
" # Find the git repository root if not provided\n",
480+
" if repo_path is None:\n",
481+
" try:\n",
482+
" repo_path = find_git_root()\n",
483+
" except ValueError as e:\n",
484+
" raise ValueError(f\"Cannot cleanup branches: {str(e)}\")\n",
485+
" \n",
486+
" # Initialize git repo object\n",
487+
" repo = git.Repo(repo_path)\n",
488+
" current_branch = repo.active_branch.name\n",
489+
" \n",
490+
" # Get all branches matching the prefix\n",
491+
" matching_branches = []\n",
492+
" for branch in repo.branches:\n",
493+
" if branch.name.startswith(prefix):\n",
494+
" matching_branches.append(branch.name)\n",
495+
" \n",
496+
" if not matching_branches:\n",
497+
" print(f\"No branches found with prefix '{prefix}'\")\n",
498+
" return []\n",
499+
" \n",
500+
" # Remove current branch from the list if present\n",
501+
" if current_branch in matching_branches:\n",
502+
" print(f\"Note: Current branch '{current_branch}' will be excluded from deletion\")\n",
503+
" matching_branches.remove(current_branch)\n",
504+
" \n",
505+
" if not matching_branches:\n",
506+
" print(\"No branches available for deletion after excluding current branch\")\n",
507+
" return []\n",
508+
" \n",
509+
" # Show branches to the user\n",
510+
" print(f\"Found {len(matching_branches)} branches with prefix '{prefix}':\")\n",
511+
" for branch_name in matching_branches:\n",
512+
" print(f\"- {branch_name}\")\n",
513+
" \n",
514+
" # Handle confirmation in interactive mode\n",
515+
" proceed = True\n",
516+
" if interactive and not dry_run:\n",
517+
" confirm = input(f\"\\nDelete these {len(matching_branches)} branches? (y/n): \").strip().lower()\n",
518+
" proceed = (confirm == 'y')\n",
519+
" \n",
520+
" if not proceed:\n",
521+
" print(\"Operation cancelled\")\n",
522+
" return []\n",
523+
" \n",
524+
" # Perform deletion\n",
525+
" deleted_branches = []\n",
526+
" for branch_name in matching_branches:\n",
527+
" if dry_run:\n",
528+
" print(f\"Would delete branch: {branch_name}\")\n",
529+
" deleted_branches.append(branch_name)\n",
530+
" else:\n",
531+
" try:\n",
532+
" # Delete the branch\n",
533+
" repo.git.branch('-D', branch_name)\n",
534+
" print(f\"Deleted branch: {branch_name}\")\n",
535+
" deleted_branches.append(branch_name)\n",
536+
" except git.GitCommandError as e:\n",
537+
" print(f\"Error deleting branch '{branch_name}': {str(e)}\")\n",
538+
" \n",
539+
" if dry_run:\n",
540+
" print(f\"\\nDry run complete. {len(deleted_branches)} branches would be deleted.\")\n",
541+
" else:\n",
542+
" print(f\"\\nCleanup complete. {len(deleted_branches)} branches deleted.\")\n",
543+
" \n",
544+
" return deleted_branches"
545+
]
546+
},
547+
{
548+
"cell_type": "code",
549+
"execution_count": null,
550+
"metadata": {},
551+
"outputs": [
552+
{
553+
"name": "stdout",
554+
"output_type": "stream",
555+
"text": [
556+
"No branches found with prefix 'ragas/'\n"
557+
]
558+
},
559+
{
560+
"data": {
561+
"text/plain": [
562+
"[]"
563+
]
564+
},
565+
"execution_count": null,
566+
"metadata": {},
567+
"output_type": "execute_result"
568+
}
569+
],
570+
"source": [
571+
"cleanup_experiment_branches(dry_run=True)"
572+
]
573+
},
302574
{
303575
"cell_type": "markdown",
304576
"metadata": {},
@@ -348,7 +620,7 @@
348620
"# | export\n",
349621
"@patch\n",
350622
"def experiment(\n",
351-
" self: Project, experiment_model, name_prefix: str = \"\"\n",
623+
" self: Project, experiment_model, name_prefix: str = \"\", save_to_git: bool = True, stage_all: bool = False\n",
352624
"):\n",
353625
" \"\"\"Decorator for creating experiment functions without Langfuse integration.\n",
354626
"\n",
@@ -367,7 +639,7 @@
367639
" return await func(*args, **kwargs)\n",
368640
"\n",
369641
" # Add run method to the wrapped function\n",
370-
" async def run_async(dataset: Dataset, name: t.Optional[str] = None):\n",
642+
" async def run_async(dataset: Dataset, name: t.Optional[str] = None, save_to_git: bool = save_to_git, stage_all: bool = stage_all):\n",
371643
" # if name is not provided, generate a memorable name\n",
372644
" if name is None:\n",
373645
" name = memorable_names.generate_unique_name()\n",
@@ -404,7 +676,6 @@
404676
" progress_bar.update(1) # Update for append operation\n",
405677
" \n",
406678
" progress_bar.close()\n",
407-
" return experiment_view\n",
408679
" \n",
409680
" except Exception as e:\n",
410681
" # Clean up the experiment if there was an error and it was created\n",
@@ -419,6 +690,13 @@
419690
" # Re-raise the original exception\n",
420691
" raise e\n",
421692
"\n",
693+
" # save to git if requested\n",
694+
" if save_to_git:\n",
695+
" repo_path = find_git_root()\n",
696+
" version_experiment(experiment_name=name, repo_path=repo_path, stage_all=stage_all)\n",
697+
"\n",
698+
" return experiment_view\n",
699+
"\n",
422700
" wrapped_experiment.__setattr__(\"run_async\", run_async)\n",
423701
" return t.cast(ExperimentProtocol, wrapped_experiment)\n",
424702
"\n",
@@ -451,7 +729,7 @@
451729
" is_correct: t.Literal[\"yes\", \"no\"]\n",
452730
"\n",
453731
"# create a test experiment function\n",
454-
"@p.experiment(TextExperimentModel)\n",
732+
"@p.experiment(TextExperimentModel, save_to_git=False, stage_all=True)\n",
455733
"async def test_experiment(item: TestModel):\n",
456734
" return TextExperimentModel(**item.model_dump(), response=\"test response\", is_correct=\"yes\")\n"
457735
]
@@ -465,13 +743,13 @@
465743
"name": "stderr",
466744
"output_type": "stream",
467745
"text": [
468-
"Running experiment: 100%|██████████| 6/6 [00:01<00:00, 3.23it/s]\n"
746+
"Running experiment: 100%|██████████| 6/6 [00:01<00:00, 3.05it/s]\n"
469747
]
470748
},
471749
{
472750
"data": {
473751
"text/plain": [
474-
"Experiment(name=dazzling_knuth, model=TextExperimentModel)"
752+
"Experiment(name=xenodochial_dorsey, model=TextExperimentModel)"
475753
]
476754
},
477755
"execution_count": null,

ragas_experimental/_modidx.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,8 +611,14 @@
611611
'ragas_experimental/project/experiments.py'),
612612
'ragas_experimental.project.experiments.Project.langfuse_experiment': ( 'project/experiments.html#project.langfuse_experiment',
613613
'ragas_experimental/project/experiments.py'),
614+
'ragas_experimental.project.experiments.cleanup_experiment_branches': ( 'project/experiments.html#cleanup_experiment_branches',
615+
'ragas_experimental/project/experiments.py'),
614616
'ragas_experimental.project.experiments.create_experiment_columns': ( 'project/experiments.html#create_experiment_columns',
615-
'ragas_experimental/project/experiments.py')},
617+
'ragas_experimental/project/experiments.py'),
618+
'ragas_experimental.project.experiments.find_git_root': ( 'project/experiments.html#find_git_root',
619+
'ragas_experimental/project/experiments.py'),
620+
'ragas_experimental.project.experiments.version_experiment': ( 'project/experiments.html#version_experiment',
621+
'ragas_experimental/project/experiments.py')},
616622
'ragas_experimental.project.naming': { 'ragas_experimental.project.naming.MemorableNames': ( 'project/naming.html#memorablenames',
617623
'ragas_experimental/project/naming.py'),
618624
'ragas_experimental.project.naming.MemorableNames.__init__': ( 'project/naming.html#memorablenames.__init__',

0 commit comments

Comments
 (0)