@@ -2090,18 +2090,18 @@ static unsigned int run_filter(struct sk_buff *skb,
2090
2090
}
2091
2091
2092
2092
static int packet_rcv_vnet (struct msghdr * msg , const struct sk_buff * skb ,
2093
- size_t * len )
2093
+ size_t * len , int vnet_hdr_sz )
2094
2094
{
2095
- struct virtio_net_hdr vnet_hdr ;
2095
+ struct virtio_net_hdr_mrg_rxbuf vnet_hdr = { . num_buffers = 0 } ;
2096
2096
2097
- if (* len < sizeof ( vnet_hdr ) )
2097
+ if (* len < vnet_hdr_sz )
2098
2098
return - EINVAL ;
2099
- * len -= sizeof ( vnet_hdr ) ;
2099
+ * len -= vnet_hdr_sz ;
2100
2100
2101
- if (virtio_net_hdr_from_skb (skb , & vnet_hdr , vio_le (), true, 0 ))
2101
+ if (virtio_net_hdr_from_skb (skb , ( struct virtio_net_hdr * ) & vnet_hdr , vio_le (), true, 0 ))
2102
2102
return - EINVAL ;
2103
2103
2104
- return memcpy_to_msg (msg , (void * )& vnet_hdr , sizeof ( vnet_hdr ) );
2104
+ return memcpy_to_msg (msg , (void * )& vnet_hdr , vnet_hdr_sz );
2105
2105
}
2106
2106
2107
2107
/*
@@ -2250,7 +2250,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
2250
2250
__u32 ts_status ;
2251
2251
bool is_drop_n_account = false;
2252
2252
unsigned int slot_id = 0 ;
2253
- bool do_vnet = false ;
2253
+ int vnet_hdr_sz = 0 ;
2254
2254
2255
2255
/* struct tpacket{2,3}_hdr is aligned to a multiple of TPACKET_ALIGNMENT.
2256
2256
* We may add members to them until current aligned size without forcing
@@ -2308,10 +2308,9 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
2308
2308
netoff = TPACKET_ALIGN (po -> tp_hdrlen +
2309
2309
(maclen < 16 ? 16 : maclen )) +
2310
2310
po -> tp_reserve ;
2311
- if (packet_sock_flag (po , PACKET_SOCK_HAS_VNET_HDR )) {
2312
- netoff += sizeof (struct virtio_net_hdr );
2313
- do_vnet = true;
2314
- }
2311
+ vnet_hdr_sz = READ_ONCE (po -> vnet_hdr_sz );
2312
+ if (vnet_hdr_sz )
2313
+ netoff += vnet_hdr_sz ;
2315
2314
macoff = netoff - maclen ;
2316
2315
}
2317
2316
if (netoff > USHRT_MAX ) {
@@ -2337,7 +2336,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
2337
2336
snaplen = po -> rx_ring .frame_size - macoff ;
2338
2337
if ((int )snaplen < 0 ) {
2339
2338
snaplen = 0 ;
2340
- do_vnet = false ;
2339
+ vnet_hdr_sz = 0 ;
2341
2340
}
2342
2341
}
2343
2342
} else if (unlikely (macoff + snaplen >
@@ -2351,7 +2350,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
2351
2350
if (unlikely ((int )snaplen < 0 )) {
2352
2351
snaplen = 0 ;
2353
2352
macoff = GET_PBDQC_FROM_RB (& po -> rx_ring )-> max_frame_len ;
2354
- do_vnet = false ;
2353
+ vnet_hdr_sz = 0 ;
2355
2354
}
2356
2355
}
2357
2356
spin_lock (& sk -> sk_receive_queue .lock );
@@ -2367,7 +2366,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
2367
2366
__set_bit (slot_id , po -> rx_ring .rx_owner_map );
2368
2367
}
2369
2368
2370
- if (do_vnet &&
2369
+ if (vnet_hdr_sz &&
2371
2370
virtio_net_hdr_from_skb (skb , h .raw + macoff -
2372
2371
sizeof (struct virtio_net_hdr ),
2373
2372
vio_le (), true, 0 )) {
@@ -2551,16 +2550,26 @@ static int __packet_snd_vnet_parse(struct virtio_net_hdr *vnet_hdr, size_t len)
2551
2550
}
2552
2551
2553
2552
static int packet_snd_vnet_parse (struct msghdr * msg , size_t * len ,
2554
- struct virtio_net_hdr * vnet_hdr )
2553
+ struct virtio_net_hdr * vnet_hdr , int vnet_hdr_sz )
2555
2554
{
2556
- if (* len < sizeof (* vnet_hdr ))
2555
+ int ret ;
2556
+
2557
+ if (* len < vnet_hdr_sz )
2557
2558
return - EINVAL ;
2558
- * len -= sizeof ( * vnet_hdr ) ;
2559
+ * len -= vnet_hdr_sz ;
2559
2560
2560
2561
if (!copy_from_iter_full (vnet_hdr , sizeof (* vnet_hdr ), & msg -> msg_iter ))
2561
2562
return - EFAULT ;
2562
2563
2563
- return __packet_snd_vnet_parse (vnet_hdr , * len );
2564
+ ret = __packet_snd_vnet_parse (vnet_hdr , * len );
2565
+ if (ret )
2566
+ return ret ;
2567
+
2568
+ /* move iter to point to the start of mac header */
2569
+ if (vnet_hdr_sz != sizeof (struct virtio_net_hdr ))
2570
+ iov_iter_advance (& msg -> msg_iter , vnet_hdr_sz - sizeof (struct virtio_net_hdr ));
2571
+
2572
+ return 0 ;
2564
2573
}
2565
2574
2566
2575
static int tpacket_fill_skb (struct packet_sock * po , struct sk_buff * skb ,
@@ -2722,6 +2731,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
2722
2731
void * ph ;
2723
2732
DECLARE_SOCKADDR (struct sockaddr_ll * , saddr , msg -> msg_name );
2724
2733
bool need_wait = !(msg -> msg_flags & MSG_DONTWAIT );
2734
+ int vnet_hdr_sz = READ_ONCE (po -> vnet_hdr_sz );
2725
2735
unsigned char * addr = NULL ;
2726
2736
int tp_len , size_max ;
2727
2737
void * data ;
@@ -2779,8 +2789,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
2779
2789
size_max = po -> tx_ring .frame_size
2780
2790
- (po -> tp_hdrlen - sizeof (struct sockaddr_ll ));
2781
2791
2782
- if ((size_max > dev -> mtu + reserve + VLAN_HLEN ) &&
2783
- !packet_sock_flag (po , PACKET_SOCK_HAS_VNET_HDR ))
2792
+ if ((size_max > dev -> mtu + reserve + VLAN_HLEN ) && !vnet_hdr_sz )
2784
2793
size_max = dev -> mtu + reserve + VLAN_HLEN ;
2785
2794
2786
2795
reinit_completion (& po -> skb_completion );
@@ -2809,10 +2818,10 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
2809
2818
status = TP_STATUS_SEND_REQUEST ;
2810
2819
hlen = LL_RESERVED_SPACE (dev );
2811
2820
tlen = dev -> needed_tailroom ;
2812
- if (packet_sock_flag ( po , PACKET_SOCK_HAS_VNET_HDR ) ) {
2821
+ if (vnet_hdr_sz ) {
2813
2822
vnet_hdr = data ;
2814
- data += sizeof ( * vnet_hdr ) ;
2815
- tp_len -= sizeof ( * vnet_hdr ) ;
2823
+ data += vnet_hdr_sz ;
2824
+ tp_len -= vnet_hdr_sz ;
2816
2825
if (tp_len < 0 ||
2817
2826
__packet_snd_vnet_parse (vnet_hdr , tp_len )) {
2818
2827
tp_len = - EINVAL ;
@@ -2837,7 +2846,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
2837
2846
addr , hlen , copylen , & sockc );
2838
2847
if (likely (tp_len >= 0 ) &&
2839
2848
tp_len > dev -> mtu + reserve &&
2840
- !packet_sock_flag ( po , PACKET_SOCK_HAS_VNET_HDR ) &&
2849
+ !vnet_hdr_sz &&
2841
2850
!packet_extra_vlan_len_allowed (dev , skb ))
2842
2851
tp_len = - EMSGSIZE ;
2843
2852
@@ -2856,7 +2865,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
2856
2865
}
2857
2866
}
2858
2867
2859
- if (packet_sock_flag ( po , PACKET_SOCK_HAS_VNET_HDR ) ) {
2868
+ if (vnet_hdr_sz ) {
2860
2869
if (virtio_net_hdr_to_skb (skb , vnet_hdr , vio_le ())) {
2861
2870
tp_len = - EINVAL ;
2862
2871
goto tpacket_error ;
@@ -2946,7 +2955,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
2946
2955
struct virtio_net_hdr vnet_hdr = { 0 };
2947
2956
int offset = 0 ;
2948
2957
struct packet_sock * po = pkt_sk (sk );
2949
- bool has_vnet_hdr = false ;
2958
+ int vnet_hdr_sz = READ_ONCE ( po -> vnet_hdr_sz ) ;
2950
2959
int hlen , tlen , linear ;
2951
2960
int extra_len = 0 ;
2952
2961
@@ -2990,11 +2999,10 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
2990
2999
2991
3000
if (sock -> type == SOCK_RAW )
2992
3001
reserve = dev -> hard_header_len ;
2993
- if (packet_sock_flag ( po , PACKET_SOCK_HAS_VNET_HDR ) ) {
2994
- err = packet_snd_vnet_parse (msg , & len , & vnet_hdr );
3002
+ if (vnet_hdr_sz ) {
3003
+ err = packet_snd_vnet_parse (msg , & len , & vnet_hdr , vnet_hdr_sz );
2995
3004
if (err )
2996
3005
goto out_unlock ;
2997
- has_vnet_hdr = true;
2998
3006
}
2999
3007
3000
3008
if (unlikely (sock_flag (sk , SOCK_NOFCS ))) {
@@ -3064,11 +3072,11 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
3064
3072
3065
3073
packet_parse_headers (skb , sock );
3066
3074
3067
- if (has_vnet_hdr ) {
3075
+ if (vnet_hdr_sz ) {
3068
3076
err = virtio_net_hdr_to_skb (skb , & vnet_hdr , vio_le ());
3069
3077
if (err )
3070
3078
goto out_free ;
3071
- len += sizeof ( vnet_hdr ) ;
3079
+ len += vnet_hdr_sz ;
3072
3080
virtio_net_hdr_set_proto (skb , & vnet_hdr );
3073
3081
}
3074
3082
@@ -3408,7 +3416,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
3408
3416
struct sock * sk = sock -> sk ;
3409
3417
struct sk_buff * skb ;
3410
3418
int copied , err ;
3411
- int vnet_hdr_len = 0 ;
3419
+ int vnet_hdr_len = READ_ONCE ( pkt_sk ( sk ) -> vnet_hdr_sz ) ;
3412
3420
unsigned int origlen = 0 ;
3413
3421
3414
3422
err = - EINVAL ;
@@ -3449,11 +3457,10 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
3449
3457
3450
3458
packet_rcv_try_clear_pressure (pkt_sk (sk ));
3451
3459
3452
- if (packet_sock_flag ( pkt_sk ( sk ), PACKET_SOCK_HAS_VNET_HDR ) ) {
3453
- err = packet_rcv_vnet (msg , skb , & len );
3460
+ if (vnet_hdr_len ) {
3461
+ err = packet_rcv_vnet (msg , skb , & len , vnet_hdr_len );
3454
3462
if (err )
3455
3463
goto out_free ;
3456
- vnet_hdr_len = sizeof (struct virtio_net_hdr );
3457
3464
}
3458
3465
3459
3466
/* You lose any data beyond the buffer you gave. If it worries
@@ -3915,8 +3922,9 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
3915
3922
return 0 ;
3916
3923
}
3917
3924
case PACKET_VNET_HDR :
3925
+ case PACKET_VNET_HDR_SZ :
3918
3926
{
3919
- int val ;
3927
+ int val , hdr_len ;
3920
3928
3921
3929
if (sock -> type != SOCK_RAW )
3922
3930
return - EINVAL ;
@@ -3925,11 +3933,19 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
3925
3933
if (copy_from_sockptr (& val , optval , sizeof (val )))
3926
3934
return - EFAULT ;
3927
3935
3936
+ if (optname == PACKET_VNET_HDR_SZ ) {
3937
+ if (val && val != sizeof (struct virtio_net_hdr ) &&
3938
+ val != sizeof (struct virtio_net_hdr_mrg_rxbuf ))
3939
+ return - EINVAL ;
3940
+ hdr_len = val ;
3941
+ } else {
3942
+ hdr_len = val ? sizeof (struct virtio_net_hdr ) : 0 ;
3943
+ }
3928
3944
lock_sock (sk );
3929
3945
if (po -> rx_ring .pg_vec || po -> tx_ring .pg_vec ) {
3930
3946
ret = - EBUSY ;
3931
3947
} else {
3932
- packet_sock_flag_set (po , PACKET_SOCK_HAS_VNET_HDR , val );
3948
+ WRITE_ONCE (po -> vnet_hdr_sz , hdr_len );
3933
3949
ret = 0 ;
3934
3950
}
3935
3951
release_sock (sk );
@@ -4062,7 +4078,10 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
4062
4078
val = packet_sock_flag (po , PACKET_SOCK_ORIGDEV );
4063
4079
break ;
4064
4080
case PACKET_VNET_HDR :
4065
- val = packet_sock_flag (po , PACKET_SOCK_HAS_VNET_HDR );
4081
+ val = !!READ_ONCE (po -> vnet_hdr_sz );
4082
+ break ;
4083
+ case PACKET_VNET_HDR_SZ :
4084
+ val = READ_ONCE (po -> vnet_hdr_sz );
4066
4085
break ;
4067
4086
case PACKET_VERSION :
4068
4087
val = po -> tp_version ;
0 commit comments