16
16
import urllib .request
17
17
from urllib .parse import urlparse
18
18
import re
19
+ import base64
20
+ from dataclasses import dataclass
21
+
22
+
23
+ @dataclass
24
+ class Endpoint :
25
+ href : str
26
+ headers : dict [str , str ]
27
+
28
+ def update_headers (self , d : dict [str , str ]):
29
+ self .headers .update ((k .capitalize (), v ) for k , v in d .items ())
30
+
19
31
20
32
sources = [pathlib .Path (arg ).resolve () for arg in sys .argv [1 :]]
21
33
source_dir = pathlib .Path (os .path .commonpath (src .parent for src in sources ))
22
34
source_dir = subprocess .check_output (["git" , "rev-parse" , "--show-toplevel" ], cwd = source_dir , text = True ).strip ()
23
35
24
36
37
+ def get_env (s , sep = "=" ):
38
+ ret = {}
39
+ for m in re .finditer (fr'(.*?){ sep } (.*)' , s , re .M ):
40
+ ret .setdefault (* m .groups ())
41
+ return ret
42
+
43
+
44
+ def git (* args , ** kwargs ):
45
+ return subprocess .run (("git" ,) + args , stdout = subprocess .PIPE , text = True , cwd = source_dir , ** kwargs ).stdout .strip ()
46
+
47
+
25
48
def get_endpoint ():
26
- lfs_env = subprocess .check_output (["git" , "lfs" , "env" ], text = True , cwd = source_dir )
27
- endpoint = ssh_server = ssh_path = None
28
- endpoint_re = re .compile (r'Endpoint(?: \(\S+\))?=(\S+)' )
29
- ssh_re = re .compile (r'\s*SSH=(\S*):(.*)' )
30
- credentials_re = re .compile (r'^password=(.*)$' , re .M )
31
- for line in lfs_env .splitlines ():
32
- m = endpoint_re .match (line )
33
- if m :
34
- if endpoint is None :
35
- endpoint = m [1 ]
36
- else :
37
- break
38
- m = ssh_re .match (line )
39
- if m :
40
- ssh_server , ssh_path = m .groups ()
41
- break
42
- assert endpoint , f"no Endpoint= line found in git lfs env:\n { lfs_env } "
43
- headers = {
49
+ lfs_env = get_env (subprocess .check_output (["git" , "lfs" , "env" ], text = True , cwd = source_dir ))
50
+ endpoint = next (v for k , v in lfs_env .items () if k .startswith ('Endpoint' ))
51
+ endpoint , _ , _ = endpoint .partition (' ' )
52
+ ssh_endpoint = lfs_env .get (" SSH" )
53
+ endpoint = Endpoint (endpoint , {
44
54
"Content-Type" : "application/vnd.git-lfs+json" ,
45
55
"Accept" : "application/vnd.git-lfs+json" ,
46
- }
47
- if ssh_server :
56
+ })
57
+ if ssh_endpoint :
58
+ # see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
59
+ server , _ , path = ssh_endpoint .partition (":" )
48
60
ssh_command = shutil .which (os .environ .get ("GIT_SSH" , os .environ .get ("GIT_SSH_COMMAND" , "ssh" )))
49
61
assert ssh_command , "no ssh command found"
50
- with subprocess .Popen ([ssh_command , ssh_server , "git-lfs-authenticate" , ssh_path , "download" ],
51
- stdout = subprocess .PIPE ) as ssh :
52
- resp = json .load (ssh .stdout )
53
- assert ssh .wait () == 0 , "ssh command failed"
54
- endpoint = resp .get ("href" , endpoint )
55
- for k , v in resp .get ("header" , {}).items ():
56
- headers [k .capitalize ()] = v
57
- url = urlparse (endpoint )
62
+ resp = json .loads (subprocess .check_output ([ssh_command , server , "git-lfs-authenticate" , path , "download" ]))
63
+ endpoint .href = resp .get ("href" , endpoint )
64
+ endpoint .update_headers (resp .get ("header" , {}))
65
+ url = urlparse (endpoint .href )
58
66
# this is how actions/checkout persist credentials
59
67
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
60
- auth = subprocess .run (["git" , "config" , f"http.{ url .scheme } ://{ url .netloc } /.extraheader" ], text = True ,
61
- stdout = subprocess .PIPE , cwd = source_dir ).stdout .strip ()
62
- for l in auth .splitlines ():
63
- k , _ , v = l .partition (": " )
64
- headers [k .capitalize ()] = v
68
+ auth = git ("config" , f"http.{ url .scheme } ://{ url .netloc } /.extraheader" )
69
+ endpoint .update_headers (get_env (auth , sep = ": " ))
65
70
if "GITHUB_TOKEN" in os .environ :
66
- headers ["Authorization" ] = f"token { os .environ ['GITHUB_TOKEN' ]} "
67
- if "Authorization" not in headers :
68
- credentials = subprocess .run (["git" , "credential" , "fill" ], cwd = source_dir , stdout = subprocess .PIPE , text = True ,
69
- input = f"protocol={ url .scheme } \n host={ url .netloc } \n path={ url .path [1 :]} \n " ,
70
- check = True ).stdout
71
- m = credentials_re .search (credentials )
72
- if m :
73
- headers ["Authorization" ] = f"token { m [1 ]} "
74
- else :
75
- print (f"WARNING: no auth credentials found for { endpoint } " )
76
- return endpoint , headers
71
+ endpoint .headers ["Authorization" ] = f"token { os .environ ['GITHUB_TOKEN' ]} "
72
+ if "Authorization" not in endpoint .headers :
73
+ # last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
74
+ # see https://git-scm.com/docs/git-credential
75
+ credentials = get_env (git ("credential" , "fill" , check = True ,
76
+ # drop leading / from url.path
77
+ input = f"protocol={ url .scheme } \n host={ url .netloc } \n path={ url .path [1 :]} \n " ))
78
+ auth = base64 .b64encode (f'{ credentials ["username" ]} :{ credentials ["password" ]} ' .encode ()).decode ('ascii' )
79
+ endpoint .headers ["Authorization" ] = f"Basic { auth } "
80
+ return endpoint
77
81
78
82
79
83
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
80
84
def get_locations (objects ):
81
- href , headers = get_endpoint ()
85
+ endpoint = get_endpoint ()
82
86
indexes = [i for i , o in enumerate (objects ) if o ]
83
87
ret = ["local" for _ in objects ]
84
88
req = urllib .request .Request (
85
- f"{ href } /objects/batch" ,
86
- headers = headers ,
89
+ f"{ endpoint . href } /objects/batch" ,
90
+ headers = endpoint . headers ,
87
91
data = json .dumps ({
88
92
"operation" : "download" ,
89
93
"transfers" : ["basic" ],
@@ -93,7 +97,7 @@ def get_locations(objects):
93
97
)
94
98
with urllib .request .urlopen (req ) as resp :
95
99
data = json .load (resp )
96
- assert len (data ["objects" ]) == len (indexes ), data
100
+ assert len (data ["objects" ]) == len (indexes ), f"received { len ( data ) } objects, expected { len ( indexes ) } "
97
101
for i , resp in zip (indexes , data ["objects" ]):
98
102
ret [i ] = f'{ resp ["oid" ]} { resp ["actions" ]["download" ]["href" ]} '
99
103
return ret
@@ -106,14 +110,10 @@ def get_lfs_object(path):
106
110
sha256 = size = None
107
111
if lfs_header != actual_header :
108
112
return None
109
- for line in fileobj :
110
- line = line .decode ('ascii' ).strip ()
111
- if line .startswith ("oid sha256:" ):
112
- sha256 = line [len ("oid sha256:" ):]
113
- elif line .startswith ("size " ):
114
- size = int (line [len ("size " ):])
115
- if not (sha256 and line ):
116
- raise Exception ("malformed pointer file" )
113
+ data = get_env (fileobj .read ().decode ('ascii' ), sep = ' ' )
114
+ assert data ['oid' ].startswith ('sha256:' ), f"unknown oid type: { data ['oid' ]} "
115
+ _ , _ , sha256 = data ['oid' ].partition (':' )
116
+ size = int (data ['size' ])
117
117
return {"oid" : sha256 , "size" : size }
118
118
119
119
0 commit comments