Skip to content

Commit f21a989

Browse files
Improved remote repo support (#3279) (#3285)
Revert detecting and passing `default_repo_branch` for backward comnpatibility of the `datack-runner`
1 parent 54cf8bb commit f21a989

File tree

3 files changed

+60
-36
lines changed

3 files changed

+60
-36
lines changed

src/dstack/_internal/cli/services/configurators/run.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from dstack._internal.core.services.diff import diff_models
5555
from dstack._internal.core.services.repos import (
5656
InvalidRepoCredentialsError,
57-
get_repo_creds,
57+
get_repo_creds_and_default_branch,
5858
load_repo,
5959
)
6060
from dstack._internal.utils.common import local_time
@@ -617,7 +617,7 @@ def get_repo(
617617
init = True
618618

619619
try:
620-
repo_creds = get_repo_creds(
620+
repo_creds, default_repo_branch = get_repo_creds_and_default_branch(
621621
repo_url=repo.repo_url,
622622
identity_file=git_identity_file,
623623
private_key=git_private_key,
@@ -626,9 +626,15 @@ def get_repo(
626626
except InvalidRepoCredentialsError as e:
627627
raise CLIError(*e.args) from e
628628

629-
# repo_branch and repo_hash are taken from the repo_spec
630-
if repo_branch is not None:
631-
repo.run_repo_data.repo_branch = repo_branch
629+
if repo_branch is None and repo_hash is None:
630+
if default_repo_branch is None:
631+
raise CLIError(
632+
"Failed to automatically detect remote repo branch."
633+
" Specify branch or hash."
634+
)
635+
# TODO: remove in 0.20. Currently `default_repo_branch` is sent only for backward compatibility of `dstack-runner`.
636+
repo_branch = default_repo_branch
637+
repo.run_repo_data.repo_branch = repo_branch
632638
if repo_hash is not None:
633639
repo.run_repo_data.repo_hash = repo_hash
634640

src/dstack/_internal/core/services/repos.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,74 +26,86 @@ class InvalidRepoCredentialsError(DstackError):
2626
pass
2727

2828

29-
def get_repo_creds(
29+
def get_repo_creds_and_default_branch(
3030
repo_url: str,
3131
identity_file: Optional[PathLike] = None,
3232
private_key: Optional[str] = None,
3333
oauth_token: Optional[str] = None,
34-
) -> RemoteRepoCreds:
34+
) -> tuple[RemoteRepoCreds, Optional[str]]:
3535
url = GitRepoURL.parse(repo_url, get_ssh_config=get_host_config)
3636

3737
# no auth
3838
with suppress(InvalidRepoCredentialsError):
39-
creds = _get_repo_creds_https(url)
40-
logger.debug("Git repo %s is public. Using no auth.", repo_url)
41-
return creds
39+
creds, default_branch = _get_repo_creds_and_default_branch_https(url)
40+
logger.debug(
41+
"Git repo %s is public. Using no auth. Default branch: %s", repo_url, default_branch
42+
)
43+
return creds, default_branch
4244

4345
# ssh key provided by the user or pulled from the server
4446
if identity_file is not None or private_key is not None:
4547
if identity_file is not None:
4648
private_key = _read_private_key(identity_file)
47-
creds = _get_repo_creds_ssh(url, identity_file, private_key)
49+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
50+
url, identity_file, private_key
51+
)
4852
logger.debug(
49-
"Git repo %s is private. Using identity file: %s.",
53+
"Git repo %s is private. Using identity file: %s. Default branch: %s",
5054
repo_url,
5155
identity_file,
56+
default_branch,
5257
)
53-
return creds
58+
return creds, default_branch
5459
elif private_key is not None:
5560
with NamedTemporaryFile("w+", 0o600) as f:
5661
f.write(private_key)
5762
f.flush()
58-
creds = _get_repo_creds_ssh(url, f.name, private_key)
63+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
64+
url, f.name, private_key
65+
)
5966
masked_key = "***" + private_key[-10:] if len(private_key) > 10 else "***MASKED***"
6067
logger.debug(
6168
"Git repo %s is private. Using private key: %s. Default branch: %s",
6269
repo_url,
6370
masked_key,
71+
default_branch,
6472
)
65-
return creds
73+
return creds, default_branch
6674
else:
6775
assert False, "should not reach here"
6876

6977
# oauth token provided by the user or pulled from the server
7078
if oauth_token is not None:
71-
creds = _get_repo_creds_https(url, oauth_token)
79+
creds, default_branch = _get_repo_creds_and_default_branch_https(url, oauth_token)
7280
masked_token = (
7381
len(oauth_token[:-4]) * "*" + oauth_token[-4:]
7482
if len(oauth_token) > 4
7583
else "***MASKED***"
7684
)
7785
logger.debug(
78-
"Git repo %s is private. Using provided OAuth token: %s.",
86+
"Git repo %s is private. Using provided OAuth token: %s. Default branch: %s",
7987
repo_url,
8088
masked_token,
89+
default_branch,
8190
)
82-
return creds
91+
return creds, default_branch
8392

8493
# key from ssh config
8594
identities = get_host_config(url.original_host).get("identityfile")
8695
if identities:
8796
_identity_file = identities[0]
8897
with suppress(InvalidRepoCredentialsError):
8998
_private_key = _read_private_key(_identity_file)
90-
creds = _get_repo_creds_ssh(url, _identity_file, _private_key)
99+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
100+
url, _identity_file, _private_key
101+
)
91102
logger.debug(
92-
"Git repo %s is private. Using SSH config identity file: %s.",
103+
"Git repo %s is private. Using SSH config identity file: %s. Default branch: %s",
93104
repo_url,
94105
_identity_file,
106+
default_branch,
95107
)
96-
return creds
108+
return creds, default_branch
97109

98110
# token from gh config
99111
if os.path.exists(gh_config_path):
@@ -102,44 +114,48 @@ def get_repo_creds(
102114
_oauth_token = gh_hosts.get(url.host, {}).get("oauth_token")
103115
if _oauth_token is not None:
104116
with suppress(InvalidRepoCredentialsError):
105-
creds = _get_repo_creds_https(url, _oauth_token)
117+
creds, default_branch = _get_repo_creds_and_default_branch_https(url, _oauth_token)
106118
masked_token = (
107119
len(_oauth_token[:-4]) * "*" + _oauth_token[-4:]
108120
if len(_oauth_token) > 4
109121
else "***MASKED***"
110122
)
111123
logger.debug(
112-
"Git repo %s is private. Using GitHub config token: %s from %s.",
124+
"Git repo %s is private. Using GitHub config token: %s from %s. Default branch: %s",
113125
repo_url,
114126
masked_token,
115127
gh_config_path,
128+
default_branch,
116129
)
117-
return creds
130+
return creds, default_branch
118131

119132
# default user key
120133
if os.path.exists(default_ssh_key):
121134
with suppress(InvalidRepoCredentialsError):
122135
_private_key = _read_private_key(default_ssh_key)
123-
creds = _get_repo_creds_ssh(url, default_ssh_key, _private_key)
136+
creds, default_branch = _get_repo_creds_and_default_branch_ssh(
137+
url, default_ssh_key, _private_key
138+
)
124139
logger.debug(
125-
"Git repo %s is private. Using default identity file: %s.",
140+
"Git repo %s is private. Using default identity file: %s. Default branch: %s",
126141
repo_url,
127142
default_ssh_key,
143+
default_branch,
128144
)
129-
return creds
145+
return creds, default_branch
130146

131147
raise InvalidRepoCredentialsError(
132148
"No valid default Git credentials found. Pass valid `--token` or `--git-identity`."
133149
)
134150

135151

136-
def _get_repo_creds_ssh(
152+
def _get_repo_creds_and_default_branch_ssh(
137153
url: GitRepoURL, identity_file: PathLike, private_key: str
138-
) -> RemoteRepoCreds:
154+
) -> tuple[RemoteRepoCreds, Optional[str]]:
139155
_url = url.as_ssh()
140156
env = _make_git_env_for_creds_check(identity_file=identity_file)
141157
try:
142-
_get_repo_default_branch(_url, env)
158+
default_branch = _get_repo_default_branch(_url, env)
143159
except GitCommandError as e:
144160
message = f"Cannot access `{_url}` using the `{identity_file}` private SSH key"
145161
raise InvalidRepoCredentialsError(message) from e
@@ -148,14 +164,16 @@ def _get_repo_creds_ssh(
148164
private_key=private_key,
149165
oauth_token=None,
150166
)
151-
return creds
167+
return creds, default_branch
152168

153169

154-
def _get_repo_creds_https(url: GitRepoURL, oauth_token: Optional[str] = None) -> RemoteRepoCreds:
170+
def _get_repo_creds_and_default_branch_https(
171+
url: GitRepoURL, oauth_token: Optional[str] = None
172+
) -> tuple[RemoteRepoCreds, Optional[str]]:
155173
_url = url.as_https()
156174
env = _make_git_env_for_creds_check()
157175
try:
158-
_get_repo_default_branch(url.as_https(oauth_token), env)
176+
default_branch = _get_repo_default_branch(url.as_https(oauth_token), env)
159177
except GitCommandError as e:
160178
message = f"Cannot access `{_url}`"
161179
if oauth_token is not None:
@@ -167,7 +185,7 @@ def _get_repo_creds_https(url: GitRepoURL, oauth_token: Optional[str] = None) ->
167185
private_key=None,
168186
oauth_token=oauth_token,
169187
)
170-
return creds
188+
return creds, default_branch
171189

172190

173191
def _make_git_env_for_creds_check(identity_file: Optional[PathLike] = None) -> dict[str, str]:

src/dstack/api/_public/repos.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dstack._internal.core.services.configs import ConfigManager
1515
from dstack._internal.core.services.repos import (
1616
InvalidRepoCredentialsError,
17-
get_repo_creds,
17+
get_repo_creds_and_default_branch,
1818
load_repo,
1919
)
2020
from dstack._internal.utils.logging import get_logger
@@ -76,7 +76,7 @@ def init(
7676
if creds is None and isinstance(repo, RemoteRepo):
7777
assert repo.repo_url is not None
7878
try:
79-
creds = get_repo_creds(
79+
creds, _ = get_repo_creds_and_default_branch(
8080
repo_url=repo.repo_url,
8181
identity_file=git_identity_file,
8282
oauth_token=oauth_token,

0 commit comments

Comments
 (0)