Skip to content

Commit f37508a

Browse files
committed
kms: Auto update certs on start
1 parent 4d1234d commit f37508a

File tree

5 files changed

+60
-13
lines changed

5 files changed

+60
-13
lines changed

dstack-util/src/system_setup.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ impl HostShared {
163163
if src_size > max_size {
164164
bail!("Source file {src} is too large, max size is {max_size} bytes");
165165
}
166-
std::fs::copy(src_path, dst_path)?;
166+
fs_err::copy(src_path, dst_path)?;
167167
Ok(())
168168
};
169169
cmd! {

kms/src/config.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const ROOT_CA_CERT: &str = "root-ca.crt";
1414
const ROOT_CA_KEY: &str = "root-ca.key";
1515
const RPC_CERT: &str = "rpc.crt";
1616
const RPC_KEY: &str = "rpc.key";
17+
const RPC_DOMAIN: &str = "rpc-domain";
1718
const K256_KEY: &str = "root-k256.key";
1819
const BOOTSTRAP_INFO: &str = "bootstrap-info.json";
1920

@@ -60,6 +61,10 @@ impl KmsConfig {
6061
self.cert_dir.join(RPC_KEY)
6162
}
6263

64+
pub fn rpc_domain(&self) -> PathBuf {
65+
self.cert_dir.join(RPC_DOMAIN)
66+
}
67+
6368
pub fn k256_key(&self) -> PathBuf {
6469
self.cert_dir.join(K256_KEY)
6570
}

kms/src/main.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rocket::{
99
response::content::RawHtml,
1010
Shutdown,
1111
};
12-
use tracing::info;
12+
use tracing::{info, warn};
1313

1414
mod config;
1515
// mod ct_log;
@@ -93,6 +93,11 @@ async fn main() -> Result<()> {
9393
}
9494
}
9595

96+
info!("Updating certs");
97+
if let Err(err) = onboard_service::update_certs(&config).await {
98+
warn!("Failed to update certs: {err}");
99+
};
100+
96101
info!("Starting KMS");
97102
info!("Supported methods:");
98103
for method in main_service::rpc_methods() {

kms/src/onboard_service.rs

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use dstack_kms_rpc::{
77
onboard_server::{OnboardRpc, OnboardServer},
88
BootstrapRequest, BootstrapResponse, OnboardRequest, OnboardResponse,
99
};
10+
use fs_err as fs;
1011
use http_client::prpc::PrpcClient;
1112
use k256::ecdsa::SigningKey;
1213
use ra_rpc::{client::RaClient, CallContext, RpcCall};
@@ -102,6 +103,7 @@ struct Keys {
102103
ca_cert: Certificate,
103104
rpc_key: KeyPair,
104105
rpc_cert: Certificate,
106+
rpc_domain: String,
105107
}
106108

107109
impl Keys {
@@ -169,6 +171,7 @@ impl Keys {
169171
ca_cert,
170172
rpc_key,
171173
rpc_cert,
174+
rpc_domain: domain.to_string(),
172175
})
173176
}
174177

@@ -216,25 +219,59 @@ impl Keys {
216219
}
217220

218221
fn store(&self, cfg: &KmsConfig) -> Result<()> {
219-
// Store the temporary CA cert and key
220-
safe_write(cfg.tmp_ca_cert(), self.tmp_ca_cert.pem())?;
221-
safe_write(cfg.tmp_ca_key(), self.tmp_ca_key.serialize_pem())?;
222+
self.store_keys(cfg)?;
223+
self.store_certs(cfg)?;
224+
safe_write(cfg.rpc_domain(), self.rpc_domain.as_bytes())?;
225+
Ok(())
226+
}
222227

223-
// Store the root CA cert and key
224-
safe_write(cfg.root_ca_cert(), self.ca_cert.pem())?;
228+
fn store_keys(&self, cfg: &KmsConfig) -> Result<()> {
229+
safe_write(cfg.tmp_ca_key(), self.tmp_ca_key.serialize_pem())?;
225230
safe_write(cfg.root_ca_key(), self.ca_key.serialize_pem())?;
226-
227-
// Store the RPC cert and key
228-
safe_write(cfg.rpc_cert(), self.rpc_cert.pem())?;
229231
safe_write(cfg.rpc_key(), self.rpc_key.serialize_pem())?;
230-
231-
// Store the ECDSA root key
232232
safe_write(cfg.k256_key(), self.k256_key.to_bytes())?;
233+
Ok(())
234+
}
233235

236+
fn store_certs(&self, cfg: &KmsConfig) -> Result<()> {
237+
safe_write(cfg.tmp_ca_cert(), self.tmp_ca_cert.pem())?;
238+
safe_write(cfg.root_ca_cert(), self.ca_cert.pem())?;
239+
safe_write(cfg.rpc_cert(), self.rpc_cert.pem())?;
234240
Ok(())
235241
}
236242
}
237243

244+
pub(crate) async fn update_certs(cfg: &KmsConfig) -> Result<()> {
245+
// Read existing keys
246+
let tmp_ca_key = KeyPair::from_pem(&fs::read_to_string(cfg.tmp_ca_key())?)?;
247+
let ca_key = KeyPair::from_pem(&fs::read_to_string(cfg.root_ca_key())?)?;
248+
let rpc_key = KeyPair::from_pem(&fs::read_to_string(cfg.rpc_key())?)?;
249+
250+
// Read k256 key
251+
let k256_key_bytes = fs::read(cfg.k256_key())?;
252+
let k256_key = SigningKey::from_slice(&k256_key_bytes)?;
253+
254+
let domain = fs::read_to_string(cfg.rpc_domain())?;
255+
let domain = domain.trim();
256+
257+
// Regenerate certificates using existing keys
258+
let keys = Keys::from_keys(
259+
tmp_ca_key,
260+
ca_key,
261+
rpc_key,
262+
k256_key,
263+
domain,
264+
cfg.onboard.quote_enabled,
265+
)
266+
.await
267+
.context("Failed to regenerate certificates")?;
268+
269+
// Write the new certificates to files
270+
keys.store_certs(cfg)?;
271+
272+
Ok(())
273+
}
274+
238275
pub(crate) async fn bootstrap_keys(cfg: &KmsConfig) -> Result<()> {
239276
let keys = Keys::generate(
240277
&cfg.onboard.auto_bootstrap_domain,

supervisor/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async fn async_main(args: Args) -> Result<()> {
9898
if let Some(uds) = args.uds {
9999
mk_parents(&uds)?;
100100
if args.remove_existing_uds {
101-
std::fs::remove_file(&uds).ok();
101+
fs_err::remove_file(&uds).ok();
102102
}
103103
figment = figment.join(("address", format!("unix:{uds}")));
104104
} else if let Some(address) = args.address {

0 commit comments

Comments
 (0)