diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 098dd7f21c89..97faacdc0f8b 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -15470,6 +15470,33 @@ static bool is_safe_to_compute_dst_reg_range(struct bpf_insn *insn, } } +static int maybe_fork_scalars(struct bpf_verifier_env *env, struct bpf_insn *insn, + struct bpf_reg_state *dst_reg) +{ + struct bpf_verifier_state *branch; + struct bpf_reg_state *regs; + bool alu32; + + if (dst_reg->smin_value == -1 && dst_reg->smax_value == 0) + alu32 = false; + else if (dst_reg->s32_min_value == -1 && dst_reg->s32_max_value == 0) + alu32 = true; + else + return 0; + + branch = push_stack(env, env->insn_idx + 1, env->insn_idx, false); + if (IS_ERR(branch)) + return PTR_ERR(branch); + + regs = branch->frame[branch->curframe]->regs; + __mark_reg_known(®s[insn->dst_reg], 0); + if (alu32) + __mark_reg32_known(dst_reg, -1ull); + else + __mark_reg_known(dst_reg, -1ull); + return 0; +} + /* WARNING: This function does calculations on 64-bit values, but the actual * execution may occur on 32-bit values. Therefore, things like bitshifts * need extra checks in the 32-bit case. @@ -15563,6 +15590,9 @@ static int adjust_scalar_min_max_vals(struct bpf_verifier_env *env, scalar32_min_max_arsh(dst_reg, &src_reg); else scalar_min_max_arsh(dst_reg, &src_reg); + ret = maybe_fork_scalars(env, insn, dst_reg); + if (ret) + return ret; break; default: break; diff --git a/tools/testing/selftests/bpf/progs/verifier_subreg.c b/tools/testing/selftests/bpf/progs/verifier_subreg.c index 8613ea160dcd..b18b75c532bc 100644 --- a/tools/testing/selftests/bpf/progs/verifier_subreg.c +++ b/tools/testing/selftests/bpf/progs/verifier_subreg.c @@ -670,4 +670,45 @@ __naked void ldx_w_zero_extend_check(void) : __clobber_all); } +SEC("socket") +__description("s>>=31") +__success __success_unpriv __retval(0) +__naked void arsh_31(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + w2 = w0; \ + w2 s>>= 31; \ + w2 &= -134; \ + if w2 s> -1 goto +2; \ + if w2 != 0xffffff78 goto +1; \ + w0 /= 0; \ + w0 = 0; \ + exit; \ +" : + : __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("s>>=63") +__success __success_unpriv __retval(0) +__naked void arsh_63(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r2 = r0; \ + r2 <<= 32; \ + r2 s>>= 63; \ + r2 &= -134; \ + if r2 s> -1 goto +2; \ + if r2 != 0xffffff78 goto +1; \ + r0 /= 0; \ + r0 = 0; \ + exit; \ +" : + : __imm(bpf_get_prandom_u32) + : __clobber_all); +} + char _license[] SEC("license") = "GPL";