|
10 | 10 | import http.client |
11 | 11 | import urllib.parse |
12 | 12 | import ssl |
| 13 | +import base64 |
13 | 14 |
|
14 | 15 | from typing import Optional, Dict, List, Tuple, Union, BinaryIO, Any |
15 | 16 |
|
@@ -133,9 +134,11 @@ def connect(self): |
133 | 134 | class VmmClient: |
134 | 135 | """A unified HTTP client that supports both regular HTTP and Unix Domain Sockets.""" |
135 | 136 |
|
136 | | - def __init__(self, base_url: str): |
| 137 | + def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password: Optional[str] = None): |
137 | 138 | self.base_url = base_url.rstrip('/') |
138 | 139 | self.use_uds = self.base_url.startswith('unix:') |
| 140 | + self.auth_user = auth_user |
| 141 | + self.auth_password = auth_password |
139 | 142 |
|
140 | 143 | if self.use_uds: |
141 | 144 | self.uds_path = self.base_url[5:] # Remove 'unix:' prefix |
@@ -163,6 +166,13 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, |
163 | 166 | if headers is None: |
164 | 167 | headers = {} |
165 | 168 |
|
| 169 | + # Add Basic Authentication header if credentials are provided |
| 170 | + if self.auth_user and self.auth_password: |
| 171 | + credentials = f"{self.auth_user}:{self.auth_password}" |
| 172 | + encoded_credentials = base64.b64encode( |
| 173 | + credentials.encode('utf-8')).decode('ascii') |
| 174 | + headers['Authorization'] = f'Basic {encoded_credentials}' |
| 175 | + |
166 | 176 | # Prepare the body |
167 | 177 | if isinstance(body, dict): |
168 | 178 | body = json.dumps(body).encode('utf-8') |
@@ -214,12 +224,12 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, |
214 | 224 |
|
215 | 225 |
|
216 | 226 | class VmmCLI: |
217 | | - def __init__(self, base_url: str): |
| 227 | + def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password: Optional[str] = None): |
218 | 228 | self.base_url = base_url.rstrip('/') |
219 | 229 | self.headers = { |
220 | 230 | 'Content-Type': 'application/json' |
221 | 231 | } |
222 | | - self.client = VmmClient(base_url) |
| 232 | + self.client = VmmClient(base_url, auth_user, auth_password) |
223 | 233 |
|
224 | 234 | def rpc_call(self, method: str, params: Optional[Dict] = None) -> Dict: |
225 | 235 | """Make an RPC call to the dstack-vmm API""" |
@@ -796,6 +806,14 @@ def main(): |
796 | 806 | parser.add_argument( |
797 | 807 | '--url', default=default_url, help='dstack-vmm API URL (can also be set via DSTACK_VMM_URL env var)') |
798 | 808 |
|
| 809 | + # Basic authentication arguments |
| 810 | + parser.add_argument( |
| 811 | + '--auth-user', default=os.environ.get('DSTACK_VMM_AUTH_USER'), |
| 812 | + help='Basic auth username (can also be set via DSTACK_VMM_AUTH_USER env var)') |
| 813 | + parser.add_argument( |
| 814 | + '--auth-password', default=os.environ.get('DSTACK_VMM_AUTH_PASSWORD'), |
| 815 | + help='Basic auth password (can also be set via DSTACK_VMM_AUTH_PASSWORD env var)') |
| 816 | + |
799 | 817 | subparsers = parser.add_subparsers(dest='command', help='Commands') |
800 | 818 |
|
801 | 819 | # List command |
@@ -935,7 +953,7 @@ def main(): |
935 | 953 |
|
936 | 954 | args = parser.parse_args() |
937 | 955 |
|
938 | | - cli = VmmCLI(args.url) |
| 956 | + cli = VmmCLI(args.url, args.auth_user, args.auth_password) |
939 | 957 |
|
940 | 958 | if args.command == 'lsvm': |
941 | 959 | cli.list_vms(args.verbose) |
|
0 commit comments