Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions include/net/proto_memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
if (!sk->sk_prot->memory_pressure)
return false;

if (mem_cgroup_sk_enabled(sk) &&
mem_cgroup_sk_under_memory_pressure(sk))
return true;
if (mem_cgroup_sk_enabled(sk)) {
if (mem_cgroup_sk_under_memory_pressure(sk))
return true;

if (mem_cgroup_sk_isolated(sk))
return false;
}

return !!READ_ONCE(*sk->sk_prot->memory_pressure);
}

static inline bool sk_should_enter_memory_pressure(struct sock *sk)
{
return !mem_cgroup_sk_enabled(sk) || !mem_cgroup_sk_isolated(sk);
}

static inline long
proto_memory_allocated(const struct proto *prot)
{
Expand Down
48 changes: 48 additions & 0 deletions include/net/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -2596,17 +2596,51 @@ static inline gfp_t gfp_memcg_charge(void)
return in_softirq() ? GFP_ATOMIC : GFP_KERNEL;
}

#define SK_BPF_MEMCG_FLAG_MASK (SK_BPF_MEMCG_FLAG_MAX - 1)
#define SK_BPF_MEMCG_PTR_MASK ~SK_BPF_MEMCG_FLAG_MASK

#ifdef CONFIG_MEMCG
static inline void mem_cgroup_sk_set_flags(struct sock *sk, unsigned short flags)
{
unsigned long val = (unsigned long)sk->sk_memcg;

val |= flags;
sk->sk_memcg = (struct mem_cgroup *)val;
}

static inline unsigned short mem_cgroup_sk_get_flags(const struct sock *sk)
{
#ifdef CONFIG_CGROUP_BPF
unsigned long val = (unsigned long)sk->sk_memcg;

return val & SK_BPF_MEMCG_FLAG_MASK;
#else
return 0;
#endif
}

static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk)
{
#ifdef CONFIG_CGROUP_BPF
unsigned long val = (unsigned long)sk->sk_memcg;

val &= SK_BPF_MEMCG_PTR_MASK;
return (struct mem_cgroup *)val;
#else
return sk->sk_memcg;
#endif
}

static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
{
return mem_cgroup_sockets_enabled && mem_cgroup_from_sk(sk);
}

static inline bool mem_cgroup_sk_isolated(const struct sock *sk)
{
return mem_cgroup_sk_get_flags(sk) & SK_BPF_MEMCG_SOCK_ISOLATED;
}

static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk)
{
struct mem_cgroup *memcg = mem_cgroup_from_sk(sk);
Expand All @@ -2624,6 +2658,15 @@ static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk)
return false;
}
#else
static inline void mem_cgroup_sk_set_flags(struct sock *sk, unsigned short flags)
{
}

static inline unsigned short mem_cgroup_sk_get_flags(const struct sock *sk)
{
return 0;
}

static inline struct mem_cgroup *mem_cgroup_from_sk(const struct sock *sk)
{
return NULL;
Expand All @@ -2634,6 +2677,11 @@ static inline bool mem_cgroup_sk_enabled(const struct sock *sk)
return false;
}

static inline bool mem_cgroup_sk_isolated(const struct sock *sk)
{
return false;
}

static inline bool mem_cgroup_sk_under_memory_pressure(const struct sock *sk)
{
return false;
Expand Down
10 changes: 7 additions & 3 deletions include/net/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,13 @@ extern unsigned long tcp_memory_pressure;
/* optimized version of sk_under_memory_pressure() for TCP sockets */
static inline bool tcp_under_memory_pressure(const struct sock *sk)
{
if (mem_cgroup_sk_enabled(sk) &&
mem_cgroup_sk_under_memory_pressure(sk))
return true;
if (mem_cgroup_sk_enabled(sk)) {
if (mem_cgroup_sk_under_memory_pressure(sk))
return true;

if (mem_cgroup_sk_isolated(sk))
return false;
}

return READ_ONCE(tcp_memory_pressure);
}
Expand Down
6 changes: 6 additions & 0 deletions include/uapi/linux/bpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -7182,6 +7182,7 @@ enum {
TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */
TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */
SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */
SK_BPF_MEMCG_FLAGS = 1010, /* Get or Set flags saved in sk->sk_memcg */
};

enum {
Expand All @@ -7204,6 +7205,11 @@ enum {
*/
};

enum {
SK_BPF_MEMCG_SOCK_ISOLATED = (1UL << 0),
SK_BPF_MEMCG_FLAG_MAX = (1UL << 1),
};

struct bpf_perf_event_value {
__u64 counter;
__u64 enabled;
Expand Down
52 changes: 51 additions & 1 deletion net/core/filter.c
Original file line number Diff line number Diff line change
Expand Up @@ -5267,6 +5267,27 @@ static int sk_bpf_set_get_cb_flags(struct sock *sk, char *optval, bool getopt)
return 0;
}

static int sk_bpf_set_get_memcg_flags(struct sock *sk, int *optval, bool getopt)
{
if (!sk_has_account(sk))
return -EOPNOTSUPP;

if (getopt) {
*optval = mem_cgroup_sk_get_flags(sk);
return 0;
}

if (sock_owned_by_user_nocheck(sk) && mem_cgroup_from_sk(sk))
return -EBUSY;

if (*optval <= 0 || *optval >= SK_BPF_MEMCG_FLAG_MAX)
return -EINVAL;

mem_cgroup_sk_set_flags(sk, *optval);

return 0;
}

static int sol_socket_sockopt(struct sock *sk, int optname,
char *optval, int *optlen,
bool getopt)
Expand All @@ -5284,6 +5305,7 @@ static int sol_socket_sockopt(struct sock *sk, int optname,
case SO_BINDTOIFINDEX:
case SO_TXREHASH:
case SK_BPF_CB_FLAGS:
case SK_BPF_MEMCG_FLAGS:
if (*optlen != sizeof(int))
return -EINVAL;
break;
Expand All @@ -5293,8 +5315,12 @@ static int sol_socket_sockopt(struct sock *sk, int optname,
return -EINVAL;
}

if (optname == SK_BPF_CB_FLAGS)
switch (optname) {
case SK_BPF_CB_FLAGS:
return sk_bpf_set_get_cb_flags(sk, optval, getopt);
case SK_BPF_MEMCG_FLAGS:
return sk_bpf_set_get_memcg_flags(sk, (int *)optval, getopt);
}

if (getopt) {
if (optname == SO_BINDTODEVICE)
Expand Down Expand Up @@ -5743,6 +5769,23 @@ static const struct bpf_func_proto bpf_sock_ops_setsockopt_proto = {
.arg5_type = ARG_CONST_SIZE,
};

BPF_CALL_5(bpf_unlocked_sock_setsockopt, struct sock *, sk, int, level,
int, optname, char *, optval, int, optlen)
{
return _bpf_setsockopt(sk, level, optname, optval, optlen);
}

static const struct bpf_func_proto bpf_unlocked_sock_setsockopt_proto = {
.func = bpf_unlocked_sock_setsockopt,
.gpl_only = false,
.ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_CTX,
.arg2_type = ARG_ANYTHING,
.arg3_type = ARG_ANYTHING,
.arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY,
.arg5_type = ARG_CONST_SIZE,
};

static int bpf_sock_ops_get_syn(struct bpf_sock_ops_kern *bpf_sock,
int optname, const u8 **start)
{
Expand Down Expand Up @@ -8051,6 +8094,13 @@ sock_filter_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
return &bpf_sk_storage_get_cg_sock_proto;
case BPF_FUNC_ktime_get_coarse_ns:
return &bpf_ktime_get_coarse_ns_proto;
case BPF_FUNC_setsockopt:
switch (prog->expected_attach_type) {
case BPF_CGROUP_INET_SOCK_CREATE:
return &bpf_unlocked_sock_setsockopt_proto;
default:
return NULL;
}
default:
return bpf_base_func_proto(func_id, prog);
}
Expand Down
64 changes: 44 additions & 20 deletions net/core/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -1046,17 +1046,21 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
if (!charged)
return -ENOMEM;

/* pre-charge to forward_alloc */
sk_memory_allocated_add(sk, pages);
allocated = sk_memory_allocated(sk);
/* If the system goes into memory pressure with this
* precharge, give up and return error.
*/
if (allocated > sk_prot_mem_limits(sk, 1)) {
sk_memory_allocated_sub(sk, pages);
mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
if (!mem_cgroup_sk_isolated(sk)) {
/* pre-charge to forward_alloc */
sk_memory_allocated_add(sk, pages);
allocated = sk_memory_allocated(sk);

/* If the system goes into memory pressure with this
* precharge, give up and return error.
*/
if (allocated > sk_prot_mem_limits(sk, 1)) {
sk_memory_allocated_sub(sk, pages);
mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
}
}

sk_forward_alloc_add(sk, pages << PAGE_SHIFT);

WRITE_ONCE(sk->sk_reserved_mem,
Expand Down Expand Up @@ -3153,8 +3157,11 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
if (likely(skb_page_frag_refill(32U, pfrag, sk->sk_allocation)))
return true;

sk_enter_memory_pressure(sk);
if (sk_should_enter_memory_pressure(sk))
sk_enter_memory_pressure(sk);

sk_stream_moderate_sndbuf(sk);

return false;
}
EXPORT_SYMBOL(sk_page_frag_refill);
Expand Down Expand Up @@ -3267,18 +3274,30 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
{
bool memcg_enabled = false, charged = false;
struct proto *prot = sk->sk_prot;
long allocated;

sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
long allocated = 0;

if (mem_cgroup_sk_enabled(sk)) {
bool isolated = mem_cgroup_sk_isolated(sk);

memcg_enabled = true;
charged = mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge());
if (!charged)

if (isolated && charged)
return 1;

if (!charged) {
if (!isolated) {
sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
}

goto suppress_allocation;
}
}

sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);

/* Under limit. */
if (allocated <= sk_prot_mem_limits(sk, 0)) {
sk_leave_memory_pressure(sk);
Expand Down Expand Up @@ -3357,7 +3376,8 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)

trace_sock_exceed_buf_limit(sk, prot, allocated, kind);

sk_memory_allocated_sub(sk, amt);
if (allocated)
sk_memory_allocated_sub(sk, amt);

if (charged)
mem_cgroup_sk_uncharge(sk, amt);
Expand Down Expand Up @@ -3396,11 +3416,15 @@ EXPORT_SYMBOL(__sk_mem_schedule);
*/
void __sk_mem_reduce_allocated(struct sock *sk, int amount)
{
sk_memory_allocated_sub(sk, amount);

if (mem_cgroup_sk_enabled(sk))
if (mem_cgroup_sk_enabled(sk)) {
mem_cgroup_sk_uncharge(sk, amount);

if (mem_cgroup_sk_isolated(sk))
return;
}

sk_memory_allocated_sub(sk, amount);

if (sk_under_global_memory_pressure(sk) &&
(sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
sk_leave_memory_pressure(sk);
Expand Down
37 changes: 37 additions & 0 deletions net/ipv4/af_inet.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
#include <net/checksum.h>
#include <net/ip.h>
#include <net/protocol.h>
#include <net/proto_memory.h>
#include <net/arp.h>
#include <net/route.h>
#include <net/ip_fib.h>
Expand Down Expand Up @@ -753,6 +754,42 @@ EXPORT_SYMBOL(inet_stream_connect);

void __inet_accept(struct socket *sock, struct socket *newsock, struct sock *newsk)
{
/* TODO: use sk_clone_lock() in SCTP and remove protocol checks */
if (mem_cgroup_sockets_enabled &&
(!IS_ENABLED(CONFIG_IP_SCTP) ||
sk_is_tcp(newsk) || sk_is_mptcp(newsk))) {
gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL;
unsigned short flags;

flags = mem_cgroup_sk_get_flags(newsk);
mem_cgroup_sk_alloc(newsk);

if (mem_cgroup_from_sk(newsk)) {
int amt;

mem_cgroup_sk_set_flags(newsk, flags);

/* The socket has not been accepted yet, no need
* to look at newsk->sk_wmem_queued.
*/
amt = sk_mem_pages(newsk->sk_forward_alloc +
atomic_read(&newsk->sk_rmem_alloc));
if (amt) {
/* This amt is already charged globally to
* sk_prot->memory_allocated due to lack of
* sk_memcg until accept(), thus we need to
* reclaim it here if newsk is isolated.
*/
if (mem_cgroup_sk_isolated(newsk))
sk_memory_allocated_sub(newsk, amt);

mem_cgroup_sk_charge(newsk, amt, gfp);
}
}

kmem_cache_charge(newsk, gfp);
}

sock_rps_record_flow(newsk);
WARN_ON(!((1 << newsk->sk_state) &
(TCPF_ESTABLISHED | TCPF_SYN_RECV |
Expand Down
Loading
Loading