Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions MAINTAINERS
Original file line number Diff line number Diff line change
Expand Up @@ -6470,7 +6470,11 @@ F: mm/memcontrol-v1.c
F: mm/memcontrol-v1.h
F: mm/page_counter.c
F: mm/swap_cgroup.c
F: samples/bpf/memcg.bpf.c
F: samples/bpf/memcg.c
F: samples/cgroup/*
F: tools/testing/selftests/bpf/prog_tests/memcg_ops.c
F: tools/testing/selftests/bpf/progs/memcg_ops.c
F: tools/testing/selftests/cgroup/memcg_protection.m
F: tools/testing/selftests/cgroup/test_hugetlb_memcg.c
F: tools/testing/selftests/cgroup/test_kmem.c
Expand Down
8 changes: 8 additions & 0 deletions include/linux/bpf.h
Original file line number Diff line number Diff line change
Expand Up @@ -1895,6 +1895,14 @@ struct bpf_raw_tp_link {
u64 cookie;
};

struct bpf_struct_ops_link {
struct bpf_link link;
struct bpf_map __rcu *map;
wait_queue_head_t wait_hup;
u64 cgroup_id;
u32 flags;
};

struct bpf_link_primer {
struct bpf_link *link;
struct file *file;
Expand Down
122 changes: 118 additions & 4 deletions include/linux/memcontrol.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <linux/writeback.h>
#include <linux/page-flags.h>
#include <linux/shrinker.h>
#include <linux/srcu.h>

struct mem_cgroup;
struct obj_cgroup;
Expand Down Expand Up @@ -181,6 +182,37 @@ struct obj_cgroup {
};
};

#ifdef CONFIG_BPF_SYSCALL
/**
* struct memcg_bpf_ops - BPF callbacks for memory cgroup operations
* @handle_cgroup_online: Called when a cgroup comes online
* @handle_cgroup_offline: Called when a cgroup goes offline
* @below_low: Override memory.low protection check. If this callback returns
* true, mem_cgroup_below_low() will return true immediately without
* performing the standard comparison. If it returns false, the
* original memory.low threshold comparison will proceed normally.
* @below_min: Override memory.min protection check. If this callback returns
* true, mem_cgroup_below_min() will return true immediately without
* performing the standard comparison. If it returns false, the
* original memory.min threshold comparison will proceed normally.
* @get_high_delay_ms: Return custom throttle delay in milliseconds
*
* This structure defines the interface for BPF programs to customize
* memory cgroup behavior through struct_ops programs.
*/
struct memcg_bpf_ops {
void (*handle_cgroup_online)(struct mem_cgroup *memcg);

void (*handle_cgroup_offline)(struct mem_cgroup *memcg);

bool (*below_low)(struct mem_cgroup *memcg);

bool (*below_min)(struct mem_cgroup *memcg);

unsigned int (*get_high_delay_ms)(struct mem_cgroup *memcg);
};
#endif /* CONFIG_BPF_SYSCALL */

/*
* The memory controller data structure. The memory controller controls both
* page cache and RSS per cgroup. We would eventually like to provide
Expand Down Expand Up @@ -321,6 +353,11 @@ struct mem_cgroup {
spinlock_t event_list_lock;
#endif /* CONFIG_MEMCG_V1 */

#ifdef CONFIG_BPF_SYSCALL
struct memcg_bpf_ops *bpf_ops;
u32 bpf_ops_flags;
#endif

struct mem_cgroup_per_node *nodeinfo[];
};

Expand Down Expand Up @@ -554,6 +591,76 @@ static inline bool mem_cgroup_disabled(void)
return !cgroup_subsys_enabled(memory_cgrp_subsys);
}

#ifdef CONFIG_BPF_SYSCALL

/* SRCU for protecting concurrent access to memcg->bpf_ops */
extern struct srcu_struct memcg_bpf_srcu;

/**
* BPF_MEMCG_CALL - Safely invoke a BPF memcg callback
* @memcg: The memory cgroup
* @op: The operation name (struct member)
* @default_val: Default return value if no BPF program attached
*
* This macro safely calls a BPF callback under SRCU protection.
*
* The first READ_ONCE() serves as a fast-path check to avoid the overhead
* of SRCU read lock acquisition when no BPF program is attached. This keeps
* the common no-BPF case performance unchanged. The second READ_ONCE() under
* SRCU protection ensures we see a consistent view of bpf_ops after acquiring
* the lock, protecting against concurrent updates.
*/
#define BPF_MEMCG_CALL(memcg, op, default_val) ({ \
typeof(default_val) __ret = (default_val); \
struct memcg_bpf_ops *__ops; \
int __idx; \
\
if (unlikely(READ_ONCE((memcg)->bpf_ops))) { \
__idx = srcu_read_lock(&memcg_bpf_srcu); \
__ops = READ_ONCE((memcg)->bpf_ops); \
if (__ops && __ops->op) \
__ret = __ops->op(memcg); \
srcu_read_unlock(&memcg_bpf_srcu, __idx); \
} \
__ret; \
})

static inline bool bpf_memcg_below_low(struct mem_cgroup *memcg)
{
return BPF_MEMCG_CALL(memcg, below_low, false);
}

static inline bool bpf_memcg_below_min(struct mem_cgroup *memcg)
{
return BPF_MEMCG_CALL(memcg, below_min, false);
}

static inline unsigned long bpf_memcg_get_high_delay(struct mem_cgroup *memcg)
{
unsigned int ret;

ret = BPF_MEMCG_CALL(memcg, get_high_delay_ms, 0U);
return msecs_to_jiffies(ret);
}

#undef BPF_MEMCG_CALL

extern void memcontrol_bpf_online(struct mem_cgroup *memcg);
extern void memcontrol_bpf_offline(struct mem_cgroup *memcg);

#else /* CONFIG_BPF_SYSCALL */

static inline unsigned long
bpf_memcg_get_high_delay(struct mem_cgroup *memcg) { return 0; }
static inline bool
bpf_memcg_below_low(struct mem_cgroup *memcg) { return false; }
static inline bool
bpf_memcg_below_min(struct mem_cgroup *memcg) { return false; }
static inline void memcontrol_bpf_online(struct mem_cgroup *memcg) { }
static inline void memcontrol_bpf_offline(struct mem_cgroup *memcg) { }

#endif /* CONFIG_BPF_SYSCALL */

static inline void mem_cgroup_protection(struct mem_cgroup *root,
struct mem_cgroup *memcg,
unsigned long *min,
Expand Down Expand Up @@ -625,6 +732,9 @@ static inline bool mem_cgroup_below_low(struct mem_cgroup *target,
if (mem_cgroup_unprotected(target, memcg))
return false;

if (bpf_memcg_below_low(memcg))
return true;

return READ_ONCE(memcg->memory.elow) >=
page_counter_read(&memcg->memory);
}
Expand All @@ -635,6 +745,9 @@ static inline bool mem_cgroup_below_min(struct mem_cgroup *target,
if (mem_cgroup_unprotected(target, memcg))
return false;

if (bpf_memcg_below_min(memcg))
return true;

return READ_ONCE(memcg->memory.emin) >=
page_counter_read(&memcg->memory);
}
Expand Down Expand Up @@ -833,9 +946,9 @@ static inline unsigned long mem_cgroup_ino(struct mem_cgroup *memcg)
{
return memcg ? cgroup_ino(memcg->css.cgroup) : 0;
}
#endif

struct mem_cgroup *mem_cgroup_get_from_ino(unsigned long ino);
#endif

static inline struct mem_cgroup *mem_cgroup_from_seq(struct seq_file *m)
{
Expand Down Expand Up @@ -909,12 +1022,13 @@ unsigned long mem_cgroup_get_zone_lru_size(struct lruvec *lruvec,
return READ_ONCE(mz->lru_zone_size[zone_idx][lru]);
}

void __mem_cgroup_handle_over_high(gfp_t gfp_mask);
void __mem_cgroup_handle_over_high(gfp_t gfp_mask,
unsigned long bpf_high_delay);

static inline void mem_cgroup_handle_over_high(gfp_t gfp_mask)
{
if (unlikely(current->memcg_nr_pages_over_high))
__mem_cgroup_handle_over_high(gfp_mask);
__mem_cgroup_handle_over_high(gfp_mask, 0);
}

unsigned long mem_cgroup_get_max(struct mem_cgroup *memcg);
Expand Down Expand Up @@ -1298,12 +1412,12 @@ static inline unsigned long mem_cgroup_ino(struct mem_cgroup *memcg)
{
return 0;
}
#endif

static inline struct mem_cgroup *mem_cgroup_get_from_ino(unsigned long ino)
{
return NULL;
}
#endif

static inline struct mem_cgroup *mem_cgroup_from_seq(struct seq_file *m)
{
Expand Down
22 changes: 16 additions & 6 deletions kernel/bpf/bpf_struct_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <linux/btf_ids.h>
#include <linux/rcupdate_wait.h>
#include <linux/poll.h>
#include <linux/cgroup.h>

struct bpf_struct_ops_value {
struct bpf_struct_ops_common_value common;
Expand Down Expand Up @@ -55,12 +56,6 @@ struct bpf_struct_ops_map {
struct bpf_struct_ops_value kvalue;
};

struct bpf_struct_ops_link {
struct bpf_link link;
struct bpf_map __rcu *map;
wait_queue_head_t wait_hup;
};

static DEFINE_MUTEX(update_mutex);

#define VALUE_PREFIX "bpf_struct_ops_"
Expand Down Expand Up @@ -1383,6 +1378,21 @@ int bpf_struct_ops_link_create(union bpf_attr *attr)
}
bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS, &bpf_struct_ops_map_lops, NULL,
attr->link_create.attach_type);
#ifdef CONFIG_CGROUPS
if (attr->link_create.cgroup.relative_fd) {
struct cgroup *cgrp;

cgrp = cgroup_get_from_fd(attr->link_create.cgroup.relative_fd);
if (IS_ERR(cgrp)) {
err = PTR_ERR(cgrp);
goto err_out;
}

link->cgroup_id = cgroup_id(cgrp);
cgroup_put(cgrp);
}
#endif /* CONFIG_CGROUPS */
link->flags = attr->link_create.flags;

err = bpf_link_prime(&link->link, &link_primer);
if (err)
Expand Down
5 changes: 5 additions & 0 deletions kernel/bpf/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -7254,6 +7254,10 @@ BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct vm_area_struct) {
struct file *vm_file;
};

BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct oom_control) {
struct mem_cgroup *memcg;
};

static bool type_is_rcu(struct bpf_verifier_env *env,
struct bpf_reg_state *reg,
const char *field_name, u32 btf_id)
Expand Down Expand Up @@ -7296,6 +7300,7 @@ static bool type_is_trusted_or_null(struct bpf_verifier_env *env,
BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct socket));
BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct dentry));
BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct vm_area_struct));
BTF_TYPE_EMIT(BTF_TYPE_SAFE_TRUSTED_OR_NULL(struct oom_control));

return btf_nested_type_is_trusted(&env->log, reg, field_name, btf_id,
"__safe_trusted_or_null");
Expand Down
Loading
Loading