diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 2e170be647bd..d36e64975750 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -15301,21 +15301,19 @@ static void __scalar64_min_max_lsh(struct bpf_reg_state *dst_reg, u64 umin_val, u64 umax_val) { /* Special case <<32 because it is a common compiler pattern to sign - * extend subreg by doing <<32 s>>32. In this case if 32bit bounds are - * positive we know this shift will also be positive so we can track - * bounds correctly. Otherwise we lose all sign bit information except - * what we can pick up from var_off. Perhaps we can generalize this - * later to shifts of any length. + * extend subreg by doing <<32 s>>32. When the shift is below the + * sign extension (32 bits in this case), which is always true when we + * cast the s32 to s64, the result will always be a valid number + * representative of the respective shift and its bounds can be + * predicted. */ - if (umin_val == 32 && umax_val == 32 && dst_reg->s32_max_value >= 0) + if (umin_val == 32 && umax_val == 32) { dst_reg->smax_value = (s64)dst_reg->s32_max_value << 32; - else - dst_reg->smax_value = S64_MAX; - - if (umin_val == 32 && umax_val == 32 && dst_reg->s32_min_value >= 0) dst_reg->smin_value = (s64)dst_reg->s32_min_value << 32; - else + } else { + dst_reg->smax_value = S64_MAX; dst_reg->smin_value = S64_MIN; + } /* If we might shift our top bit out, then we know nothing */ if (dst_reg->umax_value > 1ULL << (63 - umax_val)) { diff --git a/tools/testing/selftests/bpf/progs/verifier_subreg.c b/tools/testing/selftests/bpf/progs/verifier_subreg.c index 8613ea160dcd..62da0b8cf591 100644 --- a/tools/testing/selftests/bpf/progs/verifier_subreg.c +++ b/tools/testing/selftests/bpf/progs/verifier_subreg.c @@ -531,6 +531,74 @@ __naked void arsh32_imm_zero_extend_check(void) : __clobber_all); } +SEC("socket") +__description("arsh32 imm sign positive extend check") +__success __success_unpriv __retval(0) +__naked void arsh32_imm_sign_extend_positive_check(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r6 = r0; \ + r6 &= 0xfff; \ + r6 <<= 32; \ + r6 s>>= 32; \ + r0 = 0; \ + if w6 s>= 0 goto l0_%=; \ + r0 /= 0; \ +l0_%=: if w6 s<= 4096 goto l1_%=; \ + r0 /= 0; \ +l1_%=: exit; \ +" : + : __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("arsh32 imm sign negative extend check") +__success __success_unpriv __retval(0) +__naked void arsh32_imm_sign_extend_negative_check(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r6 = r0; \ + r6 &= 0xfff; \ + r6 -= 0xfff; \ + r6 <<= 32; \ + r6 s>>= 32; \ + r0 = 0; \ + if w6 s>= -4095 goto l0_%=; \ + r0 /= 0; \ +l0_%=: if w6 s<= 0 goto l1_%=; \ + r0 /= 0; \ +l1_%=: exit; \ +" : + : __imm(bpf_get_prandom_u32) + : __clobber_all); +} + +SEC("socket") +__description("arsh32 imm sign extend check") +__success __success_unpriv __retval(0) +__naked void arsh32_imm_sign_extend_check(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r6 = r0; \ + r6 &= 0xfff; \ + r6 -= 0x7ff; \ + r6 <<= 32; \ + r6 s>>= 32; \ + r0 = 0; \ + if w6 s>= -2049 goto l0_%=; \ + r0 /= 0; \ +l0_%=: if w6 s<= 2048 goto l1_%=; \ + r0 /= 0; \ +l1_%=: exit; \ +" : + : __imm(bpf_get_prandom_u32) + : __clobber_all); +} + SEC("socket") __description("end16 (to_le) reg zero extend check") __success __success_unpriv __retval(0)