1515logger = logging .getLogger (__name__ )
1616
1717
18- def is_git_repo (folder : Union [str , Path ]):
18+ def is_git_repo (folder : Union [str , Path ]) -> bool :
1919 """
2020 Check if the folder is the root of a git repository
2121 """
@@ -26,7 +26,7 @@ def is_git_repo(folder: Union[str, Path]):
2626 return folder_exists and git_branch .returncode == 0
2727
2828
29- def is_local_clone (folder : Union [str , Path ], remote_url : str ):
29+ def is_local_clone (folder : Union [str , Path ], remote_url : str ) -> bool :
3030 """
3131 Check if the folder is the a local clone of the remote_url
3232 """
@@ -48,6 +48,65 @@ def is_local_clone(folder: Union[str, Path], remote_url: str):
4848 return remote_url in remotes
4949
5050
51+ def is_tracked_with_lfs (filename : Union [str , Path ]) -> bool :
52+ """
53+ Check if the file passed is tracked with git-lfs.
54+ """
55+ folder = Path (filename ).parent
56+ filename = Path (filename ).name
57+
58+ try :
59+ p = subprocess .run (
60+ ["git" , "check-attr" , "-a" , filename ],
61+ stderr = subprocess .PIPE ,
62+ stdout = subprocess .PIPE ,
63+ check = True ,
64+ encoding = "utf-8" ,
65+ cwd = folder ,
66+ )
67+ attributes = p .stdout .strip ()
68+ except subprocess .CalledProcessError as exc :
69+ if "not a git repository" in exc .stderr :
70+ return False
71+ else :
72+ raise OSError (exc .stderr )
73+
74+ if len (attributes ) == 0 :
75+ return False
76+
77+ found_lfs_tag = {"diff" : False , "merge" : False , "filter" : False }
78+
79+ for attribute in attributes .split ("\n " ):
80+ for tag in found_lfs_tag .keys ():
81+ if tag in attribute and "lfs" in attribute :
82+ found_lfs_tag [tag ] = True
83+
84+ return all (found_lfs_tag .values ())
85+
86+
87+ def is_git_ignored (filename : Union [str , Path ]) -> bool :
88+ """
89+ Check if file is git-ignored. Supports nested .gitignore files.
90+ """
91+ folder = Path (filename ).parent
92+ filename = Path (filename ).name
93+
94+ try :
95+ p = subprocess .run (
96+ ["git" , "check-ignore" , filename ],
97+ stderr = subprocess .PIPE ,
98+ stdout = subprocess .PIPE ,
99+ encoding = "utf-8" ,
100+ cwd = folder ,
101+ )
102+ # Will return exit code 1 if not gitignored
103+ is_ignored = not bool (p .returncode )
104+ except subprocess .CalledProcessError as exc :
105+ raise OSError (exc .stderr )
106+
107+ return is_ignored
108+
109+
51110class Repository :
52111 """
53112 Helper class to wrap the git and git-lfs commands.
@@ -333,6 +392,47 @@ def git_head_commit_url(self) -> str:
333392 url = url [:- 1 ]
334393 return f"{ url } /commit/{ sha } "
335394
395+ def list_deleted_files (self ) -> List [str ]:
396+ """
397+ Returns a list of the files that are deleted in the working directory or index.
398+ """
399+ try :
400+ git_status = subprocess .run (
401+ ["git" , "status" , "--no-renames" , "-s" ],
402+ stderr = subprocess .PIPE ,
403+ stdout = subprocess .PIPE ,
404+ check = True ,
405+ encoding = "utf-8" ,
406+ cwd = self .local_dir ,
407+ ).stdout .strip ()
408+ except subprocess .CalledProcessError as exc :
409+ raise EnvironmentError (exc .stderr )
410+
411+ if len (git_status ) == 0 :
412+ return []
413+
414+ # Receives a status like the following
415+ # D .gitignore
416+ # D new_file.json
417+ # AD new_file1.json
418+ # ?? new_file2.json
419+ # ?? new_file4.json
420+
421+ # Strip each line of whitespaces
422+ modified_files_statuses = [status .strip () for status in git_status .split ("\n " )]
423+
424+ # Only keep files that are deleted using the D prefix
425+ deleted_files_statuses = [
426+ status for status in modified_files_statuses if "D" in status .split ()[0 ]
427+ ]
428+
429+ # Remove the D prefix and strip to keep only the relevant filename
430+ deleted_files = [
431+ status .split ()[- 1 ].strip () for status in deleted_files_statuses
432+ ]
433+
434+ return deleted_files
435+
336436 def lfs_track (self , patterns : Union [str , List [str ]]):
337437 """
338438 Tell git-lfs to track those files.
@@ -352,6 +452,25 @@ def lfs_track(self, patterns: Union[str, List[str]]):
352452 except subprocess .CalledProcessError as exc :
353453 raise EnvironmentError (exc .stderr )
354454
455+ def lfs_untrack (self , patterns : Union [str , List [str ]]):
456+ """
457+ Tell git-lfs to untrack those files.
458+ """
459+ if isinstance (patterns , str ):
460+ patterns = [patterns ]
461+ try :
462+ for pattern in patterns :
463+ subprocess .run (
464+ ["git" , "lfs" , "untrack" , pattern ],
465+ stderr = subprocess .PIPE ,
466+ stdout = subprocess .PIPE ,
467+ check = True ,
468+ encoding = "utf-8" ,
469+ cwd = self .local_dir ,
470+ )
471+ except subprocess .CalledProcessError as exc :
472+ raise EnvironmentError (exc .stderr )
473+
355474 def lfs_enable_largefiles (self ):
356475 """
357476 HF-specific. This enables upload support of files >5GB.
@@ -376,6 +495,42 @@ def lfs_enable_largefiles(self):
376495 except subprocess .CalledProcessError as exc :
377496 raise EnvironmentError (exc .stderr )
378497
498+ def auto_track_large_files (self , pattern = "." ):
499+ """
500+ Automatically track large files with git-lfs
501+ """
502+ try :
503+ p = subprocess .run (
504+ ["git" , "ls-files" , "-mo" , pattern ],
505+ stderr = subprocess .PIPE ,
506+ stdout = subprocess .PIPE ,
507+ check = True ,
508+ encoding = "utf-8" ,
509+ cwd = self .local_dir ,
510+ )
511+ files_to_be_staged = p .stdout .strip ().split ("\n " )
512+ except subprocess .CalledProcessError as exc :
513+ raise EnvironmentError (exc .stderr )
514+
515+ deleted_files = self .list_deleted_files ()
516+
517+ for filename in files_to_be_staged :
518+ if filename in deleted_files :
519+ continue
520+
521+ path_to_file = os .path .join (os .getcwd (), self .local_dir , filename )
522+ size_in_mb = os .path .getsize (path_to_file ) / (1024 * 1024 )
523+
524+ if (
525+ size_in_mb >= 10
526+ and not is_tracked_with_lfs (path_to_file )
527+ and not is_git_ignored (path_to_file )
528+ ):
529+ self .lfs_track (filename )
530+
531+ # Cleanup the .gitattributes if files were deleted
532+ self .lfs_untrack (deleted_files )
533+
379534 def git_pull (self , rebase : Optional [bool ] = False ):
380535 """
381536 git pull
@@ -395,10 +550,16 @@ def git_pull(self, rebase: Optional[bool] = False):
395550 except subprocess .CalledProcessError as exc :
396551 raise EnvironmentError (exc .stderr )
397552
398- def git_add (self , pattern = "." ):
553+ def git_add (self , pattern = "." , auto_lfs_track = False ):
399554 """
400555 git add
556+
557+ Setting the `auto_lfs_track` parameter to `True` will automatically track files that are larger
558+ than 10MB with `git-lfs`.
401559 """
560+ if auto_lfs_track :
561+ self .auto_track_large_files (pattern )
562+
402563 try :
403564 subprocess .run (
404565 ["git" , "add" , pattern ],
@@ -462,12 +623,10 @@ def push_to_hub(self, commit_message="commit files to HF hub") -> str:
462623 return self .git_push ()
463624
464625 @contextmanager
465- def commit (
466- self ,
467- commit_message : str ,
468- ):
626+ def commit (self , commit_message : str , track_large_files : bool = True ):
469627 """
470- Context manager utility to handle committing to a repository.
628+ Context manager utility to handle committing to a repository. This automatically tracks large files (>10Mb)
629+ with git-lfs. Set the `track_large_files` argument to `False` if you wish to ignore that behavior.
471630
472631 Examples:
473632
@@ -490,7 +649,7 @@ def commit(
490649 try :
491650 yield self
492651 finally :
493- self .git_add ()
652+ self .git_add (auto_lfs_track = True )
494653
495654 try :
496655 self .git_commit (commit_message )
0 commit comments