@@ -9,10 +9,8 @@ import (
99 "time"
1010
1111 "github.com/sagernet/sing-box/adapter"
12- "github.com/sagernet/sing-box/common/dialer"
1312 C "github.com/sagernet/sing-box/constant"
1413 "github.com/sagernet/sing-box/dns"
15- "github.com/sagernet/sing-box/dns/transport"
1614 "github.com/sagernet/sing-box/log"
1715 "github.com/sagernet/sing-box/option"
1816 "github.com/sagernet/sing-tun"
@@ -29,6 +27,7 @@ import (
2927
3028 "github.com/insomniacslk/dhcp/dhcpv4"
3129 mDNS "github.com/miekg/dns"
30+ "golang.org/x/exp/slices"
3231)
3332
3433func RegisterTransport (registry * dns.TransportRegistry ) {
@@ -45,9 +44,12 @@ type Transport struct {
4544 networkManager adapter.NetworkManager
4645 interfaceName string
4746 interfaceCallback * list.Element [tun.DefaultInterfaceUpdateCallback ]
48- transports []adapter.DNSTransport
49- updateAccess sync.Mutex
47+ transportLock sync.RWMutex
5048 updatedAt time.Time
49+ servers []M.Socksaddr
50+ search []string
51+ ndots int
52+ attempts int
5153}
5254
5355func NewTransport (ctx context.Context , logger log.ContextLogger , tag string , options option.DHCPDNSServerOptions ) (adapter.DNSTransport , error ) {
@@ -62,16 +64,28 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt
6264 logger : logger ,
6365 networkManager : service.FromContext [adapter.NetworkManager ](ctx ),
6466 interfaceName : options .Interface ,
67+ ndots : 1 ,
68+ attempts : 2 ,
6569 }, nil
6670}
6771
72+ func NewRawTransport (transportAdapter dns.TransportAdapter , ctx context.Context , dialer N.Dialer , logger log.ContextLogger ) * Transport {
73+ return & Transport {
74+ TransportAdapter : transportAdapter ,
75+ ctx : ctx ,
76+ dialer : dialer ,
77+ logger : logger ,
78+ networkManager : service.FromContext [adapter.NetworkManager ](ctx ),
79+ }
80+ }
81+
6882func (t * Transport ) Start (stage adapter.StartStage ) error {
6983 if stage != adapter .StartStateStart {
7084 return nil
7185 }
72- err := t .fetchServers ()
86+ _ , err := t .Fetch ()
7387 if err != nil {
74- return err
88+ t . logger . Error ( E . Cause ( err , "fetch DNS servers" ))
7589 }
7690 if t .interfaceName == "" {
7791 t .interfaceCallback = t .networkManager .InterfaceMonitor ().RegisterCallback (t .interfaceUpdated )
@@ -80,33 +94,51 @@ func (t *Transport) Start(stage adapter.StartStage) error {
8094}
8195
8296func (t * Transport ) Close () error {
83- for _ , transport := range t .transports {
84- transport .Close ()
85- }
8697 if t .interfaceCallback != nil {
8798 t .networkManager .InterfaceMonitor ().UnregisterCallback (t .interfaceCallback )
8899 }
89100 return nil
90101}
91102
92103func (t * Transport ) Exchange (ctx context.Context , message * mDNS.Msg ) (* mDNS.Msg , error ) {
93- err := t .fetchServers ()
104+ servers , err := t .Fetch ()
94105 if err != nil {
95106 return nil , err
96107 }
97-
98- if len (t .transports ) == 0 {
108+ if len (servers ) == 0 {
99109 return nil , E .New ("dhcp: empty DNS servers from response" )
100110 }
111+ return t .Exchange0 (ctx , message , servers )
112+ }
101113
102- var response * mDNS.Msg
103- for _ , transport := range t .transports {
104- response , err = transport .Exchange (ctx , message )
105- if err == nil {
106- return response , nil
107- }
114+ func (t * Transport ) Exchange0 (ctx context.Context , message * mDNS.Msg , servers []M.Socksaddr ) (* mDNS.Msg , error ) {
115+ question := message .Question [0 ]
116+ domain := dns .FqdnToDomain (question .Name )
117+ if len (servers ) == 1 || ! (message .Question [0 ].Qtype == mDNS .TypeA || message .Question [0 ].Qtype == mDNS .TypeAAAA ) {
118+ return t .exchangeSingleRequest (ctx , servers , message , domain )
119+ } else {
120+ return t .exchangeParallel (ctx , servers , message , domain )
108121 }
109- return nil , err
122+ }
123+
124+ func (t * Transport ) Fetch () ([]M.Socksaddr , error ) {
125+ t .transportLock .RLock ()
126+ updatedAt := t .updatedAt
127+ servers := t .servers
128+ t .transportLock .RUnlock ()
129+ if time .Since (updatedAt ) < C .DHCPTTL {
130+ return servers , nil
131+ }
132+ t .transportLock .Lock ()
133+ defer t .transportLock .Unlock ()
134+ if time .Since (t .updatedAt ) < C .DHCPTTL {
135+ return t .servers , nil
136+ }
137+ err := t .updateServers ()
138+ if err != nil {
139+ return nil , err
140+ }
141+ return t .servers , nil
110142}
111143
112144func (t * Transport ) fetchInterface () (* control.Interface , error ) {
@@ -124,18 +156,6 @@ func (t *Transport) fetchInterface() (*control.Interface, error) {
124156 }
125157}
126158
127- func (t * Transport ) fetchServers () error {
128- if time .Since (t .updatedAt ) < C .DHCPTTL {
129- return nil
130- }
131- t .updateAccess .Lock ()
132- defer t .updateAccess .Unlock ()
133- if time .Since (t .updatedAt ) < C .DHCPTTL {
134- return nil
135- }
136- return t .updateServers ()
137- }
138-
139159func (t * Transport ) updateServers () error {
140160 iface , err := t .fetchInterface ()
141161 if err != nil {
@@ -148,7 +168,7 @@ func (t *Transport) updateServers() error {
148168 cancel ()
149169 if err != nil {
150170 return err
151- } else if len (t .transports ) == 0 {
171+ } else if len (t .servers ) == 0 {
152172 return E .New ("dhcp: empty DNS servers response" )
153173 } else {
154174 t .updatedAt = time .Now ()
@@ -177,7 +197,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *control.Interface)
177197 }
178198 defer packetConn .Close ()
179199
180- discovery , err := dhcpv4 .NewDiscovery (iface .HardwareAddr , dhcpv4 .WithBroadcast (true ), dhcpv4 .WithRequestedOptions (dhcpv4 .OptionDomainNameServer ))
200+ discovery , err := dhcpv4 .NewDiscovery (iface .HardwareAddr , dhcpv4 .WithBroadcast (true ), dhcpv4 .WithRequestedOptions (dhcpv4 .OptionDomainNameServer , dhcpv4 . OptionDNSDomainSearchList ))
181201 if err != nil {
182202 return err
183203 }
@@ -223,31 +243,21 @@ func (t *Transport) fetchServersResponse(iface *control.Interface, packetConn ne
223243 continue
224244 }
225245
226- dns := dhcpPacket .DNS ()
227- if len (dns ) == 0 {
228- return nil
229- }
230- return t .recreateServers (iface , common .Map (dns , func (it net.IP ) M.Socksaddr {
231- return M .SocksaddrFrom (M .AddrFromIP (it ), 53 )
232- }))
246+ return t .recreateServers (iface , dhcpPacket )
233247 }
234248}
235249
236- func (t * Transport ) recreateServers (iface * control.Interface , serverAddrs []M.Socksaddr ) error {
237- if len (serverAddrs ) > 0 {
238- t .logger .Info ("dhcp: updated DNS servers from " , iface .Name , ": [" , strings .Join (common .Map (serverAddrs , M .Socksaddr .String ), "," ), "]" )
250+ func (t * Transport ) recreateServers (iface * control.Interface , dhcpPacket * dhcpv4.DHCPv4 ) error {
251+ searchList := dhcpPacket .DomainSearch ()
252+ if searchList != nil {
253+ t .search = searchList .Labels
239254 }
240- serverDialer := common .Must1 (dialer .NewDefault (t .ctx , option.DialerOptions {
241- BindInterface : iface .Name ,
242- UDPFragmentDefault : true ,
243- }))
244- var transports []adapter.DNSTransport
245- for _ , serverAddr := range serverAddrs {
246- transports = append (transports , transport .NewUDPRaw (t .logger , t .TransportAdapter , serverDialer , serverAddr ))
247- }
248- for _ , transport := range t .transports {
249- transport .Close ()
255+ serverAddrs := common .Map (dhcpPacket .DNS (), func (it net.IP ) M.Socksaddr {
256+ return M .SocksaddrFrom (M .AddrFromIP (it ), 53 )
257+ })
258+ if len (serverAddrs ) > 0 && ! slices .Equal (t .servers , serverAddrs ) {
259+ t .logger .Info ("dhcp: updated DNS servers from " , iface .Name , ": [" , strings .Join (common .Map (serverAddrs , M .Socksaddr .String ), "," ), "], search: [" , strings .Join (t .search , "," ), "]" )
250260 }
251- t .transports = transports
261+ t .servers = serverAddrs
252262 return nil
253263}
0 commit comments