@@ -484,9 +484,11 @@ def merge(
484484 """
485485 Executes a merge operation. If conflicts result, the merge is aborted, as an interactive merge does not really
486486 make sense in a scripting environment, or at least we have not figured out how to model it in a way that does.
487- :param branch:
488- :param message:
489- :param squash:
487+ :param branch: name of the branch to merge into the current branch
488+ :param message: message to be used for the merge commit only in the case of an automatic
489+ merge. In case of automatic merge without a message provided, the commit message will be
490+ "Merge branch '<branch>' into '<current_branch>'"
491+ :param squash: squash the commits from the merged branch into a single commit
490492 :return:
491493 """
492494 current_branch , branches = self ._get_branches ()
@@ -503,16 +505,19 @@ def merge(
503505
504506 if squash :
505507 args .append ("--squash" )
506-
508+ if message :
509+ args .extend (["--message" , message ])
507510 args .append (branch )
508511 output = self .execute (args , ** kwargs ).split ("\n " )
509- merge_conflict_pos = 2
510512
511- if len (output ) == 3 and "Fast-forward" in output [1 ]:
513+ # TODO: this was and remains a hack, we need to parse the output properly
514+ if len (output ) > 1 and "Fast-forward" in output [0 ]:
512515 logger .info (f"Completed fast-forward merge of { branch } into { current_branch .name } " )
513516 return
514517
515- if len (output ) == 5 and output [merge_conflict_pos ].startswith ("CONFLICT" ):
518+ # TODO: this was and remains a hack, we need to parse the output properly
519+ merge_conflict_pos = 8
520+ if len (output ) > 1 and output [merge_conflict_pos ].startswith ("CONFLICT" ):
516521 logger .warning (
517522 f"""
518523 The following merge conflict occurred merging { branch } to { current_branch .name } :
@@ -527,12 +532,6 @@ def merge(
527532 if message is None :
528533 message = f"Merged { current_branch .name } into { branch } "
529534 logger .info (message )
530- status = self .status ()
531-
532- for table in list (status .added_tables .keys ()) + list (status .modified_tables .keys ()):
533- self .add (table )
534-
535- self .commit (message )
536535
537536 def sql (
538537 self ,
@@ -729,18 +728,35 @@ def branch(
729728 delete : bool = False ,
730729 copy : bool = False ,
731730 move : bool = False ,
731+ remote : bool = False ,
732+ all : bool = False ,
732733 ** kwargs ,
733734 ):
734735 """
735- Checkout, create, delete, move, or copy, a branch. Only
736- :param branch_name:
737- :param start_point:
738- :param new_branch:
739- :param force:
740- :param delete:
741- :param copy:
742- :param move:
743- :return:
736+ List, create, or delete branches.
737+
738+ If 'branch_name' is None, existing branches are listed, including remotely tracked branches
739+ if 'remote' or 'all' are set. If 'branch_name' is provided, a new branch is created, checked
740+ our, deleted, moved or copied.
741+
742+ :param branch_name: Name of branch to Checkout, create, delete, move, or copy.
743+ :param start_point: A commit that a new branch should point at.
744+ :param new_branch: Name of branch to copy to or rename to if 'copy' or 'move' is set.
745+ :param force: Reset 'branch_name' to 'start_point', even if 'branch_name' exists already.
746+ Without 'force', dolt branch refuses to change an existing branch. In combination with
747+ 'delete', allow deleting the branch irrespective of its merged status. In
748+ combination with 'move', allow renaming the branch even if the new branch name
749+ already exists, the same applies for 'copy'.
750+ :param delete: Delete a branch. The branch must be fully merged in its upstream branch.
751+ :param copy: Create a copy of a branch.
752+ :param move: Move/rename a branch. If 'new_branch' does not exist, 'branch_name' will be
753+ renamed to 'new_branch'. If 'new_branch' exists, 'force' must be used to force the
754+ rename to happen.
755+ :param remote: When in list mode, show only remote tracked branches, unless 'all' is true.
756+ When with -d, delete a remote tracking branch.
757+ :param all: When in list mode, shows both local and remote tracked branches
758+
759+ :return: active_branch, branches
744760 """
745761 switch_count = [el for el in [delete , copy , move ] if el ]
746762 if len (switch_count ) > 1 :
@@ -751,7 +767,7 @@ def branch(
751767 raise ValueError (
752768 "force is not valid without providing a new branch name, or copy, move, or delete being true"
753769 )
754- return self ._get_branches ()
770+ return self ._get_branches (remote = remote , all = all )
755771
756772 args = ["branch" ]
757773 if force :
@@ -780,6 +796,8 @@ def execute_wrapper(command_args: List[str]):
780796 if not branch_name :
781797 raise ValueError ("must provide branch_name when deleting" )
782798 args .extend (["--delete" , branch_name ])
799+ if remote :
800+ args .append ("--remote" )
783801 return execute_wrapper (args )
784802
785803 if move :
@@ -797,25 +815,42 @@ def execute_wrapper(command_args: List[str]):
797815 args .append (start_point )
798816 return execute_wrapper (args )
799817
800- return self ._get_branches ()
818+ return self ._get_branches (remote = remote , all = all )
801819
802- def _get_branches (self ) -> Tuple [Branch , List [Branch ]]:
803- dicts = read_rows_sql (self , sql = "select * from dolt_branches" )
804- branches = [Branch (** d ) for d in dicts ]
820+ def _get_branches (self , remote : bool = False , all : bool = False ) -> Tuple [Branch , List [Branch ]]:
821+ """
822+ Gets the branches for this repository, optionally including remote branches, and optionally
823+ including all.
824+
825+ :param remote: include remotely tracked branches. If all is false and remote is true, only
826+ remotely track branches are returned. If all is true both local and remote are included.
827+ Default is False
828+ :param all: include both local and remotely tracked branches. Default is False
829+ :return: active_branch, branches
830+ """
831+ local_dicts = read_rows_sql (self , sql = "select * from dolt_branches" )
832+ dicts = []
833+ if all :
834+ dicts = local_dicts + read_rows_sql (self , sql = "select * from dolt_remote_branches" )
835+ elif remote :
836+ dicts = read_rows_sql (self , sql = "select * from dolt_remote_branches" )
837+ else :
838+ dicts = local_dicts
839+
840+ # find active_branch
805841 ab_dicts = read_rows_sql (
806842 self , "select * from dolt_branches where name = (select active_branch())"
807843 )
808-
809844 if len (ab_dicts ) != 1 :
810845 raise ValueError (
811846 "Ensure you have the latest version of Dolt installed, this is fixed as of 0.24.2"
812847 )
813-
814848 active_branch = Branch (** ab_dicts [0 ])
815-
816849 if not active_branch :
817850 raise DoltException ("Failed to set active branch" )
818851
852+ branches = [Branch (** d ) for d in dicts ]
853+
819854 return active_branch , branches
820855
821856 def checkout (
@@ -824,6 +859,7 @@ def checkout(
824859 tables : Optional [Union [str , List [str ]]] = None ,
825860 checkout_branch : bool = False ,
826861 start_point : Optional [str ] = None ,
862+ track : Optional [str ] = None ,
827863 ** kwargs ,
828864 ):
829865 """
@@ -833,6 +869,7 @@ def checkout(
833869 :param tables: table or tables to checkout
834870 :param checkout_branch: branch to checkout
835871 :param start_point: tip of new branch
872+ :param track: the upstream branch to track
836873 :return:
837874 """
838875 if tables and branch :
@@ -849,6 +886,10 @@ def checkout(
849886 if tables :
850887 args .append (" " .join (to_list (tables )))
851888
889+ if track is not None :
890+ args .append ("--track" )
891+ args .append (track )
892+
852893 self .execute (args , ** kwargs )
853894
854895 def remote (
@@ -929,13 +970,18 @@ def push(
929970 # just print the output
930971 self .execute (args , ** kwargs )
931972
932- def pull (self , remote : str = "origin" , ** kwargs ):
973+ def pull (self , remote : str = "origin" , branch : Optional [ str ] = None , ** kwargs ):
933974 """
934975 Pull the latest changes from the specified remote.
935- :param remote:
976+ :param remote: The remote to pull the changes from
977+ :param branch: The branch on the remote to pull the changes from
936978 :return:
937979 """
938- self .execute (["pull" , remote ], ** kwargs )
980+ args = ["pull" , remote ]
981+ if branch is not None :
982+ args .append (branch )
983+
984+ self .execute (args , ** kwargs )
939985
940986 def fetch (
941987 self ,
@@ -1227,7 +1273,6 @@ def _config_helper(
12271273 get : bool = False ,
12281274 unset : bool = False ,
12291275 ) -> Dict [str , str ]:
1230-
12311276 switch_count = [el for el in [add , list , get , unset ] if el ]
12321277 if len (switch_count ) != 1 :
12331278 raise ValueError ("Exactly one of add, list, get, unset must be True" )
0 commit comments