@@ -177,7 +177,7 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
177
177
return ctx ;
178
178
}
179
179
180
- static int verify_and_dec_payload (struct snp_guest_dev * snp_dev , void * payload , u32 sz )
180
+ static int verify_and_dec_payload (struct snp_guest_dev * snp_dev , struct snp_guest_req * req )
181
181
{
182
182
struct snp_guest_msg * resp_msg = & snp_dev -> secret_response ;
183
183
struct snp_guest_msg * req_msg = & snp_dev -> secret_request ;
@@ -206,20 +206,19 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
206
206
* If the message size is greater than our buffer length then return
207
207
* an error.
208
208
*/
209
- if (unlikely ((resp_msg_hdr -> msg_sz + ctx -> authsize ) > sz ))
209
+ if (unlikely ((resp_msg_hdr -> msg_sz + ctx -> authsize ) > req -> resp_sz ))
210
210
return - EBADMSG ;
211
211
212
212
/* Decrypt the payload */
213
213
memcpy (iv , & resp_msg_hdr -> msg_seqno , min (sizeof (iv ), sizeof (resp_msg_hdr -> msg_seqno )));
214
- if (!aesgcm_decrypt (ctx , payload , resp_msg -> payload , resp_msg_hdr -> msg_sz ,
214
+ if (!aesgcm_decrypt (ctx , req -> resp_buf , resp_msg -> payload , resp_msg_hdr -> msg_sz ,
215
215
& resp_msg_hdr -> algo , AAD_LEN , iv , resp_msg_hdr -> authtag ))
216
216
return - EBADMSG ;
217
217
218
218
return 0 ;
219
219
}
220
220
221
- static int enc_payload (struct snp_guest_dev * snp_dev , u64 seqno , int version , u8 type ,
222
- void * payload , size_t sz )
221
+ static int enc_payload (struct snp_guest_dev * snp_dev , u64 seqno , struct snp_guest_req * req )
223
222
{
224
223
struct snp_guest_msg * msg = & snp_dev -> secret_request ;
225
224
struct snp_guest_msg_hdr * hdr = & msg -> hdr ;
@@ -231,11 +230,11 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
231
230
hdr -> algo = SNP_AEAD_AES_256_GCM ;
232
231
hdr -> hdr_version = MSG_HDR_VER ;
233
232
hdr -> hdr_sz = sizeof (* hdr );
234
- hdr -> msg_type = type ;
235
- hdr -> msg_version = version ;
233
+ hdr -> msg_type = req -> msg_type ;
234
+ hdr -> msg_version = req -> msg_version ;
236
235
hdr -> msg_seqno = seqno ;
237
- hdr -> msg_vmpck = vmpck_id ;
238
- hdr -> msg_sz = sz ;
236
+ hdr -> msg_vmpck = req -> vmpck_id ;
237
+ hdr -> msg_sz = req -> req_sz ;
239
238
240
239
/* Verify the sequence number is non-zero */
241
240
if (!hdr -> msg_seqno )
@@ -244,17 +243,17 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
244
243
pr_debug ("request [seqno %lld type %d version %d sz %d]\n" ,
245
244
hdr -> msg_seqno , hdr -> msg_type , hdr -> msg_version , hdr -> msg_sz );
246
245
247
- if (WARN_ON ((sz + ctx -> authsize ) > sizeof (msg -> payload )))
246
+ if (WARN_ON ((req -> req_sz + ctx -> authsize ) > sizeof (msg -> payload )))
248
247
return - EBADMSG ;
249
248
250
249
memcpy (iv , & hdr -> msg_seqno , min (sizeof (iv ), sizeof (hdr -> msg_seqno )));
251
- aesgcm_encrypt (ctx , msg -> payload , payload , sz , & hdr -> algo , AAD_LEN ,
252
- iv , hdr -> authtag );
250
+ aesgcm_encrypt (ctx , msg -> payload , req -> req_buf , req -> req_sz , & hdr -> algo ,
251
+ AAD_LEN , iv , hdr -> authtag );
253
252
254
253
return 0 ;
255
254
}
256
255
257
- static int __handle_guest_request (struct snp_guest_dev * snp_dev , u64 exit_code ,
256
+ static int __handle_guest_request (struct snp_guest_dev * snp_dev , struct snp_guest_req * req ,
258
257
struct snp_guest_request_ioctl * rio )
259
258
{
260
259
unsigned long req_start = jiffies ;
@@ -269,7 +268,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
269
268
* sequence number must be incremented or the VMPCK must be deleted to
270
269
* prevent reuse of the IV.
271
270
*/
272
- rc = snp_issue_guest_request (exit_code , & snp_dev -> input , rio );
271
+ rc = snp_issue_guest_request (req , & snp_dev -> input , rio );
273
272
switch (rc ) {
274
273
case - ENOSPC :
275
274
/*
@@ -280,7 +279,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
280
279
* IV reuse.
281
280
*/
282
281
override_npages = snp_dev -> input .data_npages ;
283
- exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
282
+ req -> exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
284
283
285
284
/*
286
285
* Override the error to inform callers the given extended
@@ -340,10 +339,8 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
340
339
return rc ;
341
340
}
342
341
343
- static int handle_guest_request (struct snp_guest_dev * snp_dev , u64 exit_code ,
344
- struct snp_guest_request_ioctl * rio , u8 type ,
345
- void * req_buf , size_t req_sz , void * resp_buf ,
346
- u32 resp_sz )
342
+ static int snp_send_guest_request (struct snp_guest_dev * snp_dev , struct snp_guest_req * req ,
343
+ struct snp_guest_request_ioctl * rio )
347
344
{
348
345
u64 seqno ;
349
346
int rc ;
@@ -357,7 +354,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
357
354
memset (snp_dev -> response , 0 , sizeof (struct snp_guest_msg ));
358
355
359
356
/* Encrypt the userspace provided payload in snp_dev->secret_request. */
360
- rc = enc_payload (snp_dev , seqno , rio -> msg_version , type , req_buf , req_sz );
357
+ rc = enc_payload (snp_dev , seqno , req );
361
358
if (rc )
362
359
return rc ;
363
360
@@ -368,7 +365,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
368
365
memcpy (snp_dev -> request , & snp_dev -> secret_request ,
369
366
sizeof (snp_dev -> secret_request ));
370
367
371
- rc = __handle_guest_request (snp_dev , exit_code , rio );
368
+ rc = __handle_guest_request (snp_dev , req , rio );
372
369
if (rc ) {
373
370
if (rc == - EIO &&
374
371
rio -> exitinfo2 == SNP_GUEST_VMM_ERR (SNP_GUEST_VMM_ERR_INVALID_LEN ))
@@ -382,7 +379,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
382
379
return rc ;
383
380
}
384
381
385
- rc = verify_and_dec_payload (snp_dev , resp_buf , resp_sz );
382
+ rc = verify_and_dec_payload (snp_dev , req );
386
383
if (rc ) {
387
384
dev_alert (snp_dev -> dev , "Detected unexpected decode failure from ASP. rc: %d\n" , rc );
388
385
snp_disable_vmpck (snp_dev );
@@ -401,6 +398,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
401
398
{
402
399
struct snp_report_req * report_req = & snp_dev -> req .report ;
403
400
struct snp_report_resp * report_resp ;
401
+ struct snp_guest_req req = {};
404
402
int rc , resp_len ;
405
403
406
404
lockdep_assert_held (& snp_cmd_mutex );
@@ -421,8 +419,16 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
421
419
if (!report_resp )
422
420
return - ENOMEM ;
423
421
424
- rc = handle_guest_request (snp_dev , SVM_VMGEXIT_GUEST_REQUEST , arg , SNP_MSG_REPORT_REQ ,
425
- report_req , sizeof (* report_req ), report_resp -> data , resp_len );
422
+ req .msg_version = arg -> msg_version ;
423
+ req .msg_type = SNP_MSG_REPORT_REQ ;
424
+ req .vmpck_id = vmpck_id ;
425
+ req .req_buf = report_req ;
426
+ req .req_sz = sizeof (* report_req );
427
+ req .resp_buf = report_resp -> data ;
428
+ req .resp_sz = resp_len ;
429
+ req .exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
430
+
431
+ rc = snp_send_guest_request (snp_dev , & req , arg );
426
432
if (rc )
427
433
goto e_free ;
428
434
@@ -438,6 +444,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
438
444
{
439
445
struct snp_derived_key_req * derived_key_req = & snp_dev -> req .derived_key ;
440
446
struct snp_derived_key_resp derived_key_resp = {0 };
447
+ struct snp_guest_req req = {};
441
448
int rc , resp_len ;
442
449
/* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
443
450
u8 buf [64 + 16 ];
@@ -460,8 +467,16 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
460
467
sizeof (* derived_key_req )))
461
468
return - EFAULT ;
462
469
463
- rc = handle_guest_request (snp_dev , SVM_VMGEXIT_GUEST_REQUEST , arg , SNP_MSG_KEY_REQ ,
464
- derived_key_req , sizeof (* derived_key_req ), buf , resp_len );
470
+ req .msg_version = arg -> msg_version ;
471
+ req .msg_type = SNP_MSG_KEY_REQ ;
472
+ req .vmpck_id = vmpck_id ;
473
+ req .req_buf = derived_key_req ;
474
+ req .req_sz = sizeof (* derived_key_req );
475
+ req .resp_buf = buf ;
476
+ req .resp_sz = resp_len ;
477
+ req .exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
478
+
479
+ rc = snp_send_guest_request (snp_dev , & req , arg );
465
480
if (rc )
466
481
return rc ;
467
482
@@ -482,6 +497,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
482
497
{
483
498
struct snp_ext_report_req * report_req = & snp_dev -> req .ext_report ;
484
499
struct snp_report_resp * report_resp ;
500
+ struct snp_guest_req req = {};
485
501
int ret , npages = 0 , resp_len ;
486
502
sockptr_t certs_address ;
487
503
@@ -529,9 +545,17 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
529
545
return - ENOMEM ;
530
546
531
547
snp_dev -> input .data_npages = npages ;
532
- ret = handle_guest_request (snp_dev , SVM_VMGEXIT_EXT_GUEST_REQUEST , arg , SNP_MSG_REPORT_REQ ,
533
- & report_req -> data , sizeof (report_req -> data ),
534
- report_resp -> data , resp_len );
548
+
549
+ req .msg_version = arg -> msg_version ;
550
+ req .msg_type = SNP_MSG_REPORT_REQ ;
551
+ req .vmpck_id = vmpck_id ;
552
+ req .req_buf = & report_req -> data ;
553
+ req .req_sz = sizeof (report_req -> data );
554
+ req .resp_buf = report_resp -> data ;
555
+ req .resp_sz = resp_len ;
556
+ req .exit_code = SVM_VMGEXIT_EXT_GUEST_REQUEST ;
557
+
558
+ ret = snp_send_guest_request (snp_dev , & req , arg );
535
559
536
560
/* If certs length is invalid then copy the returned length */
537
561
if (arg -> vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN ) {
@@ -1057,7 +1081,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
1057
1081
misc -> name = DEVICE_NAME ;
1058
1082
misc -> fops = & snp_guest_fops ;
1059
1083
1060
- /* initial the input address for guest request */
1084
+ /* Initialize the input addresses for guest request */
1061
1085
snp_dev -> input .req_gpa = __pa (snp_dev -> request );
1062
1086
snp_dev -> input .resp_gpa = __pa (snp_dev -> response );
1063
1087
snp_dev -> input .data_gpa = __pa (snp_dev -> certs_data );
0 commit comments