diff --git a/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/batch_jobs.py b/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/batch_jobs.py index f8c02b9..cd27fbd 100644 --- a/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/batch_jobs.py +++ b/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/batch_jobs.py @@ -14,9 +14,9 @@ from {{ cookiecutter.project_name }}.configs import RunConfig from {{ cookiecutter.project_name }}.constants import PROJECT_SHORT, WANDB_ENTITY, WANDB_PROJECT -from {{ cookiecutter.project_name }}.utils.git import git_latest_commit, validate_git_repo +from {{ cookiecutter.project_name }}.utils.git import get_repo_root, git_latest_commit, validate_git_repo -JOB_TEMPLATE_PATH = Path(__file__).parent.parent.parent / "k8s" / "batch_job.yaml" +JOB_TEMPLATE_PATH = get_repo_root() / "k8s" / "batch_job.yaml" with JOB_TEMPLATE_PATH.open() as f: JOB_TEMPLATE = f.read() diff --git a/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/git.py b/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/git.py index dba2848..8a66f9c 100644 --- a/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/git.py +++ b/{{cookiecutter.project_slug}}/src/{{cookiecutter.package_name}}/utils/git.py @@ -6,22 +6,28 @@ from {{ cookiecutter.project_name }}.utils.utils import ask_for_confirmation -JOB_TEMPLATE_PATH = Path(__file__).parent.parent.parent / "k8s" / "batch_job.yaml" -with JOB_TEMPLATE_PATH.open() as f: - JOB_TEMPLATE = f.read() - @functools.cache def git_latest_commit() -> str: """Gets the latest commit hash.""" - repo = Repo(".") + repo = Repo(".", search_parent_directories=True) commit_hash = str(repo.head.object.hexsha) return commit_hash +@functools.cache +def get_repo_root() -> Path: + """Get the root directory of the git repository.""" + repo = git.Repo(".", search_parent_directories=True) + working_dir = repo.working_tree_dir + if working_dir is None: + raise RuntimeError("Could not find git repository root") + return Path(working_dir) + + def validate_git_repo() -> None: """Validates the git repo before running a batch job.""" - repo = Repo(".") + repo = Repo(".", search_parent_directories=True) # Push to git as we want to run the code with the current commit. repo.remote("origin").push(repo.active_branch.name).raise_if_error()