@@ -224,6 +224,14 @@ static void jit_fill_hole(void *area, unsigned int size)
224224
225225struct jit_context {
226226 int cleanup_addr ; /* Epilogue code offset */
227+
228+ /*
229+ * Program specific offsets of labels in the code; these rely on the
230+ * JIT doing at least 2 passes, recording the position on the first
231+ * pass, only to generate the correct offset on the second pass.
232+ */
233+ int tail_call_direct_label ;
234+ int tail_call_indirect_label ;
227235};
228236
229237/* Maximum number of bytes emitted while JITing one eBPF insn */
@@ -379,22 +387,6 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
379387 return __bpf_arch_text_poke (ip , t , old_addr , new_addr , true);
380388}
381389
382- static int get_pop_bytes (bool * callee_regs_used )
383- {
384- int bytes = 0 ;
385-
386- if (callee_regs_used [3 ])
387- bytes += 2 ;
388- if (callee_regs_used [2 ])
389- bytes += 2 ;
390- if (callee_regs_used [1 ])
391- bytes += 2 ;
392- if (callee_regs_used [0 ])
393- bytes += 1 ;
394-
395- return bytes ;
396- }
397-
398390/*
399391 * Generate the following code:
400392 *
@@ -410,29 +402,12 @@ static int get_pop_bytes(bool *callee_regs_used)
410402 * out:
411403 */
412404static void emit_bpf_tail_call_indirect (u8 * * pprog , bool * callee_regs_used ,
413- u32 stack_depth )
405+ u32 stack_depth , u8 * ip ,
406+ struct jit_context * ctx )
414407{
415408 int tcc_off = -4 - round_up (stack_depth , 8 );
416- u8 * prog = * pprog ;
417- int pop_bytes = 0 ;
418- int off1 = 42 ;
419- int off2 = 31 ;
420- int off3 = 9 ;
421-
422- /* count the additional bytes used for popping callee regs from stack
423- * that need to be taken into account for each of the offsets that
424- * are used for bailing out of the tail call
425- */
426- pop_bytes = get_pop_bytes (callee_regs_used );
427- off1 += pop_bytes ;
428- off2 += pop_bytes ;
429- off3 += pop_bytes ;
430-
431- if (stack_depth ) {
432- off1 += 7 ;
433- off2 += 7 ;
434- off3 += 7 ;
435- }
409+ u8 * prog = * pprog , * start = * pprog ;
410+ int offset ;
436411
437412 /*
438413 * rdi - pointer to ctx
@@ -447,17 +422,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
447422 EMIT2 (0x89 , 0xD2 ); /* mov edx, edx */
448423 EMIT3 (0x39 , 0x56 , /* cmp dword ptr [rsi + 16], edx */
449424 offsetof(struct bpf_array , map .max_entries ));
450- #define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
451- EMIT2 (X86_JBE , OFFSET1 ); /* jbe out */
425+
426+ offset = ctx -> tail_call_indirect_label - (prog + 2 - start );
427+ EMIT2 (X86_JBE , offset ); /* jbe out */
452428
453429 /*
454430 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
455431 * goto out;
456432 */
457433 EMIT2_off32 (0x8B , 0x85 , tcc_off ); /* mov eax, dword ptr [rbp - tcc_off] */
458434 EMIT3 (0x83 , 0xF8 , MAX_TAIL_CALL_CNT ); /* cmp eax, MAX_TAIL_CALL_CNT */
459- #define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
460- EMIT2 (X86_JA , OFFSET2 ); /* ja out */
435+
436+ offset = ctx -> tail_call_indirect_label - (prog + 2 - start );
437+ EMIT2 (X86_JA , offset ); /* ja out */
461438 EMIT3 (0x83 , 0xC0 , 0x01 ); /* add eax, 1 */
462439 EMIT2_off32 (0x89 , 0x85 , tcc_off ); /* mov dword ptr [rbp - tcc_off], eax */
463440
@@ -470,12 +447,11 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
470447 * goto out;
471448 */
472449 EMIT3 (0x48 , 0x85 , 0xC9 ); /* test rcx,rcx */
473- #define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
474- EMIT2 (X86_JE , OFFSET3 ); /* je out */
475450
476- * pprog = prog ;
477- pop_callee_regs (pprog , callee_regs_used );
478- prog = * pprog ;
451+ offset = ctx -> tail_call_indirect_label - (prog + 2 - start );
452+ EMIT2 (X86_JE , offset ); /* je out */
453+
454+ pop_callee_regs (& prog , callee_regs_used );
479455
480456 EMIT1 (0x58 ); /* pop rax */
481457 if (stack_depth )
@@ -495,67 +471,49 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
495471 RETPOLINE_RCX_BPF_JIT ();
496472
497473 /* out: */
474+ ctx -> tail_call_indirect_label = prog - start ;
498475 * pprog = prog ;
499476}
500477
501478static void emit_bpf_tail_call_direct (struct bpf_jit_poke_descriptor * poke ,
502- u8 * * pprog , int addr , u8 * image ,
503- bool * callee_regs_used , u32 stack_depth )
479+ u8 * * pprog , u8 * ip ,
480+ bool * callee_regs_used , u32 stack_depth ,
481+ struct jit_context * ctx )
504482{
505483 int tcc_off = -4 - round_up (stack_depth , 8 );
506- u8 * prog = * pprog ;
507- int pop_bytes = 0 ;
508- int off1 = 20 ;
509- int poke_off ;
510-
511- /* count the additional bytes used for popping callee regs to stack
512- * that need to be taken into account for jump offset that is used for
513- * bailing out from of the tail call when limit is reached
514- */
515- pop_bytes = get_pop_bytes (callee_regs_used );
516- off1 += pop_bytes ;
517-
518- /*
519- * total bytes for:
520- * - nop5/ jmpq $off
521- * - pop callee regs
522- * - sub rsp, $val if depth > 0
523- * - pop rax
524- */
525- poke_off = X86_PATCH_SIZE + pop_bytes + 1 ;
526- if (stack_depth ) {
527- poke_off += 7 ;
528- off1 += 7 ;
529- }
484+ u8 * prog = * pprog , * start = * pprog ;
485+ int offset ;
530486
531487 /*
532488 * if (tail_call_cnt > MAX_TAIL_CALL_CNT)
533489 * goto out;
534490 */
535491 EMIT2_off32 (0x8B , 0x85 , tcc_off ); /* mov eax, dword ptr [rbp - tcc_off] */
536492 EMIT3 (0x83 , 0xF8 , MAX_TAIL_CALL_CNT ); /* cmp eax, MAX_TAIL_CALL_CNT */
537- EMIT2 (X86_JA , off1 ); /* ja out */
493+
494+ offset = ctx -> tail_call_direct_label - (prog + 2 - start );
495+ EMIT2 (X86_JA , offset ); /* ja out */
538496 EMIT3 (0x83 , 0xC0 , 0x01 ); /* add eax, 1 */
539497 EMIT2_off32 (0x89 , 0x85 , tcc_off ); /* mov dword ptr [rbp - tcc_off], eax */
540498
541- poke -> tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE );
499+ poke -> tailcall_bypass = ip + (prog - start );
542500 poke -> adj_off = X86_TAIL_CALL_OFFSET ;
543- poke -> tailcall_target = image + ( addr - X86_PATCH_SIZE ) ;
501+ poke -> tailcall_target = ip + ctx -> tail_call_direct_label - X86_PATCH_SIZE ;
544502 poke -> bypass_addr = (u8 * )poke -> tailcall_target + X86_PATCH_SIZE ;
545503
546504 emit_jump (& prog , (u8 * )poke -> tailcall_target + X86_PATCH_SIZE ,
547505 poke -> tailcall_bypass );
548506
549- * pprog = prog ;
550- pop_callee_regs (pprog , callee_regs_used );
551- prog = * pprog ;
507+ pop_callee_regs (& prog , callee_regs_used );
552508 EMIT1 (0x58 ); /* pop rax */
553509 if (stack_depth )
554510 EMIT3_off32 (0x48 , 0x81 , 0xC4 , round_up (stack_depth , 8 ));
555511
556512 memcpy (prog , x86_nops [5 ], X86_PATCH_SIZE );
557513 prog += X86_PATCH_SIZE ;
514+
558515 /* out: */
516+ ctx -> tail_call_direct_label = prog - start ;
559517
560518 * pprog = prog ;
561519}
@@ -1411,13 +1369,16 @@ st: if (is_imm8(insn->off))
14111369 case BPF_JMP | BPF_TAIL_CALL :
14121370 if (imm32 )
14131371 emit_bpf_tail_call_direct (& bpf_prog -> aux -> poke_tab [imm32 - 1 ],
1414- & prog , addrs [i ], image ,
1372+ & prog , image + addrs [i - 1 ] ,
14151373 callee_regs_used ,
1416- bpf_prog -> aux -> stack_depth );
1374+ bpf_prog -> aux -> stack_depth ,
1375+ ctx );
14171376 else
14181377 emit_bpf_tail_call_indirect (& prog ,
14191378 callee_regs_used ,
1420- bpf_prog -> aux -> stack_depth );
1379+ bpf_prog -> aux -> stack_depth ,
1380+ image + addrs [i - 1 ],
1381+ ctx );
14211382 break ;
14221383
14231384 /* cond jump */
0 commit comments