diff --git a/include/linux/tnum.h b/include/linux/tnum.h index c52b862dad45..ed18ee1148b6 100644 --- a/include/linux/tnum.h +++ b/include/linux/tnum.h @@ -63,6 +63,9 @@ struct tnum tnum_union(struct tnum t1, struct tnum t2); /* Return @a with all but the lowest @size bytes cleared */ struct tnum tnum_cast(struct tnum a, u8 size); +/* Return @a sign-extended from @size bytes */ +struct tnum tnum_scast(struct tnum a, u8 size); + /* Returns true if @a is a known constant */ static inline bool tnum_is_const(struct tnum a) { diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c index f8e70e9c3998..eabcec2ebc26 100644 --- a/kernel/bpf/tnum.c +++ b/kernel/bpf/tnum.c @@ -199,6 +199,19 @@ struct tnum tnum_cast(struct tnum a, u8 size) return a; } +struct tnum tnum_scast(struct tnum a, u8 size) +{ + u8 s = 64 - size * 8; + u64 value, mask; + + if (size >= 8) + return a; + + value = ((s64)a.value << s) >> s; + mask = ((s64)a.mask << s) >> s; + return TNUM(value, mask); +} + bool tnum_is_aligned(struct tnum a, u64 size) { if (!size) diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 766695491bc5..c9a6bf85b4ad 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -6876,147 +6876,57 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size) reg_bounds_sync(reg); } -static void set_sext64_default_val(struct bpf_reg_state *reg, int size) -{ - if (size == 1) { - reg->smin_value = reg->s32_min_value = S8_MIN; - reg->smax_value = reg->s32_max_value = S8_MAX; - } else if (size == 2) { - reg->smin_value = reg->s32_min_value = S16_MIN; - reg->smax_value = reg->s32_max_value = S16_MAX; - } else { - /* size == 4 */ - reg->smin_value = reg->s32_min_value = S32_MIN; - reg->smax_value = reg->s32_max_value = S32_MAX; - } - reg->umin_value = reg->u32_min_value = 0; - reg->umax_value = U64_MAX; - reg->u32_max_value = U32_MAX; - reg->var_off = tnum_unknown; -} - static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size) { - s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval; - u64 top_smax_value, top_smin_value; - u64 num_bits = size * 8; + s64 smin_value, smax_value; - if (tnum_is_const(reg->var_off)) { - u64_cval = reg->var_off.value; - if (size == 1) - reg->var_off = tnum_const((s8)u64_cval); - else if (size == 2) - reg->var_off = tnum_const((s16)u64_cval); - else - /* size == 4 */ - reg->var_off = tnum_const((s32)u64_cval); - - u64_cval = reg->var_off.value; - reg->smax_value = reg->smin_value = u64_cval; - reg->umax_value = reg->umin_value = u64_cval; - reg->s32_max_value = reg->s32_min_value = u64_cval; - reg->u32_max_value = reg->u32_min_value = u64_cval; + if (size >= 8) return; - } - top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits; - top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits; + reg->var_off = tnum_scast(reg->var_off, size); - if (top_smax_value != top_smin_value) - goto out; + smin_value = -(1LL << (size * 8 - 1)); + smax_value = (1LL << (size * 8 - 1)) - 1; - /* find the s64_min and s64_min after sign extension */ - if (size == 1) { - init_s64_max = (s8)reg->smax_value; - init_s64_min = (s8)reg->smin_value; - } else if (size == 2) { - init_s64_max = (s16)reg->smax_value; - init_s64_min = (s16)reg->smin_value; - } else { - init_s64_max = (s32)reg->smax_value; - init_s64_min = (s32)reg->smin_value; - } - - s64_max = max(init_s64_max, init_s64_min); - s64_min = min(init_s64_max, init_s64_min); + reg->smin_value = smin_value; + reg->smax_value = smax_value; - /* both of s64_max/s64_min positive or negative */ - if ((s64_max >= 0) == (s64_min >= 0)) { - reg->s32_min_value = reg->smin_value = s64_min; - reg->s32_max_value = reg->smax_value = s64_max; - reg->u32_min_value = reg->umin_value = s64_min; - reg->u32_max_value = reg->umax_value = s64_max; - reg->var_off = tnum_range(s64_min, s64_max); - return; - } + reg->s32_min_value = (s32)smin_value; + reg->s32_max_value = (s32)smax_value; -out: - set_sext64_default_val(reg, size); -} - -static void set_sext32_default_val(struct bpf_reg_state *reg, int size) -{ - if (size == 1) { - reg->s32_min_value = S8_MIN; - reg->s32_max_value = S8_MAX; - } else { - /* size == 2 */ - reg->s32_min_value = S16_MIN; - reg->s32_max_value = S16_MAX; - } + reg->umin_value = 0; + reg->umax_value = U64_MAX; reg->u32_min_value = 0; reg->u32_max_value = U32_MAX; - reg->var_off = tnum_subreg(tnum_unknown); + + __update_reg_bounds(reg); } static void coerce_subreg_to_size_sx(struct bpf_reg_state *reg, int size) { - s32 init_s32_max, init_s32_min, s32_max, s32_min, u32_val; - u32 top_smax_value, top_smin_value; - u32 num_bits = size * 8; - - if (tnum_is_const(reg->var_off)) { - u32_val = reg->var_off.value; - if (size == 1) - reg->var_off = tnum_const((s8)u32_val); - else - reg->var_off = tnum_const((s16)u32_val); + s32 smin_value, smax_value; - u32_val = reg->var_off.value; - reg->s32_min_value = reg->s32_max_value = u32_val; - reg->u32_min_value = reg->u32_max_value = u32_val; + if (size >= 4) return; - } - top_smax_value = ((u32)reg->s32_max_value >> num_bits) << num_bits; - top_smin_value = ((u32)reg->s32_min_value >> num_bits) << num_bits; + reg->var_off = tnum_subreg(tnum_scast(reg->var_off, size)); - if (top_smax_value != top_smin_value) - goto out; + smin_value = -(1 << (size * 8 - 1)); + smax_value = (1 << (size * 8 - 1)) - 1; - /* find the s32_min and s32_min after sign extension */ - if (size == 1) { - init_s32_max = (s8)reg->s32_max_value; - init_s32_min = (s8)reg->s32_min_value; - } else { - /* size == 2 */ - init_s32_max = (s16)reg->s32_max_value; - init_s32_min = (s16)reg->s32_min_value; - } - s32_max = max(init_s32_max, init_s32_min); - s32_min = min(init_s32_max, init_s32_min); - - if ((s32_min >= 0) == (s32_max >= 0)) { - reg->s32_min_value = s32_min; - reg->s32_max_value = s32_max; - reg->u32_min_value = (u32)s32_min; - reg->u32_max_value = (u32)s32_max; - reg->var_off = tnum_subreg(tnum_range(s32_min, s32_max)); - return; - } + reg->s32_min_value = smin_value; + reg->s32_max_value = smax_value; -out: - set_sext32_default_val(reg, size); + reg->u32_min_value = 0; + reg->u32_max_value = U32_MAX; + + __update_reg32_bounds(reg); + + reg->umin_value = reg->u32_min_value; + reg->umax_value = reg->u32_max_value; + + reg->smin_value = reg->umin_value; + reg->smax_value = reg->umax_value; } static bool bpf_map_is_rdonly(const struct bpf_map *map) diff --git a/tools/testing/selftests/bpf/progs/verifier_movsx.c b/tools/testing/selftests/bpf/progs/verifier_movsx.c index a4d8814eb5ed..df7ad41af172 100644 --- a/tools/testing/selftests/bpf/progs/verifier_movsx.c +++ b/tools/testing/selftests/bpf/progs/verifier_movsx.c @@ -339,6 +339,25 @@ label_%=: \ : __clobber_all); } +SEC("socket") +__description("MOV64SX, S8, upper bits truncation") +__log_level(2) +__msg("R1={{P?}}0") +__success __success_unpriv __retval(0) +__naked void mov64sx_s8_truncated_range(void) +{ + asm volatile (" \ + call %[bpf_get_prandom_u32]; \ + r1 = r0; \ + r1 &= 0x100; \ + r1 = (s8)r1; \ + r0 = 0; \ + exit; \ +" : + : __imm(bpf_get_prandom_u32) + : __clobber_all); +} + #else SEC("socket")