@@ -224,6 +224,14 @@ static void jit_fill_hole(void *area, unsigned int size)
224
224
225
225
struct jit_context {
226
226
int cleanup_addr ; /* Epilogue code offset */
227
+
228
+ /*
229
+ * Program specific offsets of labels in the code; these rely on the
230
+ * JIT doing at least 2 passes, recording the position on the first
231
+ * pass, only to generate the correct offset on the second pass.
232
+ */
233
+ int tail_call_direct_label ;
234
+ int tail_call_indirect_label ;
227
235
};
228
236
229
237
/* Maximum number of bytes emitted while JITing one eBPF insn */
@@ -379,22 +387,6 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
379
387
return __bpf_arch_text_poke (ip , t , old_addr , new_addr , true);
380
388
}
381
389
382
- static int get_pop_bytes (bool * callee_regs_used )
383
- {
384
- int bytes = 0 ;
385
-
386
- if (callee_regs_used [3 ])
387
- bytes += 2 ;
388
- if (callee_regs_used [2 ])
389
- bytes += 2 ;
390
- if (callee_regs_used [1 ])
391
- bytes += 2 ;
392
- if (callee_regs_used [0 ])
393
- bytes += 1 ;
394
-
395
- return bytes ;
396
- }
397
-
398
390
/*
399
391
* Generate the following code:
400
392
*
@@ -410,29 +402,12 @@ static int get_pop_bytes(bool *callee_regs_used)
410
402
* out:
411
403
*/
412
404
static void emit_bpf_tail_call_indirect (u8 * * pprog , bool * callee_regs_used ,
413
- u32 stack_depth )
405
+ u32 stack_depth , u8 * ip ,
406
+ struct jit_context * ctx )
414
407
{
415
408
int tcc_off = -4 - round_up (stack_depth , 8 );
416
- u8 * prog = * pprog ;
417
- int pop_bytes = 0 ;
418
- int off1 = 42 ;
419
- int off2 = 31 ;
420
- int off3 = 9 ;
421
-
422
- /* count the additional bytes used for popping callee regs from stack
423
- * that need to be taken into account for each of the offsets that
424
- * are used for bailing out of the tail call
425
- */
426
- pop_bytes = get_pop_bytes (callee_regs_used );
427
- off1 += pop_bytes ;
428
- off2 += pop_bytes ;
429
- off3 += pop_bytes ;
430
-
431
- if (stack_depth ) {
432
- off1 += 7 ;
433
- off2 += 7 ;
434
- off3 += 7 ;
435
- }
409
+ u8 * prog = * pprog , * start = * pprog ;
410
+ int offset ;
436
411
437
412
/*
438
413
* rdi - pointer to ctx
@@ -447,17 +422,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
447
422
EMIT2 (0x89 , 0xD2 ); /* mov edx, edx */
448
423
EMIT3 (0x39 , 0x56 , /* cmp dword ptr [rsi + 16], edx */
449
424
offsetof(struct bpf_array , map .max_entries ));
450
- #define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
451
- EMIT2 (X86_JBE , OFFSET1 ); /* jbe out */
425
+
426
+ offset = ctx -> tail_call_indirect_label - (prog + 2 - start );
427
+ EMIT2 (X86_JBE , offset ); /* jbe out */
452
428
453
429
/*
454
430
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
455
431
* goto out;
456
432
*/
457
433
EMIT2_off32 (0x8B , 0x85 , tcc_off ); /* mov eax, dword ptr [rbp - tcc_off] */
458
434
EMIT3 (0x83 , 0xF8 , MAX_TAIL_CALL_CNT ); /* cmp eax, MAX_TAIL_CALL_CNT */
459
- #define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
460
- EMIT2 (X86_JA , OFFSET2 ); /* ja out */
435
+
436
+ offset = ctx -> tail_call_indirect_label - (prog + 2 - start );
437
+ EMIT2 (X86_JA , offset ); /* ja out */
461
438
EMIT3 (0x83 , 0xC0 , 0x01 ); /* add eax, 1 */
462
439
EMIT2_off32 (0x89 , 0x85 , tcc_off ); /* mov dword ptr [rbp - tcc_off], eax */
463
440
@@ -470,12 +447,11 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
470
447
* goto out;
471
448
*/
472
449
EMIT3 (0x48 , 0x85 , 0xC9 ); /* test rcx,rcx */
473
- #define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
474
- EMIT2 (X86_JE , OFFSET3 ); /* je out */
475
450
476
- * pprog = prog ;
477
- pop_callee_regs (pprog , callee_regs_used );
478
- prog = * pprog ;
451
+ offset = ctx -> tail_call_indirect_label - (prog + 2 - start );
452
+ EMIT2 (X86_JE , offset ); /* je out */
453
+
454
+ pop_callee_regs (& prog , callee_regs_used );
479
455
480
456
EMIT1 (0x58 ); /* pop rax */
481
457
if (stack_depth )
@@ -495,67 +471,49 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
495
471
RETPOLINE_RCX_BPF_JIT ();
496
472
497
473
/* out: */
474
+ ctx -> tail_call_indirect_label = prog - start ;
498
475
* pprog = prog ;
499
476
}
500
477
501
478
static void emit_bpf_tail_call_direct (struct bpf_jit_poke_descriptor * poke ,
502
- u8 * * pprog , int addr , u8 * image ,
503
- bool * callee_regs_used , u32 stack_depth )
479
+ u8 * * pprog , u8 * ip ,
480
+ bool * callee_regs_used , u32 stack_depth ,
481
+ struct jit_context * ctx )
504
482
{
505
483
int tcc_off = -4 - round_up (stack_depth , 8 );
506
- u8 * prog = * pprog ;
507
- int pop_bytes = 0 ;
508
- int off1 = 20 ;
509
- int poke_off ;
510
-
511
- /* count the additional bytes used for popping callee regs to stack
512
- * that need to be taken into account for jump offset that is used for
513
- * bailing out from of the tail call when limit is reached
514
- */
515
- pop_bytes = get_pop_bytes (callee_regs_used );
516
- off1 += pop_bytes ;
517
-
518
- /*
519
- * total bytes for:
520
- * - nop5/ jmpq $off
521
- * - pop callee regs
522
- * - sub rsp, $val if depth > 0
523
- * - pop rax
524
- */
525
- poke_off = X86_PATCH_SIZE + pop_bytes + 1 ;
526
- if (stack_depth ) {
527
- poke_off += 7 ;
528
- off1 += 7 ;
529
- }
484
+ u8 * prog = * pprog , * start = * pprog ;
485
+ int offset ;
530
486
531
487
/*
532
488
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
533
489
* goto out;
534
490
*/
535
491
EMIT2_off32 (0x8B , 0x85 , tcc_off ); /* mov eax, dword ptr [rbp - tcc_off] */
536
492
EMIT3 (0x83 , 0xF8 , MAX_TAIL_CALL_CNT ); /* cmp eax, MAX_TAIL_CALL_CNT */
537
- EMIT2 (X86_JA , off1 ); /* ja out */
493
+
494
+ offset = ctx -> tail_call_direct_label - (prog + 2 - start );
495
+ EMIT2 (X86_JA , offset ); /* ja out */
538
496
EMIT3 (0x83 , 0xC0 , 0x01 ); /* add eax, 1 */
539
497
EMIT2_off32 (0x89 , 0x85 , tcc_off ); /* mov dword ptr [rbp - tcc_off], eax */
540
498
541
- poke -> tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE );
499
+ poke -> tailcall_bypass = ip + (prog - start );
542
500
poke -> adj_off = X86_TAIL_CALL_OFFSET ;
543
- poke -> tailcall_target = image + ( addr - X86_PATCH_SIZE ) ;
501
+ poke -> tailcall_target = ip + ctx -> tail_call_direct_label - X86_PATCH_SIZE ;
544
502
poke -> bypass_addr = (u8 * )poke -> tailcall_target + X86_PATCH_SIZE ;
545
503
546
504
emit_jump (& prog , (u8 * )poke -> tailcall_target + X86_PATCH_SIZE ,
547
505
poke -> tailcall_bypass );
548
506
549
- * pprog = prog ;
550
- pop_callee_regs (pprog , callee_regs_used );
551
- prog = * pprog ;
507
+ pop_callee_regs (& prog , callee_regs_used );
552
508
EMIT1 (0x58 ); /* pop rax */
553
509
if (stack_depth )
554
510
EMIT3_off32 (0x48 , 0x81 , 0xC4 , round_up (stack_depth , 8 ));
555
511
556
512
memcpy (prog , x86_nops [5 ], X86_PATCH_SIZE );
557
513
prog += X86_PATCH_SIZE ;
514
+
558
515
/* out: */
516
+ ctx -> tail_call_direct_label = prog - start ;
559
517
560
518
* pprog = prog ;
561
519
}
@@ -1411,13 +1369,16 @@ st: if (is_imm8(insn->off))
1411
1369
case BPF_JMP | BPF_TAIL_CALL :
1412
1370
if (imm32 )
1413
1371
emit_bpf_tail_call_direct (& bpf_prog -> aux -> poke_tab [imm32 - 1 ],
1414
- & prog , addrs [i ], image ,
1372
+ & prog , image + addrs [i - 1 ] ,
1415
1373
callee_regs_used ,
1416
- bpf_prog -> aux -> stack_depth );
1374
+ bpf_prog -> aux -> stack_depth ,
1375
+ ctx );
1417
1376
else
1418
1377
emit_bpf_tail_call_indirect (& prog ,
1419
1378
callee_regs_used ,
1420
- bpf_prog -> aux -> stack_depth );
1379
+ bpf_prog -> aux -> stack_depth ,
1380
+ image + addrs [i - 1 ],
1381
+ ctx );
1421
1382
break ;
1422
1383
1423
1384
/* cond jump */
0 commit comments