diff --git a/arch/riscv/net/bpf_jit_comp64.c b/arch/riscv/net/bpf_jit_comp64.c index 549c3063c7f11..11ca56320a3f8 100644 --- a/arch/riscv/net/bpf_jit_comp64.c +++ b/arch/riscv/net/bpf_jit_comp64.c @@ -954,6 +954,33 @@ static int invoke_bpf_prog(struct bpf_tramp_link *l, int args_off, int retval_of return ret; } +/* + * Sign-extend the register if necessary + */ +static int sign_extend(struct rv_jit_context *ctx, int r, u8 size) +{ + switch (size) { + case 1: + emit_slli(r, r, 56, ctx); + emit_srai(r, r, 56, ctx); + break; + case 2: + emit_slli(r, r, 48, ctx); + emit_srai(r, r, 48, ctx); + break; + case 4: + emit_addiw(r, r, 0, ctx); + break; + case 8: + break; + default: + pr_err("bpf-jit: invalid size %d for sign_extend\n", size); + return -EINVAL; + } + + return 0; +} + static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, const struct btf_func_model *m, struct bpf_tramp_links *tlinks, @@ -1177,6 +1204,12 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, if (save_ret) { emit_ld(RV_REG_A0, -retval_off, RV_REG_FP, ctx); emit_ld(regmap[BPF_REG_0], -(retval_off - 8), RV_REG_FP, ctx); + if (is_struct_ops) { + emit_mv(RV_REG_A0, regmap[BPF_REG_0], ctx); + ret = sign_extend(ctx, RV_REG_A0, m->ret_size); + if (ret) + goto out; + } } emit_ld(RV_REG_S1, -sreg_off, RV_REG_FP, ctx);