Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions kernel/bpf/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -17579,6 +17579,113 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
&linked_regs);
}

/* For JEQ/JNE with a known constant, fork the not-equal branch
* into dst > const and dst < const for tighter range tracking.
*/
if ((opcode == BPF_JEQ || opcode == BPF_JNE) &&
dst_reg->type == SCALAR_VALUE) {
struct bpf_verifier_state *neq_branch2, *neq_state;
struct bpf_reg_state *neq_regs, *neq2_regs;
struct bpf_reg_state fake_const;
bool can_fork = false;
s32 neq_target;
u64 val;

/* Determine which src is the constant */
if (BPF_SRC(insn->code) == BPF_K) {
val = (u64)(s64)insn->imm;
can_fork = true;
} else if (is_reg_const(src_reg, is_jmp32)) {
val = reg_const_value(src_reg, is_jmp32);
can_fork = true;
} else if (is_reg_const(dst_reg, is_jmp32)) {
/* dst is const — JNE/JEQ is symmetric for forking,
* but we'd fork on src's range. Skip for now.
*/
can_fork = false;
}

if (can_fork) {
/* Identify the not-equal branch and check feasibility */
if (opcode == BPF_JEQ) {
neq_regs = regs; /* this_branch = fallthrough = != */
neq_target = *insn_idx + 1;
neq_state = this_branch;
} else {
neq_regs = other_branch_regs; /* other_branch = jump = != */
neq_target = *insn_idx + insn->off + 1;
neq_state = other_branch;
}

/* Check that the range spans across the constant.
* Check both unsigned and signed ranges, because
* reg_bounds_sync() may cross-propagate signed
* bounds into unsigned after JGT/JLT refinement,
* creating infeasible states if the signed range
* doesn't also span the constant.
*/
if (is_jmp32) {
can_fork = neq_regs[insn->dst_reg].u32_min_value < (u32)val &&
neq_regs[insn->dst_reg].u32_max_value > (u32)val &&
neq_regs[insn->dst_reg].s32_min_value < (s32)val &&
neq_regs[insn->dst_reg].s32_max_value > (s32)val;
} else {
can_fork = neq_regs[insn->dst_reg].umin_value < val &&
neq_regs[insn->dst_reg].umax_value > val &&
neq_regs[insn->dst_reg].smin_value < (s64)val &&
neq_regs[insn->dst_reg].smax_value > (s64)val;
}
}

if (can_fork) {
/* Create a fake const register for regs_refine_cond_op */
memset(&fake_const, 0, sizeof(fake_const));
fake_const.type = SCALAR_VALUE;
__mark_reg_known(&fake_const, val);

/* Push second fork for the not-equal branch.
* push_stack copies env->cur_state (this_branch).
* For JNE this_branch is the equal branch, so we
* must overwrite dst_reg with the not-equal copy.
*/
neq_branch2 = push_stack(env, neq_target, *insn_idx, false);
if (IS_ERR(neq_branch2))
return PTR_ERR(neq_branch2);
neq2_regs = neq_branch2->frame[neq_branch2->curframe]->regs;
neq2_regs[insn->dst_reg] = neq_regs[insn->dst_reg];

/* Fork 1 (existing not-equal branch): dst > const */
regs_refine_cond_op(&neq_regs[insn->dst_reg],
&fake_const, BPF_JGT, is_jmp32);
reg_bounds_sync(&neq_regs[insn->dst_reg]);

/* Fork 2 (new state): dst < const */
memset(&fake_const, 0, sizeof(fake_const));
fake_const.type = SCALAR_VALUE;
__mark_reg_known(&fake_const, val);
regs_refine_cond_op(&neq2_regs[insn->dst_reg],
&fake_const, BPF_JLT, is_jmp32);
reg_bounds_sync(&neq2_regs[insn->dst_reg]);

/* Propagate JGT/JLT bounds to linked registers */
if (neq_regs[insn->dst_reg].id) {
sync_linked_regs(env, neq_state,
&neq_regs[insn->dst_reg],
&linked_regs);
sync_linked_regs(env, neq_branch2,
&neq2_regs[insn->dst_reg],
&linked_regs);
}

err = reg_bounds_sanity_check(env,
&neq_regs[insn->dst_reg], "neq_gt");
err = err ?: reg_bounds_sanity_check(env,
&neq2_regs[insn->dst_reg], "neq_lt");
if (err)
return err;
}
}

/* if one pointer register is compared to another pointer
* register check if PTR_MAYBE_NULL could be lifted.
* E.g. register A - maybe null
Expand Down
2 changes: 2 additions & 0 deletions tools/testing/selftests/bpf/prog_tests/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "verifier_int_ptr.skel.h"
#include "verifier_iterating_callbacks.skel.h"
#include "verifier_jeq_infer_not_null.skel.h"
#include "verifier_jeq_jne_fork.skel.h"
#include "verifier_jit_convergence.skel.h"
#include "verifier_ld_ind.skel.h"
#include "verifier_ldsx.skel.h"
Expand Down Expand Up @@ -191,6 +192,7 @@ void test_verifier_helper_value_access(void) { RUN(verifier_helper_value_access
void test_verifier_int_ptr(void) { RUN(verifier_int_ptr); }
void test_verifier_iterating_callbacks(void) { RUN(verifier_iterating_callbacks); }
void test_verifier_jeq_infer_not_null(void) { RUN(verifier_jeq_infer_not_null); }
void test_verifier_jeq_jne_fork(void) { RUN(verifier_jeq_jne_fork); }
void test_verifier_jit_convergence(void) { RUN(verifier_jit_convergence); }
void test_verifier_load_acquire(void) { RUN(verifier_load_acquire); }
void test_verifier_ld_ind(void) { RUN(verifier_ld_ind); }
Expand Down
204 changes: 204 additions & 0 deletions tools/testing/selftests/bpf/progs/verifier_jeq_jne_fork.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
// SPDX-License-Identifier: GPL-2.0
/* Tests for JEQ/JNE not-equal branch forking.
*
* When the verifier processes JEQ/JNE with a known constant, it forks
* the not-equal branch into two sub-states:
* fork1: dst > const (unsigned, via JGT refinement)
* fork2: dst < const (unsigned, via JLT refinement)
* This gives tighter range tracking than the original JNE edge-trim,
* which only adjusts bounds when the constant is at the range boundary.
*/

#include <linux/bpf.h>
#include <bpf/bpf_helpers.h>
#include "bpf_misc.h"

/* JNE with BPF_K: both forks land on the jump target and are popped
* from the verifier stack, so both appear in the "from X to Y:" log.
*
* r0 in [0, 7], JNE r0, 3:
* fork2 (popped first, JLT): r0 in [0, 2]
* fork1 (popped second, JGT): r0 in [4, 7]
*/
SEC("socket")
__description("jne_k: neq branch forked into r0 > 3 and r0 < 3")
__success __log_level(2)
__msg("R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2")
__msg("R0=scalar(smin=umin=smin32=umin32=4,smax=umax=smax32=umax32=7")
__retval(0)
__naked void jne_k_neq_fork(void)
{
asm volatile (
"call %[bpf_ktime_get_ns];"
"r0 &= 7;"
"if r0 != 3 goto l_neq_%=;"
"r0 = 0;"
"exit;"
"l_neq_%=:"
"r0 = 0;"
"exit;"
:
: __imm(bpf_ktime_get_ns)
: __clobber_all);
}

/* JEQ with BPF_K: fork1 is the continuation (not popped) and fork2
* is popped. We verify fork2's bounds in the log.
*
* r0 in [0, 7], JEQ r0, 3:
* fork1 (continuation, JGT): r0 in [4, 7]
* fork2 (popped, JLT): r0 in [0, 2]
*/
SEC("socket")
__description("jeq_k: neq branch forked, fork2 has r0 < 3")
__success __log_level(2)
__msg("R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2")
__retval(0)
__naked void jeq_k_neq_fork(void)
{
asm volatile (
"call %[bpf_ktime_get_ns];"
"r0 &= 7;"
"if r0 == 3 goto l_eq_%=;"
"r0 = 0;"
"exit;"
"l_eq_%=:"
"r0 = 0;"
"exit;"
:
: __imm(bpf_ktime_get_ns)
: __clobber_all);
}

/* JEQ with BPF_X: register source containing a known constant.
*
* r0 in [0, 7], r1 = 3, JEQ r0, r1:
* fork2 (popped, JLT): r0 in [0, 2]
*/
SEC("socket")
__description("jeq_x: neq branch forked with register source")
__success __log_level(2)
__msg("R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2")
__retval(0)
__naked void jeq_x_neq_fork(void)
{
asm volatile (
"call %[bpf_ktime_get_ns];"
"r0 &= 7;"
"r1 = 3;"
"if r0 == r1 goto l_eq_%=;"
"r0 = 0;"
"exit;"
"l_eq_%=:"
"r0 = 0;"
"exit;"
:
: __imm(bpf_ktime_get_ns)
: __clobber_all);
}

/* JMP32 JEQ with BPF_K: 32-bit comparison variant.
* Fork uses u32_min/u32_max for the feasibility check.
*
* w0 in [0, 7], JEQ32 w0, 3:
* fork2 (popped, JLT32): w0 in [0, 2]
*/
SEC("socket")
__description("jeq32_k: 32-bit neq branch forked")
__success __log_level(2)
__msg("R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=2")
__retval(0)
__naked void jeq32_k_neq_fork(void)
{
asm volatile (
"call %[bpf_ktime_get_ns];"
"w0 &= 7;"
"if w0 == 3 goto l_eq_%=;"
"r0 = 0;"
"exit;"
"l_eq_%=:"
"r0 = 0;"
"exit;"
:
: __imm(bpf_ktime_get_ns)
: __clobber_all);
}

/* JNE with larger range and different constant.
* r0 in [0, 255], JNE r0, 100:
* fork2 (JLT): r0 in [0, 99]
* fork1 (JGT): r0 in [101, 255]
*/
SEC("socket")
__description("jne_k: neq branch forked with wider range")
__success __log_level(2)
__msg("R0=scalar(smin=smin32=0,smax=umax=smax32=umax32=99")
__msg("R0=scalar(smin=umin=smin32=umin32=101,smax=umax=smax32=umax32=255")
__retval(0)
__naked void jne_k_neq_fork_wide(void)
{
asm volatile (
"call %[bpf_ktime_get_ns];"
"r0 &= 0xff;"
"if r0 != 100 goto l_neq_%=;"
"r0 = 0;"
"exit;"
"l_neq_%=:"
"r0 = 0;"
"exit;"
:
: __imm(bpf_ktime_get_ns)
: __clobber_all);
}

/* Const at umin boundary — no fork expected.
* r0 in [0, 7], JEQ r0, 0:
* Fork condition: umin(0) < 0 is false for unsigned → no fork.
* Edge-trim gives not-equal branch r0 in [1, 7].
*/
SEC("socket")
__description("jeq_k: const at umin, no fork needed")
__success
__retval(0)
__naked void jeq_k_no_fork_umin(void)
{
asm volatile (
"call %[bpf_ktime_get_ns];"
"r0 &= 7;"
"if r0 == 0 goto l_eq_%=;"
"r0 = 0;"
"exit;"
"l_eq_%=:"
"r0 = 0;"
"exit;"
:
: __imm(bpf_ktime_get_ns)
: __clobber_all);
}

/* Const at umax boundary — no fork expected.
* r0 in [0, 7], JEQ r0, 7:
* Fork condition: umax(7) > 7 is false → no fork.
* Edge-trim gives not-equal branch r0 in [0, 6].
*/
SEC("socket")
__description("jeq_k: const at umax, no fork needed")
__success
__retval(0)
__naked void jeq_k_no_fork_umax(void)
{
asm volatile (
"call %[bpf_ktime_get_ns];"
"r0 &= 7;"
"if r0 == 7 goto l_eq_%=;"
"r0 = 0;"
"exit;"
"l_eq_%=:"
"r0 = 0;"
"exit;"
:
: __imm(bpf_ktime_get_ns)
: __clobber_all);
}

char _license[] SEC("license") = "GPL";
Loading