Skip to content

Commit dceba08

Browse files
author
Peter Zijlstra
committed
bpf,x86: Simplify computing label offsets
Take an idea from the 32bit JIT, which uses the multi-pass nature of the JIT to compute the instruction offsets on a prior pass in order to compute the relative jump offsets on a later pass. Application to the x86_64 JIT is slightly more involved because the offsets depend on program variables (such as callee_regs_used and stack_depth) and hence the computed offsets need to be kept in the context of the JIT. This removes, IMO quite fragile, code that hard-codes the offsets and tries to compute the length of variable parts of it. Convert both emit_bpf_tail_call_*() functions which have an out: label at the end. Additionally emit_bpt_tail_call_direct() also has a poke table entry, for which it computes the offset from the end (and thus already relies on the previous pass to have computed addrs[i]), also convert this to be a forward based offset. Signed-off-by: Peter Zijlstra (Intel) <[email protected]> Reviewed-by: Borislav Petkov <[email protected]> Acked-by: Alexei Starovoitov <[email protected]> Acked-by: Josh Poimboeuf <[email protected]> Tested-by: Alexei Starovoitov <[email protected]> Link: https://lore.kernel.org/r/[email protected]
1 parent f8a66d6 commit dceba08

File tree

1 file changed

+42
-81
lines changed

1 file changed

+42
-81
lines changed

arch/x86/net/bpf_jit_comp.c

Lines changed: 42 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ static void jit_fill_hole(void *area, unsigned int size)
224224

225225
struct jit_context {
226226
int cleanup_addr; /* Epilogue code offset */
227+
228+
/*
229+
* Program specific offsets of labels in the code; these rely on the
230+
* JIT doing at least 2 passes, recording the position on the first
231+
* pass, only to generate the correct offset on the second pass.
232+
*/
233+
int tail_call_direct_label;
234+
int tail_call_indirect_label;
227235
};
228236

229237
/* Maximum number of bytes emitted while JITing one eBPF insn */
@@ -379,22 +387,6 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
379387
return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true);
380388
}
381389

382-
static int get_pop_bytes(bool *callee_regs_used)
383-
{
384-
int bytes = 0;
385-
386-
if (callee_regs_used[3])
387-
bytes += 2;
388-
if (callee_regs_used[2])
389-
bytes += 2;
390-
if (callee_regs_used[1])
391-
bytes += 2;
392-
if (callee_regs_used[0])
393-
bytes += 1;
394-
395-
return bytes;
396-
}
397-
398390
/*
399391
* Generate the following code:
400392
*
@@ -410,29 +402,12 @@ static int get_pop_bytes(bool *callee_regs_used)
410402
* out:
411403
*/
412404
static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
413-
u32 stack_depth)
405+
u32 stack_depth, u8 *ip,
406+
struct jit_context *ctx)
414407
{
415408
int tcc_off = -4 - round_up(stack_depth, 8);
416-
u8 *prog = *pprog;
417-
int pop_bytes = 0;
418-
int off1 = 42;
419-
int off2 = 31;
420-
int off3 = 9;
421-
422-
/* count the additional bytes used for popping callee regs from stack
423-
* that need to be taken into account for each of the offsets that
424-
* are used for bailing out of the tail call
425-
*/
426-
pop_bytes = get_pop_bytes(callee_regs_used);
427-
off1 += pop_bytes;
428-
off2 += pop_bytes;
429-
off3 += pop_bytes;
430-
431-
if (stack_depth) {
432-
off1 += 7;
433-
off2 += 7;
434-
off3 += 7;
435-
}
409+
u8 *prog = *pprog, *start = *pprog;
410+
int offset;
436411

437412
/*
438413
* rdi - pointer to ctx
@@ -447,17 +422,19 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
447422
EMIT2(0x89, 0xD2); /* mov edx, edx */
448423
EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */
449424
offsetof(struct bpf_array, map.max_entries));
450-
#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */
451-
EMIT2(X86_JBE, OFFSET1); /* jbe out */
425+
426+
offset = ctx->tail_call_indirect_label - (prog + 2 - start);
427+
EMIT2(X86_JBE, offset); /* jbe out */
452428

453429
/*
454430
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
455431
* goto out;
456432
*/
457433
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
458434
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
459-
#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE)
460-
EMIT2(X86_JA, OFFSET2); /* ja out */
435+
436+
offset = ctx->tail_call_indirect_label - (prog + 2 - start);
437+
EMIT2(X86_JA, offset); /* ja out */
461438
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
462439
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
463440

@@ -470,12 +447,11 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
470447
* goto out;
471448
*/
472449
EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */
473-
#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE)
474-
EMIT2(X86_JE, OFFSET3); /* je out */
475450

476-
*pprog = prog;
477-
pop_callee_regs(pprog, callee_regs_used);
478-
prog = *pprog;
451+
offset = ctx->tail_call_indirect_label - (prog + 2 - start);
452+
EMIT2(X86_JE, offset); /* je out */
453+
454+
pop_callee_regs(&prog, callee_regs_used);
479455

480456
EMIT1(0x58); /* pop rax */
481457
if (stack_depth)
@@ -495,67 +471,49 @@ static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used,
495471
RETPOLINE_RCX_BPF_JIT();
496472

497473
/* out: */
474+
ctx->tail_call_indirect_label = prog - start;
498475
*pprog = prog;
499476
}
500477

501478
static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke,
502-
u8 **pprog, int addr, u8 *image,
503-
bool *callee_regs_used, u32 stack_depth)
479+
u8 **pprog, u8 *ip,
480+
bool *callee_regs_used, u32 stack_depth,
481+
struct jit_context *ctx)
504482
{
505483
int tcc_off = -4 - round_up(stack_depth, 8);
506-
u8 *prog = *pprog;
507-
int pop_bytes = 0;
508-
int off1 = 20;
509-
int poke_off;
510-
511-
/* count the additional bytes used for popping callee regs to stack
512-
* that need to be taken into account for jump offset that is used for
513-
* bailing out from of the tail call when limit is reached
514-
*/
515-
pop_bytes = get_pop_bytes(callee_regs_used);
516-
off1 += pop_bytes;
517-
518-
/*
519-
* total bytes for:
520-
* - nop5/ jmpq $off
521-
* - pop callee regs
522-
* - sub rsp, $val if depth > 0
523-
* - pop rax
524-
*/
525-
poke_off = X86_PATCH_SIZE + pop_bytes + 1;
526-
if (stack_depth) {
527-
poke_off += 7;
528-
off1 += 7;
529-
}
484+
u8 *prog = *pprog, *start = *pprog;
485+
int offset;
530486

531487
/*
532488
* if (tail_call_cnt > MAX_TAIL_CALL_CNT)
533489
* goto out;
534490
*/
535491
EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */
536492
EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */
537-
EMIT2(X86_JA, off1); /* ja out */
493+
494+
offset = ctx->tail_call_direct_label - (prog + 2 - start);
495+
EMIT2(X86_JA, offset); /* ja out */
538496
EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */
539497
EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */
540498

541-
poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE);
499+
poke->tailcall_bypass = ip + (prog - start);
542500
poke->adj_off = X86_TAIL_CALL_OFFSET;
543-
poke->tailcall_target = image + (addr - X86_PATCH_SIZE);
501+
poke->tailcall_target = ip + ctx->tail_call_direct_label - X86_PATCH_SIZE;
544502
poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE;
545503

546504
emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE,
547505
poke->tailcall_bypass);
548506

549-
*pprog = prog;
550-
pop_callee_regs(pprog, callee_regs_used);
551-
prog = *pprog;
507+
pop_callee_regs(&prog, callee_regs_used);
552508
EMIT1(0x58); /* pop rax */
553509
if (stack_depth)
554510
EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8));
555511

556512
memcpy(prog, x86_nops[5], X86_PATCH_SIZE);
557513
prog += X86_PATCH_SIZE;
514+
558515
/* out: */
516+
ctx->tail_call_direct_label = prog - start;
559517

560518
*pprog = prog;
561519
}
@@ -1411,13 +1369,16 @@ st: if (is_imm8(insn->off))
14111369
case BPF_JMP | BPF_TAIL_CALL:
14121370
if (imm32)
14131371
emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1],
1414-
&prog, addrs[i], image,
1372+
&prog, image + addrs[i - 1],
14151373
callee_regs_used,
1416-
bpf_prog->aux->stack_depth);
1374+
bpf_prog->aux->stack_depth,
1375+
ctx);
14171376
else
14181377
emit_bpf_tail_call_indirect(&prog,
14191378
callee_regs_used,
1420-
bpf_prog->aux->stack_depth);
1379+
bpf_prog->aux->stack_depth,
1380+
image + addrs[i - 1],
1381+
ctx);
14211382
break;
14221383

14231384
/* cond jump */

0 commit comments

Comments
 (0)