-
Notifications
You must be signed in to change notification settings - Fork 170
Expand file tree
/
Copy pathconn.go
More file actions
670 lines (586 loc) · 19.3 KB
/
conn.go
File metadata and controls
670 lines (586 loc) · 19.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
// Copyright 2018 Google LLC. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nftables
import (
"errors"
"fmt"
"iter"
"math"
"os"
"sync"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/mdlayher/netlink"
"github.com/mdlayher/netlink/nltest"
"golang.org/x/sys/unix"
)
// A Conn represents a netlink connection of the nftables family.
//
// All methods return their input, so that variables can be defined from string
// literals when desired.
//
// Commands are buffered. Flush sends all buffered commands in a single batch.
type Conn struct {
TestDial nltest.Func // for testing only; passed to nltest.Dial
NetNS int // fd referencing the network namespace netlink will interact with.
lasting bool // establish a lasting connection to be used across multiple netlink operations.
mu sync.Mutex // protects the following state
messages []netlinkMessage
err error
nlconn *netlink.Conn // netlink socket using NETLINK_NETFILTER protocol.
sockOptions []SockOption
lastID uint32
allocatedIDs uint32
}
type netlinkMessage struct {
Header netlink.Header
Data []byte
rule *Rule
}
// ConnOption is an option to change the behavior of the nftables Conn returned by Open.
type ConnOption func(*Conn)
// SockOption is an option to change the behavior of the netlink socket used by the nftables Conn.
type SockOption func(*netlink.Conn) error
// New returns a netlink connection for querying and modifying nftables. Some
// aspects of the new netlink connection can be configured using the options
// WithNetNSFd, WithTestDial, and AsLasting.
//
// A lasting netlink connection should be closed by calling CloseLasting() to
// close the underlying lasting netlink connection, cancelling all pending
// operations using this connection.
func New(opts ...ConnOption) (*Conn, error) {
cc := &Conn{}
for _, opt := range opts {
opt(cc)
}
if !cc.lasting {
return cc, nil
}
nlconn, err := cc.dialNetlink()
if err != nil {
return nil, err
}
cc.nlconn = nlconn
return cc, nil
}
// AsLasting creates the new netlink connection as a lasting connection that is
// reused across multiple netlink operations, instead of opening and closing the
// underlying netlink connection only for the duration of a single netlink
// operation.
func AsLasting() ConnOption {
return func(cc *Conn) {
// We cannot create the underlying connection yet, as we are called
// anywhere in the option processing chain and there might be later
// options still modifying connection behavior.
cc.lasting = true
}
}
// WithNetNSFd sets the network namespace to create a new netlink connection to:
// the fd must reference a network namespace.
func WithNetNSFd(fd int) ConnOption {
return func(cc *Conn) {
cc.NetNS = fd
}
}
// WithTestDial sets the specified nltest.Func when creating a new netlink
// connection.
func WithTestDial(f nltest.Func) ConnOption {
return func(cc *Conn) {
cc.TestDial = f
}
}
// WithSockOptions sets the specified socket options when creating a new netlink
// connection. Note that when using WithSockOptions, you are responsible for
// providing a large-enough read and write buffer, whereas normally, the
// nftables package automatically enlarges the buffers as needed.
func WithSockOptions(opts ...SockOption) ConnOption {
return func(cc *Conn) {
cc.sockOptions = append(cc.sockOptions, opts...)
}
}
// netlinkCloser is returned by netlinkConn(UnderLock) and must be called after
// being done with the returned netlink connection in order to properly close
// this connection, if necessary.
type netlinkCloser func() error
// netlinkConn returns a netlink connection together with a netlinkCloser that
// later must be called by the caller when it doesn't need the returned netlink
// connection anymore. The netlinkCloser will close the netlink connection when
// necessary. If New has been told to create a lasting connection, then this
// lasting netlink connection will be returned, otherwise a new "transient"
// netlink connection will be opened and returned instead. netlinkConn must not
// be called while the Conn.mu lock is currently helt (this will cause a
// deadlock). Use netlinkConnUnderLock instead in such situations.
func (cc *Conn) netlinkConn() (*netlink.Conn, netlinkCloser, error) {
cc.mu.Lock()
defer cc.mu.Unlock()
return cc.netlinkConnUnderLock()
}
// netlinkConnUnderLock works like netlinkConn but must be called while holding
// the Conn.mu lock.
func (cc *Conn) netlinkConnUnderLock() (*netlink.Conn, netlinkCloser, error) {
if cc.nlconn != nil {
return cc.nlconn, func() error { return nil }, nil
}
nlconn, err := cc.dialNetlink()
if err != nil {
return nil, nil, err
}
return nlconn, func() error { return nlconn.Close() }, nil
}
// receiveSeq returns an iterator of messages to be read from the provided
// netlink connection filtering out non-nftables messages. It will stop
// iterating when the buffer is drained or in the case of a fatal error.
// Non-fatal errors encountered while receiving messages are yielded along with
// a zero-value message.
func (cc *Conn) receiveSeq(conn *netlink.Conn) iter.Seq2[netlink.Message, error] {
return func(yield func(netlink.Message, error) bool) {
if conn == nil {
yield(netlink.Message{}, errors.New("netlink conn is not initialized"))
return
}
for {
ready, err := cc.isReadReady(conn)
if err != nil {
yield(netlink.Message{}, err)
return
}
// Since SendMessages is blocking and netlink communication is
// synchronous, the kernel has already processed the request and queued
// any responses by the time SendMessages returns. Therefore, if
// isReadReady returns false on the first call, it means there are no
// messages coming at all and we can safely exit.
if !ready {
break
}
replies, err := conn.Receive()
if err != nil {
// Yield the error but continue iterating
if !yield(netlink.Message{}, err) {
return
}
continue
}
if len(replies) == 0 && cc.TestDial != nil {
// When using a test dial function, we don't always get a reply for each
// sent message. Additionally, there is no buffer to poll for more data,
// so we stop here.
return
}
for _, msg := range replies {
// Filter out non-nftables messages.
// In practice, this would only be netlink.Error messages.
// Those are handled by the netlink library itself and should be
// reported as errors by conn.Receive().
subsystem := msg.Header.Type >> 8
if subsystem != unix.NFNL_SUBSYS_NFTABLES {
continue
}
// Stop iteration if yield returns false
if !yield(msg, nil) {
return
}
}
}
}
}
// receive will drain the receive buffer of the provided netlink connection
// and return all received messages, along with the first error encountered,
// if any.
func (cc *Conn) receive(conn *netlink.Conn) ([]netlink.Message, error) {
var allReplies []netlink.Message
var firstErr error
for msg, err := range cc.receiveSeq(conn) {
if err != nil && firstErr == nil {
firstErr = err
continue
}
allReplies = append(allReplies, msg)
}
return allReplies, firstErr
}
// CloseLasting closes the lasting netlink connection that has been opened using
// AsLasting option when creating this connection. If either no lasting netlink
// connection has been opened or the lasting connection is already in the
// process of closing or has been closed, CloseLasting will immediately return
// without any error.
//
// CloseLasting will terminate all pending netlink operations using the lasting
// connection.
//
// After closing a lasting connection, the connection will revert to using
// on-demand transient netlink connections when calling further netlink
// operations (such as GetTables).
func (cc *Conn) CloseLasting() error {
// Don't acquire the lock for the whole duration of the CloseLasting
// operation, but instead only so long as to make sure to only run the
// netlink socket close on the first time with a lasting netlink socket. As
// there is only the New() constructor, but no Open() method, it's
// impossible to reopen a lasting connection.
cc.mu.Lock()
nlconn := cc.nlconn
cc.nlconn = nil
cc.mu.Unlock()
if nlconn != nil {
return nlconn.Close()
}
return nil
}
// Flush sends all buffered commands in a single batch to nftables.
func (cc *Conn) Flush() error {
return cc.flush(0)
}
// FlushWithGenID sends all buffered commands in a single batch to nftables
// along with the provided gen ID. If the ruleset has changed since the gen ID
// was retrieved, an ERESTART error will be returned.
func (cc *Conn) FlushWithGenID(genID uint32) error {
return cc.flush(genID)
}
// flush sends all buffered commands in a single batch to nftables. If genID is
// non-zero, it will be included in the batch messages.
func (cc *Conn) flush(genID uint32) error {
cc.mu.Lock()
defer func() {
cc.messages = nil
cc.allocatedIDs = 0
cc.err = nil
cc.mu.Unlock()
}()
if len(cc.messages) == 0 {
// Messages were already programmed, returning nil
return nil
}
if cc.err != nil {
return cc.err
}
conn, closer, err := cc.netlinkConnUnderLock()
if err != nil {
return err
}
defer func() { _ = closer() }()
if err := cc.enlargeWriteBuffer(conn); err != nil {
return err
}
if err := cc.enlargeReadBuffer(conn); err != nil {
return err
}
batch, err := batch(cc.messages, genID)
if err != nil {
return err
}
sentMsgs, err := conn.SendMessages(batch)
if err != nil {
return fmt.Errorf("SendMessages: %w", err)
}
var firstErr error
seqToMsgMap := cc.getSeqToMsgMap(sentMsgs)
for reply, err := range cc.receiveSeq(conn) {
if err != nil {
if firstErr == nil {
firstErr = cc.handleReceiveError(seqToMsgMap, err)
}
// Continue receiving further messages even after an error.
continue
}
if err := cc.handleEchoReply(seqToMsgMap, reply); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
// withOpError inspects err to see if it is a *netlink.OpError. If it is, it
// calls fn with the *netlink.OpError and returns the result. If it is not, it
// simply returns the err.
func (cc *Conn) withOpError(err error, fn func(*netlink.OpError) error) error {
if err == nil {
return nil
}
var opErr *netlink.OpError
if errors.As(err, &opErr) {
return fn(opErr)
}
return err
}
// handleReceiveError inspects err to see if it is a *netlink.OpError. If it is,
// it finds the original sent message using the sequence number from the error,
// parses its nftMsgType, and returns a new error that includes the nftMsgType
// string representation. If err is not a *netlink.OpError, it is simply
// returned as-is.
func (cc *Conn) handleReceiveError(msgs map[uint32]netlinkMessage, err error) error {
if err := cc.withOpError(err, func(opErr *netlink.OpError) error {
msg, ok := msgs[opErr.Sequence]
if !ok {
return opErr
}
nftMsgType, parseErr := parseNftMsgType(msg.Header.Type)
if parseErr != nil {
return opErr
}
return fmt.Errorf("%s: %w", nftMsgType.String(), opErr)
}); err != nil {
return fmt.Errorf("receive: %w", err)
}
return nil
}
// getSeqToMsgMap returns a map of the cc.messages that were sent, indexed by
// their sequence number as included in the sent netlink messages. The returned
// map will not include the batch begin and end messages.
func (cc *Conn) getSeqToMsgMap(sentMsgs []netlink.Message) map[uint32]netlinkMessage {
seqToMsgMap := make(map[uint32]netlinkMessage)
for i, msg := range sentMsgs {
if i == 0 || i == len(sentMsgs)-1 {
// Skip batch begin and end messages.
continue
}
if i-1 >= len(cc.messages) {
// Should not happen, but be defensive.
break
}
// Update the header in the original message, as the sequence number
// and possibly other fields have been updated by the the underlying
// netlink library.
cc.messages[i-1].Header = msg.Header
seqToMsgMap[msg.Header.Sequence] = cc.messages[i-1]
}
return seqToMsgMap
}
func (cc *Conn) handleEchoReply(seqToMsgMap map[uint32]netlinkMessage, reply netlink.Message) error {
sentMsg, ok := seqToMsgMap[reply.Header.Sequence]
if !ok {
// We don't have a record of sending this message, ignore.
return nil
}
if sentMsg.Header.Flags&netlink.Echo == 0 {
return nil
}
switch reply.Header.Type {
case newRuleHeaderType:
// The only messages which set the echo flag are rule create messages.
return sentMsg.rule.handleCreateReply(reply)
default:
return nil
}
}
// FlushRuleset flushes the entire ruleset. See also
// https://wiki.nftables.org/wiki-nftables/index.php/Operations_at_ruleset_level
func (cc *Conn) FlushRuleset() {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.messages = append(cc.messages, netlinkMessage{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE),
Flags: netlink.Request | netlink.Create,
},
Data: extraHeader(0, 0),
})
}
func (cc *Conn) dialNetlink() (*netlink.Conn, error) {
var (
conn *netlink.Conn
err error
)
if cc.TestDial != nil {
conn = nltest.Dial(cc.TestDial)
} else {
conn, err = netlink.Dial(unix.NETLINK_NETFILTER, &netlink.Config{NetNS: cc.NetNS})
}
if err != nil {
return nil, err
}
for _, opt := range cc.sockOptions {
if err := opt(conn); err != nil {
return nil, err
}
}
return conn, nil
}
func (cc *Conn) setErr(err error) {
if cc.err != nil {
return
}
cc.err = err
}
func (cc *Conn) marshalAttr(attrs []netlink.Attribute) []byte {
b, err := netlink.MarshalAttributes(attrs)
if err != nil {
cc.setErr(err)
return nil
}
return b
}
func (cc *Conn) marshalExpr(fam byte, e expr.Any) []byte {
b, err := expr.Marshal(fam, e)
if err != nil {
cc.setErr(err)
return nil
}
return b
}
// batch wraps the given messages in a batch begin and end message, and returns
// the resulting slice of netlink messages. If the genID is non-zero, it will be
// included in both batch messages.
func batch(messages []netlinkMessage, genID uint32) ([]netlink.Message, error) {
batch := make([]netlink.Message, len(messages)+2)
data := extraHeader(0, unix.NFNL_SUBSYS_NFTABLES)
if genID > 0 {
attr, err := netlink.MarshalAttributes([]netlink.Attribute{
{Type: unix.NFNL_BATCH_GENID, Data: binaryutil.BigEndian.PutUint32(genID)},
})
if err != nil {
return nil, err
}
data = append(data, attr...)
}
batch[0] = netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN),
Flags: netlink.Request,
},
Data: data,
}
for i, msg := range messages {
batch[i+1] = netlink.Message{
Header: msg.Header,
Data: msg.Data,
}
}
batch[len(messages)+1] = netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END),
Flags: netlink.Request,
},
Data: data,
}
return batch, nil
}
// allocateTransactionID allocates an identifier which is only valid in the
// current transaction.
func (cc *Conn) allocateTransactionID() uint32 {
if cc.allocatedIDs == math.MaxUint32 {
panic(fmt.Sprintf("trying to allocate more than %d IDs in a single nftables transaction", math.MaxUint32))
}
// To make it more likely to catch when a transaction ID is erroneously used
// in a later transaction, cc.lastID is not reset after each transaction;
// instead it is only reset once it rolls over from math.MaxUint32 to 0.
cc.allocatedIDs++
cc.lastID++
if cc.lastID == 0 {
cc.lastID = 1
}
return cc.lastID
}
// getMessageSize returns the total size of all messages in the buffer.
func (cc *Conn) getMessageSize() int {
var total int
for _, msg := range cc.messages {
total += len(msg.Data) + unix.NLMSG_HDRLEN
}
return total
}
// canEnlargeBuffers returns true if the connection can automatically enlarge
// the write and read buffers of the netlink connection.
func (cc *Conn) canEnlargeBuffers() bool {
// If there are sock options, we assume that the user has already set the
// buffers to a fixed size.
if len(cc.sockOptions) > 0 {
return false
}
if cc.TestDial != nil {
return false
}
return true
}
// enlargeWriteBuffer automatically sets the write buffer of the given
// connection to the accumulated message size. This is only done if the current
// write buffer is smaller than the message size.
//
// nftables actually handles this differently, it multiplies the number of
// iovec entries by 2MB. This is not possible to do here as our underlying
// netlink and socket libraries will only add a single iovec entry and
// won't expose the number of entries.
// https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n262
//
// TODO: Update this function to mimic the behavior of nftables once our
// socket library supports multiple iovec entries.
func (cc *Conn) enlargeWriteBuffer(conn *netlink.Conn) error {
if !cc.canEnlargeBuffers() {
return nil
}
messageSize := cc.getMessageSize()
writeBuffer, err := conn.WriteBuffer()
if err != nil {
return err
}
if writeBuffer < messageSize {
return conn.SetWriteBuffer(messageSize)
}
return nil
}
// getDefaultEchoReadBuffer returns the minimum read buffer size for batches
// with echo messages.
//
// See https://git.netfilter.org/libmnl/tree/include/libmnl/libmnl.h?id=03da98bcd284d55212bc79e91dfb63da0ef7b937#n20
// and https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n391
func (cc *Conn) getDefaultEchoReadBuffer() int {
pageSize := os.Getpagesize()
return max(pageSize, 8192) * 1024
}
// enlargeReadBuffer automatically sets the read buffer of the given connection
// to the required size. This is only done if the current read buffer is smaller
// than the required size.
//
// See https://git.netfilter.org/nftables/tree/src/mnl.c?id=713592c6008a8c589a00d3d3d2e49709ff2de62c#n426
func (cc *Conn) enlargeReadBuffer(conn *netlink.Conn) error {
if !cc.canEnlargeBuffers() {
return nil
}
var bufferSize int
// If there are any messages with the Echo flag, we initialize the buffer size
// to the default echo read buffer size.
for _, msg := range cc.messages {
if msg.Header.Flags&netlink.Echo == 0 {
bufferSize = cc.getDefaultEchoReadBuffer()
break
}
}
// Just like nftables, we allocate 1024 bytes for each message in the batch.
requiredSize := len(cc.messages) * 1024
if bufferSize < requiredSize {
bufferSize = requiredSize
}
currSize, err := conn.ReadBuffer()
if err != nil {
return err
}
if currSize < bufferSize {
return conn.SetReadBuffer(bufferSize)
}
return nil
}
// getPortIDUnderLock returns the netlink port ID associated with this
// connection. It must be called while holding the Conn.mu lock.
func (cc *Conn) getPortIDUnderLock() (uint32, error) {
conn, closer, err := cc.netlinkConnUnderLock()
if err != nil {
return 0, err
}
defer func() { _ = closer() }()
return conn.PID(), nil
}
// GetPortID returns the netlink port ID associated with this connection.
func (cc *Conn) GetPortID() (uint32, error) {
cc.mu.Lock()
defer cc.mu.Unlock()
pid, err := cc.getPortIDUnderLock()
return pid, err
}