Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/linux/tnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ struct tnum tnum_union(struct tnum t1, struct tnum t2);
/* Return @a with all but the lowest @size bytes cleared */
struct tnum tnum_cast(struct tnum a, u8 size);

/* Return @a sign-extended from @size bytes */
struct tnum tnum_scast(struct tnum a, u8 size);

/* Returns true if @a is a known constant */
static inline bool tnum_is_const(struct tnum a)
{
Expand Down
13 changes: 13 additions & 0 deletions kernel/bpf/tnum.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,19 @@ struct tnum tnum_cast(struct tnum a, u8 size)
return a;
}

struct tnum tnum_scast(struct tnum a, u8 size)
{
u8 s = 64 - size * 8;
u64 value, mask;

if (size >= 8)
return a;

value = ((s64)a.value << s) >> s;
mask = ((s64)a.mask << s) >> s;
return TNUM(value, mask);
}

bool tnum_is_aligned(struct tnum a, u64 size)
{
if (!size)
Expand Down
150 changes: 30 additions & 120 deletions kernel/bpf/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -6876,147 +6876,57 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
reg_bounds_sync(reg);
}

static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
{
if (size == 1) {
reg->smin_value = reg->s32_min_value = S8_MIN;
reg->smax_value = reg->s32_max_value = S8_MAX;
} else if (size == 2) {
reg->smin_value = reg->s32_min_value = S16_MIN;
reg->smax_value = reg->s32_max_value = S16_MAX;
} else {
/* size == 4 */
reg->smin_value = reg->s32_min_value = S32_MIN;
reg->smax_value = reg->s32_max_value = S32_MAX;
}
reg->umin_value = reg->u32_min_value = 0;
reg->umax_value = U64_MAX;
reg->u32_max_value = U32_MAX;
reg->var_off = tnum_unknown;
}

static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
{
s64 init_s64_max, init_s64_min, s64_max, s64_min, u64_cval;
u64 top_smax_value, top_smin_value;
u64 num_bits = size * 8;
s64 smin_value, smax_value;

if (tnum_is_const(reg->var_off)) {
u64_cval = reg->var_off.value;
if (size == 1)
reg->var_off = tnum_const((s8)u64_cval);
else if (size == 2)
reg->var_off = tnum_const((s16)u64_cval);
else
/* size == 4 */
reg->var_off = tnum_const((s32)u64_cval);

u64_cval = reg->var_off.value;
reg->smax_value = reg->smin_value = u64_cval;
reg->umax_value = reg->umin_value = u64_cval;
reg->s32_max_value = reg->s32_min_value = u64_cval;
reg->u32_max_value = reg->u32_min_value = u64_cval;
if (size >= 8)
return;
}

top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
reg->var_off = tnum_scast(reg->var_off, size);

if (top_smax_value != top_smin_value)
goto out;
smin_value = -(1LL << (size * 8 - 1));
smax_value = (1LL << (size * 8 - 1)) - 1;

/* find the s64_min and s64_min after sign extension */
if (size == 1) {
init_s64_max = (s8)reg->smax_value;
init_s64_min = (s8)reg->smin_value;
} else if (size == 2) {
init_s64_max = (s16)reg->smax_value;
init_s64_min = (s16)reg->smin_value;
} else {
init_s64_max = (s32)reg->smax_value;
init_s64_min = (s32)reg->smin_value;
}

s64_max = max(init_s64_max, init_s64_min);
s64_min = min(init_s64_max, init_s64_min);
reg->smin_value = smin_value;
reg->smax_value = smax_value;

/* both of s64_max/s64_min positive or negative */
if ((s64_max >= 0) == (s64_min >= 0)) {
reg->s32_min_value = reg->smin_value = s64_min;
reg->s32_max_value = reg->smax_value = s64_max;
reg->u32_min_value = reg->umin_value = s64_min;
reg->u32_max_value = reg->umax_value = s64_max;
reg->var_off = tnum_range(s64_min, s64_max);
return;
}
reg->s32_min_value = (s32)smin_value;
reg->s32_max_value = (s32)smax_value;

out:
set_sext64_default_val(reg, size);
}

static void set_sext32_default_val(struct bpf_reg_state *reg, int size)
{
if (size == 1) {
reg->s32_min_value = S8_MIN;
reg->s32_max_value = S8_MAX;
} else {
/* size == 2 */
reg->s32_min_value = S16_MIN;
reg->s32_max_value = S16_MAX;
}
reg->umin_value = 0;
reg->umax_value = U64_MAX;
reg->u32_min_value = 0;
reg->u32_max_value = U32_MAX;
reg->var_off = tnum_subreg(tnum_unknown);

__update_reg_bounds(reg);
}

static void coerce_subreg_to_size_sx(struct bpf_reg_state *reg, int size)
{
s32 init_s32_max, init_s32_min, s32_max, s32_min, u32_val;
u32 top_smax_value, top_smin_value;
u32 num_bits = size * 8;

if (tnum_is_const(reg->var_off)) {
u32_val = reg->var_off.value;
if (size == 1)
reg->var_off = tnum_const((s8)u32_val);
else
reg->var_off = tnum_const((s16)u32_val);
s32 smin_value, smax_value;

u32_val = reg->var_off.value;
reg->s32_min_value = reg->s32_max_value = u32_val;
reg->u32_min_value = reg->u32_max_value = u32_val;
if (size >= 4)
return;
}

top_smax_value = ((u32)reg->s32_max_value >> num_bits) << num_bits;
top_smin_value = ((u32)reg->s32_min_value >> num_bits) << num_bits;
reg->var_off = tnum_subreg(tnum_scast(reg->var_off, size));

if (top_smax_value != top_smin_value)
goto out;
smin_value = -(1 << (size * 8 - 1));
smax_value = (1 << (size * 8 - 1)) - 1;

/* find the s32_min and s32_min after sign extension */
if (size == 1) {
init_s32_max = (s8)reg->s32_max_value;
init_s32_min = (s8)reg->s32_min_value;
} else {
/* size == 2 */
init_s32_max = (s16)reg->s32_max_value;
init_s32_min = (s16)reg->s32_min_value;
}
s32_max = max(init_s32_max, init_s32_min);
s32_min = min(init_s32_max, init_s32_min);

if ((s32_min >= 0) == (s32_max >= 0)) {
reg->s32_min_value = s32_min;
reg->s32_max_value = s32_max;
reg->u32_min_value = (u32)s32_min;
reg->u32_max_value = (u32)s32_max;
reg->var_off = tnum_subreg(tnum_range(s32_min, s32_max));
return;
}
reg->s32_min_value = smin_value;
reg->s32_max_value = smax_value;

out:
set_sext32_default_val(reg, size);
reg->u32_min_value = 0;
reg->u32_max_value = U32_MAX;

__update_reg32_bounds(reg);

reg->umin_value = reg->u32_min_value;
reg->umax_value = reg->u32_max_value;

reg->smin_value = reg->umin_value;
reg->smax_value = reg->umax_value;
}

static bool bpf_map_is_rdonly(const struct bpf_map *map)
Expand Down
19 changes: 19 additions & 0 deletions tools/testing/selftests/bpf/progs/verifier_movsx.c
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,25 @@ label_%=: \
: __clobber_all);
}

SEC("socket")
__description("MOV64SX, S8, upper bits truncation")
__log_level(2)
__msg("R1={{P?}}0")
__success __success_unpriv __retval(0)
__naked void mov64sx_s8_truncated_range(void)
{
asm volatile (" \
call %[bpf_get_prandom_u32]; \
r1 = r0; \
r1 &= 0x100; \
r1 = (s8)r1; \
r0 = 0; \
exit; \
" :
: __imm(bpf_get_prandom_u32)
: __clobber_all);
}

#else

SEC("socket")
Expand Down
Loading