Skip to content

Commit f7294d1

Browse files
Improved remote repo support (#3279)
1 parent a758bd3 commit f7294d1

File tree

5 files changed

+82
-84
lines changed

5 files changed

+82
-84
lines changed

runner/internal/repo/manager.go

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,26 @@ type Manager struct {
1616
ctx context.Context
1717
localPath string
1818
clo git.CloneOptions
19+
branch string
1920
hash string
2021
}
2122

2223
func NewManager(ctx context.Context, url, branch, hash string, singleBranch bool) *Manager {
2324
ctx = log.AppendArgsCtx(ctx, "url", url, "branch", branch, "hash", hash)
2425
m := &Manager{
25-
ctx: ctx,
26+
ctx: ctx,
27+
branch: branch,
28+
hash: hash,
2629
clo: git.CloneOptions{
2730
URL: url,
2831
RecurseSubmodules: git.DefaultSubmoduleRecursionDepth,
29-
ReferenceName: plumbing.NewBranchReferenceName(branch),
3032
SingleBranch: singleBranch,
3133
},
32-
hash: hash,
34+
}
35+
// Only set ReferenceName if branch is non-empty
36+
// If empty, it will default to HEAD in CloneOptions.Validate()
37+
if branch != "" {
38+
m.clo.ReferenceName = plumbing.NewBranchReferenceName(branch)
3339
}
3440

3541
return m
@@ -69,23 +75,41 @@ func (m *Manager) Checkout() error {
6975
return fmt.Errorf("clone repo: %w", err)
7076
}
7177
if ref != nil {
72-
branchRef, err := ref.Reference(m.clo.ReferenceName, true)
73-
if err != nil {
74-
return fmt.Errorf("get branch reference: %w", err)
75-
}
7678
var cho git.CheckoutOptions
77-
if m.hash == "" || m.hash == branchRef.Hash().String() {
78-
cho.Branch = m.clo.ReferenceName
79+
needCheckout := false
80+
81+
if m.branch != "" {
82+
branchRef, err := ref.Reference(m.clo.ReferenceName, true)
83+
if err != nil {
84+
return fmt.Errorf("get branch reference: %w", err)
85+
}
86+
if m.hash == "" || m.hash == branchRef.Hash().String() {
87+
// Hash is empty or matches branch head: checkout branch
88+
cho.Branch = m.clo.ReferenceName
89+
needCheckout = true
90+
} else {
91+
// Hash is specified and different: checkout by hash
92+
cho.Hash = plumbing.NewHash(m.hash)
93+
needCheckout = true
94+
}
7995
} else {
80-
cho.Hash = plumbing.NewHash(m.hash)
96+
// Branch is empty: checkout by hash if specified, otherwise HEAD is already checked out
97+
if m.hash != "" {
98+
cho.Hash = plumbing.NewHash(m.hash)
99+
needCheckout = true
100+
}
101+
// If hash is also empty, HEAD is already checked out by clone, no need to checkout again
81102
}
82-
workTree, err := ref.Worktree()
83-
if err != nil {
84-
return fmt.Errorf("get worktree: %w", err)
85-
}
86-
err = workTree.Checkout(&cho)
87-
if err != nil {
88-
return fmt.Errorf("checkout: %w", err)
103+
104+
if needCheckout {
105+
workTree, err := ref.Worktree()
106+
if err != nil {
107+
return fmt.Errorf("get worktree: %w", err)
108+
}
109+
err = workTree.Checkout(&cho)
110+
if err != nil {
111+
return fmt.Errorf("checkout: %w", err)
112+
}
89113
}
90114
} else {
91115
log.Warning(m.ctx, "git clone ref==nil")

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from dstack._internal.core.services.diff import diff_models
5555
from dstack._internal.core.services.repos import (
5656
InvalidRepoCredentialsError,
57-
get_repo_creds_and_default_branch,
57+
get_repo_creds,
5858
load_repo,
5959
)
6060
from dstack._internal.utils.common import local_time
@@ -562,8 +562,7 @@ def get_repo(
562562
local_path = Path.cwd()
563563
legacy_local_path = True
564564
if url:
565-
# "master" is a dummy value, we'll fetch the actual default branch later
566-
repo = RemoteRepo.from_url(repo_url=url, repo_branch="master")
565+
repo = RemoteRepo.from_url(repo_url=url)
567566
repo_head = self.api.repos.get(repo_id=repo.repo_id, with_creds=True)
568567
elif local_path:
569568
if legacy_local_path:
@@ -618,7 +617,7 @@ def get_repo(
618617
init = True
619618

620619
try:
621-
repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
620+
repo_creds = get_repo_creds(
622621
repo_url=repo.repo_url,
623622
identity_file=git_identity_file,
624623
private_key=git_private_key,
@@ -627,16 +626,11 @@ def get_repo(
627626
except InvalidRepoCredentialsError as e:
628627
raise CLIError(*e.args) from e
629628

630-
if repo_branch is None and repo_hash is None:
631-
repo_branch = default_repo_branch
632-
if repo_branch is None:
633-
raise CLIError(
634-
"Failed to automatically detect remote repo branch."
635-
" Specify branch or hash."
636-
)
637-
repo = RemoteRepo.from_url(
638-
repo_url=repo.repo_url, repo_branch=repo_branch, repo_hash=repo_hash
639-
)
629+
# repo_branch and repo_hash are taken from the repo_spec
630+
if repo_branch is not None:
631+
repo.run_repo_data.repo_branch = repo_branch
632+
if repo_hash is not None:
633+
repo.run_repo_data.repo_hash = repo_hash
640634

641635
if init:
642636
self.api.repos.init(

src/dstack/_internal/core/models/repos/remote.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,12 @@ def from_url(
135135
136136
Args:
137137
repo_url: The URL of a remote Git repo.
138-
repo_branch: The name of the remote branch. Must be specified if `hash` is not specified.
139-
repo_hash: The hash of the revision. Must be specified if `branch` is not specified.
138+
repo_branch: The name of the remote branch.
139+
repo_hash: The hash of the revision.
140140
141141
Returns:
142142
A remote repo instance.
143143
"""
144-
if repo_branch is None and repo_hash is None:
145-
raise ValueError("Either `repo_branch` or `repo_hash` must be specified.")
146144
return RemoteRepo(
147145
repo_url=repo_url,
148146
repo_branch=repo_branch,

src/dstack/_internal/core/services/repos.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -26,86 +26,74 @@ class InvalidRepoCredentialsError(DstackError):
2626
pass
2727

2828

29-
def get_repo_creds_and_default_branch(
29+
def get_repo_creds(
3030
repo_url: str,
3131
identity_file: Optional[PathLike] = None,
3232
private_key: Optional[str] = None,
3333
oauth_token: Optional[str] = None,
34-
) -> tuple[RemoteRepoCreds, Optional[str]]:
34+
) -> RemoteRepoCreds:
3535
url = GitRepoURL.parse(repo_url, get_ssh_config=get_host_config)
3636

3737
# no auth
3838
with suppress(InvalidRepoCredentialsError):
39-
creds, default_branch = _get_repo_creds_and_default_branch_https(url)
40-
logger.debug(
41-
"Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch
42-
)
43-
return creds, default_branch
39+
creds = _get_repo_creds_https(url)
40+
logger.debug("Git repo %s is public. Using no auth.", repo_url)
41+
return creds
4442

4543
# ssh key provided by the user or pulled from the server
4644
if identity_file is not None or private_key is not None:
4745
if identity_file is not None:
4846
private_key = _read_private_key(identity_file)
49-
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
50-
url, identity_file, private_key
51-
)
47+
creds = _get_repo_creds_ssh(url, identity_file, private_key)
5248
logger.debug(
53-
"Git repo %s is private. Using identity file: %s. Default branch: %s",
49+
"Git repo %s is private. Using identity file: %s.",
5450
repo_url,
5551
identity_file,
56-
default_branch,
5752
)
58-
return creds, default_branch
53+
return creds
5954
elif private_key is not None:
6055
with NamedTemporaryFile("w+", 0o600) as f:
6156
f.write(private_key)
6257
f.flush()
63-
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
64-
url, f.name, private_key
65-
)
58+
creds = _get_repo_creds_ssh(url, f.name, private_key)
6659
masked_key = "***" + private_key[-10:] if len(private_key) > 10 else "***MASKED***"
6760
logger.debug(
6861
"Git repo %s is private. Using private key: %s. Default branch: %s",
6962
repo_url,
7063
masked_key,
71-
default_branch,
7264
)
73-
return creds, default_branch
65+
return creds
7466
else:
7567
assert False, "should not reach here"
7668

7769
# oauth token provided by the user or pulled from the server
7870
if oauth_token is not None:
79-
creds, default_branch = _get_repo_creds_and_default_branch_https(url, oauth_token)
71+
creds = _get_repo_creds_https(url, oauth_token)
8072
masked_token = (
8173
len(oauth_token[:-4]) * "*" + oauth_token[-4:]
8274
if len(oauth_token) > 4
8375
else "***MASKED***"
8476
)
8577
logger.debug(
86-
"Git repo %s is private. Using provided OAuth token: %s. Default branch: %s",
78+
"Git repo %s is private. Using provided OAuth token: %s.",
8779
repo_url,
8880
masked_token,
89-
default_branch,
9081
)
91-
return creds, default_branch
82+
return creds
9283

9384
# key from ssh config
9485
identities = get_host_config(url.original_host).get("identityfile")
9586
if identities:
9687
_identity_file = identities[0]
9788
with suppress(InvalidRepoCredentialsError):
9889
_private_key = _read_private_key(_identity_file)
99-
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
100-
url, _identity_file, _private_key
101-
)
90+
creds = _get_repo_creds_ssh(url, _identity_file, _private_key)
10291
logger.debug(
103-
"Git repo %s is private. Using SSH config identity file: %s. Default branch: %s",
92+
"Git repo %s is private. Using SSH config identity file: %s.",
10493
repo_url,
10594
_identity_file,
106-
default_branch,
10795
)
108-
return creds, default_branch
96+
return creds
10997

11098
# token from gh config
11199
if os.path.exists(gh_config_path):
@@ -114,48 +102,44 @@ def get_repo_creds_and_default_branch(
114102
_oauth_token = gh_hosts.get(url.host, {}).get("oauth_token")
115103
if _oauth_token is not None:
116104
with suppress(InvalidRepoCredentialsError):
117-
creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token)
105+
creds = _get_repo_creds_https(url, _oauth_token)
118106
masked_token = (
119107
len(_oauth_token[:-4]) * "*" + _oauth_token[-4:]
120108
if len(_oauth_token) > 4
121109
else "***MASKED***"
122110
)
123111
logger.debug(
124-
"Git repo %s is private. Using GitHub config token: %s from %s. Default branch: %s",
112+
"Git repo %s is private. Using GitHub config token: %s from %s.",
125113
repo_url,
126114
masked_token,
127115
gh_config_path,
128-
default_branch,
129116
)
130-
return creds, default_branch
117+
return creds
131118

132119
# default user key
133120
if os.path.exists(default_ssh_key):
134121
with suppress(InvalidRepoCredentialsError):
135122
_private_key = _read_private_key(default_ssh_key)
136-
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
137-
url, default_ssh_key, _private_key
138-
)
123+
creds = _get_repo_creds_ssh(url, default_ssh_key, _private_key)
139124
logger.debug(
140-
"Git repo %s is private. Using default identity file: %s. Default branch: %s",
125+
"Git repo %s is private. Using default identity file: %s.",
141126
repo_url,
142127
default_ssh_key,
143-
default_branch,
144128
)
145-
return creds, default_branch
129+
return creds
146130

147131
raise InvalidRepoCredentialsError(
148132
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
149133
)
150134

151135

152-
def _get_repo_creds_and_default_branch_ssh(
136+
def _get_repo_creds_ssh(
153137
url: GitRepoURL, identity_file: PathLike, private_key: str
154-
) -> tuple[RemoteRepoCreds, Optional[str]]:
138+
) -> RemoteRepoCreds:
155139
_url = url.as_ssh()
156140
env = _make_git_env_for_creds_check(identity_file=identity_file)
157141
try:
158-
default_branch = _get_repo_default_branch(_url, env)
142+
_get_repo_default_branch(_url, env)
159143
except GitCommandError as e:
160144
message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key"
161145
raise InvalidRepoCredentialsError(message) from e
@@ -164,16 +148,14 @@ def _get_repo_creds_and_default_branch_ssh(
164148
private_key=private_key,
165149
oauth_token=None,
166150
)
167-
return creds, default_branch
151+
return creds
168152

169153

170-
def _get_repo_creds_and_default_branch_https(
171-
url: GitRepoURL, oauth_token: Optional[str] = None
172-
) -> tuple[RemoteRepoCreds, Optional[str]]:
154+
def _get_repo_creds_https(url: GitRepoURL, oauth_token: Optional[str] = None) -> RemoteRepoCreds:
173155
_url = url.as_https()
174156
env = _make_git_env_for_creds_check()
175157
try:
176-
default_branch = _get_repo_default_branch(url.as_https(oauth_token), env)
158+
_get_repo_default_branch(url.as_https(oauth_token), env)
177159
except GitCommandError as e:
178160
message = f"Cannot access `{_url}`"
179161
if oauth_token is not None:
@@ -185,7 +167,7 @@ def _get_repo_creds_and_default_branch_https(
185167
private_key=None,
186168
oauth_token=oauth_token,
187169
)
188-
return creds, default_branch
170+
return creds
189171

190172

191173
def _make_git_env_for_creds_check(identity_file: Optional[PathLike] = None) -> dict[str, str]:

src/dstack/api/_public/repos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dstack._internal.core.services.configs import ConfigManager
1515
from dstack._internal.core.services.repos import (
1616
InvalidRepoCredentialsError,
17-
get_repo_creds_and_default_branch,
17+
get_repo_creds,
1818
load_repo,
1919
)
2020
from dstack._internal.utils.logging import get_logger
@@ -76,7 +76,7 @@ def init(
7676
if creds is None and isinstance(repo, RemoteRepo):
7777
assert repo.repo_url is not None
7878
try:
79-
creds, _ = get_repo_creds_and_default_branch(
79+
creds = get_repo_creds(
8080
repo_url=repo.repo_url,
8181
identity_file=git_identity_file,
8282
oauth_token=oauth_token,

0 commit comments

Comments
 (0)