Skip to content

Commit d30e7bc

Browse files
authored
GODRIVER-2045 Propose an example of custom dialer with DNS cache. (#1088)
Propose an example of custom dialer with DNS cache.
1 parent cb7150a commit d30e7bc

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

examples/example_customdns_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright (C) MongoDB, Inc. 2022-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package examples
8+
9+
import (
10+
"context"
11+
"log"
12+
"net"
13+
"sync"
14+
"testing"
15+
"time"
16+
17+
"github.com/miekg/dns"
18+
"go.mongodb.org/mongo-driver/bson"
19+
"go.mongodb.org/mongo-driver/mongo"
20+
"go.mongodb.org/mongo-driver/mongo/options"
21+
)
22+
23+
func resolve(ctx context.Context, cache *dnsCache, in *dns.Conn, out *dns.Conn) {
24+
for ctx.Err() == nil {
25+
q, err := in.ReadMsg()
26+
if err != nil {
27+
// TODO: Handle error.
28+
log.Fatalf("Unhandled error in ReadMsg: %v", err)
29+
}
30+
if len(q.Question) != 1 {
31+
// Multiple questions in a single query is not actually used in real life.
32+
continue
33+
}
34+
35+
a, err := func() (*dns.Msg, error) {
36+
cache.lock.Lock()
37+
defer cache.lock.Unlock()
38+
39+
now := time.Now()
40+
if rr, ok := cache.records[q.Question[0]]; ok && rr.exp.After(now) {
41+
a := new(dns.Msg)
42+
a.SetReply(q)
43+
a.Compress = false
44+
a.Answer = append(a.Answer, rr.record)
45+
return a, nil
46+
}
47+
48+
err := out.WriteMsg(q)
49+
if err != nil {
50+
return nil, err
51+
}
52+
53+
m, err := out.ReadMsg()
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
l := len(m.Answer)
59+
for i, q := range m.Question {
60+
if i >= l {
61+
break
62+
}
63+
a := m.Answer[i]
64+
cache.records[q] = &RR{
65+
a,
66+
now.Add(time.Second * time.Duration(a.Header().Ttl)),
67+
}
68+
}
69+
return m, nil
70+
}()
71+
if err != nil {
72+
// TODO: Handle error.
73+
log.Fatalf("Unhandled error in record retrieval: %v", err)
74+
}
75+
76+
if err := in.WriteMsg(a); err != nil {
77+
// TODO: Handle error.
78+
log.Fatalf("Unhandled error in WriteMsg: %v", err)
79+
}
80+
}
81+
}
82+
83+
type RR struct {
84+
record dns.RR
85+
exp time.Time
86+
}
87+
88+
type dnsCache struct {
89+
records map[dns.Question]*RR
90+
lock sync.Mutex
91+
}
92+
93+
type dialer struct {
94+
*net.Dialer
95+
cache *dnsCache
96+
}
97+
98+
func NewDialer() dialer {
99+
cache := &dnsCache{
100+
records: make(map[dns.Question]*RR),
101+
lock: sync.Mutex{},
102+
}
103+
return dialer{
104+
Dialer: &net.Dialer{
105+
Resolver: &net.Resolver{
106+
PreferGo: true,
107+
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
108+
var d net.Dialer
109+
outConn, err := d.DialContext(ctx, network, address)
110+
conn, inConn := net.Pipe()
111+
if err == nil {
112+
go resolve(ctx, cache, &dns.Conn{Conn: inConn}, &dns.Conn{Conn: outConn})
113+
}
114+
return conn, err
115+
},
116+
},
117+
},
118+
cache: cache,
119+
}
120+
}
121+
122+
func TestCustomDialer(t *testing.T) {
123+
client, err := mongo.NewClient(options.Client().ApplyURI("mongodb://testurl:27017").SetDialer(NewDialer()))
124+
if err != nil {
125+
t.Fatalf("error creating client: %v", err)
126+
}
127+
ctx := context.Background()
128+
err = client.Connect(ctx)
129+
if err != nil {
130+
t.Fatalf("error connecting: %v", err)
131+
}
132+
defer client.Disconnect(context.Background())
133+
coll := client.Database("test").Collection("test")
134+
_, err = coll.InsertOne(context.Background(), bson.D{{"text", "text"}})
135+
if err != nil {
136+
t.Fatalf("error inserting: %v", err)
137+
}
138+
}

0 commit comments

Comments
 (0)