Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions include/linux/skmsg.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ struct sk_psock {
struct sk_buff_head ingress_skb;
struct list_head ingress_msg;
spinlock_t ingress_lock;
ssize_t ingress_size;
unsigned long state;
struct list_head link;
spinlock_t link_lock;
Expand Down Expand Up @@ -141,6 +142,8 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
struct sk_msg *msg, u32 bytes);
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags);
int __sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags, int *from_self_copied);
bool sk_msg_is_readable(struct sock *sk);

static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes)
Expand Down Expand Up @@ -319,6 +322,16 @@ static inline void sock_drop(struct sock *sk, struct sk_buff *skb)
kfree_skb(skb);
}

static inline ssize_t sk_psock_get_msg_size(struct sk_psock *psock)
{
return READ_ONCE(psock->ingress_size);
}

static inline void sk_psock_inc_msg_size(struct sk_psock *psock, ssize_t diff)
{
WRITE_ONCE(psock->ingress_size, READ_ONCE(psock->ingress_size) + diff);
}

static inline bool sk_psock_queue_msg(struct sk_psock *psock,
struct sk_msg *msg)
{
Expand All @@ -327,6 +340,7 @@ static inline bool sk_psock_queue_msg(struct sk_psock *psock,
spin_lock_bh(&psock->ingress_lock);
if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) {
list_add_tail(&msg->list, &psock->ingress_msg);
sk_psock_inc_msg_size(psock, msg->sg.size);
ret = true;
} else {
sk_msg_free(psock->sk, msg);
Expand All @@ -343,18 +357,25 @@ static inline struct sk_msg *sk_psock_dequeue_msg(struct sk_psock *psock)

spin_lock_bh(&psock->ingress_lock);
msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
if (msg)
if (msg) {
list_del(&msg->list);
sk_psock_inc_msg_size(psock, -msg->sg.size);
}
spin_unlock_bh(&psock->ingress_lock);
return msg;
}

static inline struct sk_msg *__sk_psock_peek_msg(struct sk_psock *psock)
{
return list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
}

static inline struct sk_msg *sk_psock_peek_msg(struct sk_psock *psock)
{
struct sk_msg *msg;

spin_lock_bh(&psock->ingress_lock);
msg = list_first_entry_or_null(&psock->ingress_msg, struct sk_msg, list);
msg = __sk_psock_peek_msg(psock);
spin_unlock_bh(&psock->ingress_lock);
return msg;
}
Expand Down Expand Up @@ -521,6 +542,39 @@ static inline bool sk_psock_strp_enabled(struct sk_psock *psock)
return !!psock->saved_data_ready;
}

/* for tcp only, sk is locked */
static inline ssize_t sk_psock_msg_inq(struct sock *sk)
{
struct sk_psock *psock;
ssize_t inq = 0;

psock = sk_psock_get(sk);
if (likely(psock)) {
inq = sk_psock_get_msg_size(psock);
sk_psock_put(sk, psock);
}
return inq;
}

/* for udp only, sk is not locked */
static inline ssize_t sk_msg_first_length(struct sock *sk)
{
struct sk_psock *psock;
struct sk_msg *msg;
ssize_t inq = 0;

psock = sk_psock_get(sk);
if (likely(psock)) {
spin_lock_bh(&psock->ingress_lock);
msg = __sk_psock_peek_msg(psock);
if (msg)
inq = msg->sg.size;
spin_unlock_bh(&psock->ingress_lock);
sk_psock_put(sk, psock);
}
return inq;
}

#if IS_ENABLED(CONFIG_NET_SOCK_MSG)

#define BPF_F_STRPARSER (1UL << 1)
Expand Down
28 changes: 25 additions & 3 deletions net/core/skmsg.c
Original file line number Diff line number Diff line change
Expand Up @@ -409,14 +409,14 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
}
EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);

/* Receive sk_msg from psock->ingress_msg to @msg. */
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags)
int __sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags, int *from_self_copied)
{
struct iov_iter *iter = &msg->msg_iter;
int peek = flags & MSG_PEEK;
struct sk_msg *msg_rx;
int i, copied = 0;
bool to_self;

msg_rx = sk_psock_peek_msg(psock);
while (copied != len) {
Expand All @@ -425,6 +425,7 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
if (unlikely(!msg_rx))
break;

to_self = msg_rx->sk == sk;
i = msg_rx->sg.start;
do {
struct page *page;
Expand All @@ -443,6 +444,9 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
}

copied += copy;
if (to_self && from_self_copied)
*from_self_copied += copy;

if (likely(!peek)) {
sge->offset += copy;
sge->length -= copy;
Expand All @@ -451,6 +455,7 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
atomic_sub(copy, &sk->sk_rmem_alloc);
}
msg_rx->sg.size -= copy;
sk_psock_inc_msg_size(psock, -copy);

if (!sge->length) {
sk_msg_iter_var_next(i);
Expand Down Expand Up @@ -487,6 +492,14 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
out:
return copied;
}
EXPORT_SYMBOL_GPL(__sk_msg_recvmsg);

/* Receive sk_msg from psock->ingress_msg to @msg. */
int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg,
int len, int flags)
{
return __sk_msg_recvmsg(sk, psock, msg, len, flags, NULL);
}
EXPORT_SYMBOL_GPL(sk_msg_recvmsg);

bool sk_msg_is_readable(struct sock *sk)
Expand Down Expand Up @@ -616,6 +629,12 @@ static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb
if (unlikely(!msg))
return -EAGAIN;
skb_set_owner_r(skb, sk);

/* This is used in tcp_bpf_recvmsg_parser() to determine whether the
* data originates from the socket's own protocol stack. No need to
* refcount sk because msg's lifetime is bound to sk via the ingress_msg.
*/
msg->sk = sk;
err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg, take_ref);
if (err < 0)
kfree(msg);
Expand Down Expand Up @@ -801,9 +820,11 @@ static void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
list_del(&msg->list);
if (!msg->skb)
atomic_sub(msg->sg.size, &psock->sk->sk_rmem_alloc);
sk_psock_inc_msg_size(psock, -((ssize_t)msg->sg.size));
sk_msg_free(psock->sk, msg);
kfree(msg);
}
WARN_ON_ONCE(psock->ingress_size);
}

static void __sk_psock_zap_ingress(struct sk_psock *psock)
Expand Down Expand Up @@ -909,6 +930,7 @@ int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
sk_msg_compute_data_pointers(msg);
msg->sk = sk;
ret = bpf_prog_run_pin_on_cpu(prog, msg);
msg->sk = NULL;
ret = sk_psock_map_verd(ret, msg->sk_redir);
psock->apply_bytes = msg->apply_bytes;
if (ret == __SK_REDIRECT) {
Expand Down
26 changes: 24 additions & 2 deletions net/ipv4/tcp_bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <net/inet_common.h>
#include <net/tls.h>
#include <asm/ioctls.h>

void tcp_eat_skb(struct sock *sk, struct sk_buff *skb)
{
Expand Down Expand Up @@ -226,6 +227,7 @@ static int tcp_bpf_recvmsg_parser(struct sock *sk,
int peek = flags & MSG_PEEK;
struct sk_psock *psock;
struct tcp_sock *tcp;
int from_self_copied = 0;
int copied = 0;
u32 seq;

Expand Down Expand Up @@ -262,7 +264,7 @@ static int tcp_bpf_recvmsg_parser(struct sock *sk,
}

msg_bytes_ready:
copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
copied = __sk_msg_recvmsg(sk, psock, msg, len, flags, &from_self_copied);
/* The typical case for EFAULT is the socket was gracefully
* shutdown with a FIN pkt. So check here the other case is
* some error on copy_page_to_iter which would be unexpected.
Expand All @@ -277,7 +279,7 @@ static int tcp_bpf_recvmsg_parser(struct sock *sk,
goto out;
}
}
seq += copied;
seq += from_self_copied;
if (!copied) {
long timeo;
int data;
Expand Down Expand Up @@ -331,6 +333,25 @@ static int tcp_bpf_recvmsg_parser(struct sock *sk,
return copied;
}

static int tcp_bpf_ioctl(struct sock *sk, int cmd, int *karg)
{
bool slow;

/* we only care about FIONREAD */
if (cmd != SIOCINQ)
return tcp_ioctl(sk, cmd, karg);

/* works similar as tcp_ioctl */
if (sk->sk_state == TCP_LISTEN)
return -EINVAL;

slow = lock_sock_fast(sk);
*karg = sk_psock_msg_inq(sk);
unlock_sock_fast(sk, slow);

return 0;
}

static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int flags, int *addr_len)
{
Expand Down Expand Up @@ -609,6 +630,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
prot[TCP_BPF_BASE].close = sock_map_close;
prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg;
prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable;
prot[TCP_BPF_BASE].ioctl = tcp_bpf_ioctl;

prot[TCP_BPF_TX] = prot[TCP_BPF_BASE];
prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg;
Expand Down
25 changes: 21 additions & 4 deletions net/ipv4/udp_bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <net/sock.h>
#include <net/udp.h>
#include <net/inet_common.h>
#include <asm/ioctls.h>

#include "udp_impl.h"

Expand Down Expand Up @@ -111,12 +112,28 @@ enum {
static DEFINE_SPINLOCK(udpv6_prot_lock);
static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];

static int udp_bpf_ioctl(struct sock *sk, int cmd, int *karg)
{
/* we only care about FIONREAD */
if (cmd != SIOCINQ)
return udp_ioctl(sk, cmd, karg);

/* works similar as udp_ioctl.
* man udp(7): "FIONREAD (SIOCINQ): Returns the size of the next
* pending datagram in the integer in bytes, or 0 when no datagram
* is pending."
*/
*karg = sk_msg_first_length(sk);
return 0;
}

static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
{
*prot = *base;
prot->close = sock_map_close;
prot->recvmsg = udp_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
*prot = *base;
prot->close = sock_map_close;
prot->recvmsg = udp_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
prot->ioctl = udp_bpf_ioctl;
}

static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
Expand Down
Loading
Loading