Skip to content

Commit 157569a

Browse files
authored
feat: add model name to signatures and attestation report (#50)
* feat: add model name to cryptographic signatures and attestation report Include model_name in the signed text format (model:req_hash:resp_hash) and in the AttestationReport struct so clients can independently verify which model the proxy is serving. * fix: bundle attestation params into struct to satisfy clippy too_many_arguments * chore: fix formatting
1 parent 977d76d commit 157569a

File tree

9 files changed

+71
-37
lines changed

9 files changed

+71
-37
lines changed

src/attestation.rs

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -251,27 +251,38 @@ nU+jXBG7tgClr/DntUBJx+xfNWpxLKE=
251251
);
252252
}
253253
}
254+
/// Parameters for generating an attestation report.
255+
pub struct AttestationParams<'a> {
256+
pub model_name: &'a str,
257+
pub signing_address: &'a str,
258+
pub signing_algo: &'a str,
259+
pub signing_public_key: &'a str,
260+
pub signing_address_bytes: &'a [u8],
261+
pub nonce: Option<&'a str>,
262+
pub gpu_no_hw_mode: bool,
263+
pub tls_cert_fingerprint: Option<&'a str>,
264+
}
265+
254266
/// Generate a complete attestation report.
255267
pub async fn generate_attestation(
256-
signing_address: &str,
257-
signing_algo: &str,
258-
signing_public_key: &str,
259-
signing_address_bytes: &[u8],
260-
nonce: Option<&str>,
261-
gpu_no_hw_mode: bool,
262-
tls_cert_fingerprint: Option<&str>,
268+
params: AttestationParams<'_>,
263269
) -> Result<AttestationReport, AttestationError> {
264-
let nonce_bytes = parse_nonce(nonce)?;
270+
let nonce_bytes = parse_nonce(params.nonce)?;
265271
let nonce_hex = hex::encode(nonce_bytes);
266272

267273
// Build TDX report data (binds cert fingerprint when present)
268-
let fp_bytes = tls_cert_fingerprint
274+
let fp_bytes = params
275+
.tls_cert_fingerprint
269276
.map(hex::decode)
270277
.transpose()
271278
.map_err(|e| {
272279
AttestationError::Internal(anyhow::anyhow!("bad cert fingerprint hex: {e}"))
273280
})?;
274-
let report_data = build_report_data(signing_address_bytes, &nonce_bytes, fp_bytes.as_deref());
281+
let report_data = build_report_data(
282+
params.signing_address_bytes,
283+
&nonce_bytes,
284+
fp_bytes.as_deref(),
285+
);
275286

276287
// Get TDX quote from dstack
277288
let client = dstack_sdk::dstack_client::DstackClient::new(None);
@@ -280,23 +291,24 @@ pub async fn generate_attestation(
280291
serde_json::from_str(&quote_result.event_log).map_err(anyhow::Error::from)?;
281292

282293
// Collect GPU evidence
283-
let gpu_evidence = collect_gpu_evidence(&nonce_hex, gpu_no_hw_mode).await?;
294+
let gpu_evidence = collect_gpu_evidence(&nonce_hex, params.gpu_no_hw_mode).await?;
284295
let nvidia_payload = build_nvidia_payload(&nonce_hex, &gpu_evidence);
285296

286297
// Get system info
287298
let info = client.info().await?;
288299
let info_value = serde_json::to_value(&info).map_err(anyhow::Error::from)?;
289300

290301
Ok(AttestationReport {
291-
signing_address: signing_address.to_string(),
292-
signing_algo: signing_algo.to_string(),
293-
signing_public_key: signing_public_key.to_string(),
302+
model_name: params.model_name.to_string(),
303+
signing_address: params.signing_address.to_string(),
304+
signing_algo: params.signing_algo.to_string(),
305+
signing_public_key: params.signing_public_key.to_string(),
294306
request_nonce: nonce_hex,
295307
intel_quote: quote_result.quote,
296308
nvidia_payload,
297309
event_log,
298310
info: info_value,
299-
tls_cert_fingerprint: tls_cert_fingerprint.map(|s| s.to_string()),
311+
tls_cert_fingerprint: params.tls_cert_fingerprint.map(|s| s.to_string()),
300312
})
301313
}
302314

src/proxy.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ pub struct ProxyOpts {
183183
pub cache: Arc<ChatCache>,
184184
/// Prefix for auto-generated IDs (e.g., "chatcmpl", "img", "emb").
185185
pub id_prefix: String,
186+
/// Model name included in the signed text.
187+
pub model_name: String,
186188
/// If set, report usage to the cloud API after a successful response.
187189
pub usage_reporter: Option<UsageReporter>,
188190
/// What kind of usage to extract from the response.
@@ -251,7 +253,7 @@ pub async fn proxy_json_request(
251253
let response_sha256 = hex::encode(Sha256::digest(response_body.as_bytes()));
252254

253255
// Sign and cache
254-
let text = format!("{request_sha256}:{response_sha256}");
256+
let text = format!("{}:{request_sha256}:{response_sha256}", opts.model_name);
255257
let signed = opts.signing.sign_chat(&text).map_err(|e| {
256258
error!(error = %e, "Signing failed");
257259
AppError::Internal(e)
@@ -304,6 +306,7 @@ pub async fn proxy_streaming_request(
304306
let signing = opts.signing.clone();
305307
let cache = opts.cache.clone();
306308
let usage_reporter = opts.usage_reporter.clone();
309+
let model_name = opts.model_name.clone();
307310

308311
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::io::Error>>(64);
309312

@@ -355,7 +358,7 @@ pub async fn proxy_streaming_request(
355358
if !upstream_error && !downstream_closed && parser.seen_done {
356359
let response_sha256 = hex::encode(hasher.finalize());
357360
if let Some(ref id) = parser.chat_id {
358-
let text = format!("{request_sha256}:{response_sha256}");
361+
let text = format!("{model_name}:{request_sha256}:{response_sha256}");
359362
match signing.sign_chat(&text) {
360363
Ok(signed) => {
361364
if let Ok(signed_json) = serde_json::to_string(&signed) {
@@ -457,7 +460,7 @@ pub async fn proxy_multipart_request(
457460
serde_json::to_string(&response_data).map_err(|e| AppError::Internal(e.into()))?;
458461
let response_sha256 = hex::encode(Sha256::digest(response_body.as_bytes()));
459462

460-
let text = format!("{request_sha256}:{response_sha256}");
463+
let text = format!("{}:{request_sha256}:{response_sha256}", opts.model_name);
461464
let signed = opts.signing.sign_chat(&text).map_err(|e| {
462465
error!(error = %e, "Signing failed");
463466
AppError::Internal(e)
@@ -566,7 +569,7 @@ pub async fn sign_and_cache_json_response(
566569
serde_json::to_string(&response_data).map_err(|e| AppError::Internal(e.into()))?;
567570
let response_sha256 = hex::encode(Sha256::digest(response_body.as_bytes()));
568571

569-
let text = format!("{request_sha256}:{response_sha256}");
572+
let text = format!("{}:{request_sha256}:{response_sha256}", opts.model_name);
570573
let signed = opts.signing.sign_chat(&text).map_err(|e| {
571574
error!(error = %e, "Signing failed");
572575
AppError::Internal(e)
@@ -595,6 +598,7 @@ pub async fn proxy_streaming_response(
595598
let signing = opts.signing.clone();
596599
let cache = opts.cache.clone();
597600
let usage_reporter = opts.usage_reporter.clone();
601+
let model_name = opts.model_name.clone();
598602
let request_sha256 = request_sha256.to_string();
599603

600604
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, std::io::Error>>(64);
@@ -643,7 +647,7 @@ pub async fn proxy_streaming_response(
643647
if !upstream_error && !downstream_closed && parser.seen_done {
644648
let response_sha256 = hex::encode(hasher.finalize());
645649
if let Some(ref id) = parser.chat_id {
646-
let text = format!("{request_sha256}:{response_sha256}");
650+
let text = format!("{model_name}:{request_sha256}:{response_sha256}");
647651
match signing.sign_chat(&text) {
648652
Ok(signed) => {
649653
if let Ok(signed_json) = serde_json::to_string(&signed) {
@@ -842,6 +846,7 @@ mod tests {
842846
signing,
843847
cache,
844848
id_prefix: "test".to_string(),
849+
model_name: "test-model".to_string(),
845850
usage_reporter: None,
846851
usage_type: UsageType::default(),
847852
}

src/routes/attestation.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,20 @@ pub async fn attestation_report(
5353
}
5454
}
5555

56-
let report = crate::attestation::generate_attestation(
57-
&signing_address,
56+
let report = crate::attestation::generate_attestation(crate::attestation::AttestationParams {
57+
model_name: &state.config.model_name,
58+
signing_address: &signing_address,
5859
signing_algo,
59-
&signing_public_key,
60-
&signing_address_bytes,
61-
query.nonce.as_deref(),
62-
state.config.gpu_no_hw_mode,
63-
if query.include_tls_fingerprint.unwrap_or(false) {
60+
signing_public_key: &signing_public_key,
61+
signing_address_bytes: &signing_address_bytes,
62+
nonce: query.nonce.as_deref(),
63+
gpu_no_hw_mode: state.config.gpu_no_hw_mode,
64+
tls_cert_fingerprint: if query.include_tls_fingerprint.unwrap_or(false) {
6465
state.tls_cert_fingerprint.as_deref()
6566
} else {
6667
None
6768
},
68-
)
69+
})
6970
.await
7071
.map_err(|e| match e {
7172
crate::attestation::AttestationError::InvalidNonce(msg) => AppError::BadRequest(msg),

src/routes/catch_all.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ pub async fn catch_all(
200200
signing: state.signing.clone(),
201201
cache: state.cache.clone(),
202202
id_prefix: "pt".to_string(),
203+
model_name: state.config.model_name.clone(),
203204
usage_reporter: reporter,
204205
usage_type: UsageType::ChatCompletion,
205206
};
@@ -221,6 +222,7 @@ pub async fn catch_all(
221222
signing: state.signing.clone(),
222223
cache: state.cache.clone(),
223224
id_prefix: "pt".to_string(),
225+
model_name: state.config.model_name.clone(),
224226
usage_reporter: reporter,
225227
usage_type: UsageType::ChatCompletion,
226228
};

src/routes/chat.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pub async fn chat_completions(
4646
signing: state.signing.clone(),
4747
cache: state.cache.clone(),
4848
id_prefix: "chatcmpl".to_string(),
49+
model_name: state.config.model_name.clone(),
4950
usage_reporter: make_usage_reporter(auth.cloud_api_key.as_ref(), &state),
5051
usage_type: UsageType::ChatCompletion,
5152
};

src/routes/completions.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pub async fn completions(
4444
signing: state.signing.clone(),
4545
cache: state.cache.clone(),
4646
id_prefix: "cmpl".to_string(),
47+
model_name: state.config.model_name.clone(),
4748
usage_reporter: make_usage_reporter(auth.cloud_api_key.as_ref(), &state),
4849
usage_type: UsageType::ChatCompletion,
4950
};

src/routes/passthrough.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ pub async fn images_edits(
139139
signing: state.signing.clone(),
140140
cache: state.cache.clone(),
141141
id_prefix: "img".to_string(),
142+
model_name: state.config.model_name.clone(),
142143
usage_reporter: make_usage_reporter(auth.cloud_api_key.as_ref(), &state),
143144
usage_type: UsageType::ImageGeneration,
144145
};
@@ -185,6 +186,7 @@ pub async fn audio_transcriptions(
185186
signing: state.signing.clone(),
186187
cache: state.cache.clone(),
187188
id_prefix: "trans".to_string(),
189+
model_name: state.config.model_name.clone(),
188190
usage_reporter: make_usage_reporter(auth.cloud_api_key.as_ref(), &state),
189191
usage_type: UsageType::ChatCompletion,
190192
};
@@ -220,6 +222,7 @@ async fn json_passthrough(
220222
signing: state.signing.clone(),
221223
cache: state.cache.clone(),
222224
id_prefix: id_prefix.to_string(),
225+
model_name: state.config.model_name.clone(),
223226
usage_reporter,
224227
usage_type,
225228
};

src/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub struct SignedChat {
1313
/// Attestation report returned by GET /v1/attestation/report.
1414
#[derive(Debug, Clone, Serialize, Deserialize)]
1515
pub struct AttestationReport {
16+
pub model_name: String,
1617
pub signing_address: String,
1718
pub signing_algo: String,
1819
pub signing_public_key: String,

tests/integration.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -744,8 +744,8 @@ async fn test_signature_binds_actual_request_body() {
744744
body["text"]
745745
.as_str()
746746
.unwrap()
747-
.starts_with(&format!("{expected_hash}:")),
748-
"Signed text must start with SHA-256 of actual request body"
747+
.starts_with(&format!("test-model:{expected_hash}:")),
748+
"Signed text must start with model_name:SHA-256 of actual request body"
749749
);
750750
}
751751

@@ -804,7 +804,7 @@ async fn test_x_request_hash_header_is_ignored() {
804804
body["text"]
805805
.as_str()
806806
.unwrap()
807-
.starts_with(&format!("{expected_hash}:")),
807+
.starts_with(&format!("test-model:{expected_hash}:")),
808808
"Signed text must use actual body hash, not X-Request-Hash header"
809809
);
810810
}
@@ -996,7 +996,10 @@ async fn test_ecdsa_signature_cryptographic_verification() {
996996
// Verify text = sha256(request):sha256(response)
997997
let expected_req_hash = hex::encode(Sha256::digest(&request_bytes));
998998
let expected_resp_hash = hex::encode(Sha256::digest(&response_body_bytes));
999-
assert_eq!(text, format!("{expected_req_hash}:{expected_resp_hash}"));
999+
assert_eq!(
1000+
text,
1001+
format!("test-model:{expected_req_hash}:{expected_resp_hash}")
1002+
);
10001003

10011004
// Verify ECDSA EIP-191 signature via key recovery
10021005
let sig_bytes = hex::decode(&signature_hex[2..]).unwrap();
@@ -1092,7 +1095,10 @@ async fn test_ed25519_signature_cryptographic_verification() {
10921095
// Verify text = sha256(request):sha256(response)
10931096
let expected_req_hash = hex::encode(Sha256::digest(&request_bytes));
10941097
let expected_resp_hash = hex::encode(Sha256::digest(&response_body_bytes));
1095-
assert_eq!(text, format!("{expected_req_hash}:{expected_resp_hash}"));
1098+
assert_eq!(
1099+
text,
1100+
format!("test-model:{expected_req_hash}:{expected_resp_hash}")
1101+
);
10961102

10971103
// Verify Ed25519 signature
10981104
let sig_bytes = hex::decode(signature_hex).unwrap();
@@ -1181,17 +1187,19 @@ async fn test_streaming_signature_cached_and_verifiable() {
11811187
let parts: Vec<&str> = text.split(':').collect();
11821188
assert_eq!(
11831189
parts.len(),
1184-
2,
1185-
"Signed text should be request_hash:response_hash"
1190+
3,
1191+
"Signed text should be model_name:request_hash:response_hash"
11861192
);
11871193

1194+
assert_eq!(parts[0], "test-model");
1195+
11881196
// Request hash should match SHA256 of request body
11891197
let expected_req_hash = hex::encode(Sha256::digest(&request_bytes));
1190-
assert_eq!(parts[0], expected_req_hash);
1198+
assert_eq!(parts[1], expected_req_hash);
11911199

11921200
// Response hash should match SHA256 of the streamed bytes
11931201
let expected_resp_hash = hex::encode(Sha256::digest(&stream_body_bytes));
1194-
assert_eq!(parts[1], expected_resp_hash);
1202+
assert_eq!(parts[2], expected_resp_hash);
11951203

11961204
// Signature should be valid format
11971205
let sig_hex = sig_body["signature"].as_str().unwrap();

0 commit comments

Comments
 (0)