1
+ use std:: collections:: BTreeMap ;
2
+ use std:: hash:: Hash ;
1
3
use std:: net:: { Ipv4Addr , Ipv6Addr } ;
2
4
use std:: num:: NonZeroUsize ;
3
5
use std:: sync:: { Arc , Mutex , Weak } ;
4
6
5
7
use futures:: future:: poll_fn;
6
8
use lru:: LruCache ;
7
- use tokio:: sync:: Semaphore ;
9
+ use serde:: { Deserialize , Serialize } ;
10
+ use tokio:: sync:: { Notify , Semaphore } ;
8
11
use trust_dns_resolver:: proto:: op:: { Message as DnsMessage , MessageType , ResponseCode } ;
9
12
use trust_dns_resolver:: proto:: rr:: { RData , Record , RecordType } ;
10
13
use trust_dns_resolver:: proto:: serialize:: binary:: BinDecodable ;
11
14
15
+ use crate :: data:: PluginCache ;
12
16
use crate :: flow:: * ;
13
17
14
- const CACHE_CAPAICTY : NonZeroUsize = unsafe { NonZeroUsize :: new_unchecked ( 1024 ) } ;
18
+ const CACHE_CAPACITY : NonZeroUsize = NonZeroUsize :: new ( 1024 ) . unwrap ( ) ;
19
+ const REVERSE_MAPPING_V4_CACHE_KEY : & str = "rev_v4" ;
20
+ const REVERSE_MAPPING_V6_CACHE_KEY : & str = "rev_v6" ;
15
21
16
- pub struct DnsDatagramHandler {
22
+ pub struct DnsServer {
17
23
concurrency_limit : Arc < Semaphore > ,
18
24
resolver : Weak < dyn Resolver > ,
19
25
ttl : u32 ,
20
26
pub ( super ) reverse_mapping_v4 : Arc < Mutex < LruCache < Ipv4Addr , String > > > ,
21
27
pub ( super ) reverse_mapping_v6 : Arc < Mutex < LruCache < Ipv6Addr , String > > > ,
28
+ plugin_cache : PluginCache ,
29
+ pub ( super ) new_notify : Arc < Notify > ,
22
30
}
23
31
24
- impl DnsDatagramHandler {
25
- pub fn new ( concurrency_limit : usize , resolver : Weak < dyn Resolver > , ttl : u32 ) -> Self {
32
+ #[ derive( Debug , Clone , Default , PartialOrd , Ord , PartialEq , Eq , Serialize , Deserialize ) ]
33
+ #[ serde( transparent) ]
34
+ struct ReverseMappingCache < T : Ord > ( BTreeMap < T , String > ) ;
35
+
36
+ impl DnsServer {
37
+ pub fn new (
38
+ concurrency_limit : usize ,
39
+ resolver : Weak < dyn Resolver > ,
40
+ ttl : u32 ,
41
+ plugin_cache : PluginCache ,
42
+ ) -> Self {
26
43
let concurrency_limit = Arc :: new ( Semaphore :: new ( concurrency_limit) ) ;
27
- DnsDatagramHandler {
44
+ let mut reverse_mapping_v4 = LruCache :: new ( CACHE_CAPACITY ) ;
45
+ let mut reverse_mapping_v6 = LruCache :: new ( CACHE_CAPACITY ) ;
46
+ if let Some ( reverse_mapping_v4_cache) = plugin_cache
47
+ . get :: < ReverseMappingCache < _ > > ( REVERSE_MAPPING_V4_CACHE_KEY )
48
+ . ok ( )
49
+ . flatten ( )
50
+ {
51
+ for ( k, v) in reverse_mapping_v4_cache. 0 {
52
+ reverse_mapping_v4. put ( k, v) ;
53
+ }
54
+ }
55
+ if let Some ( reverse_mapping_v6_cache) = plugin_cache
56
+ . get :: < ReverseMappingCache < _ > > ( REVERSE_MAPPING_V6_CACHE_KEY )
57
+ . ok ( )
58
+ . flatten ( )
59
+ {
60
+ for ( k, v) in reverse_mapping_v6_cache. 0 {
61
+ reverse_mapping_v6. put ( k, v) ;
62
+ }
63
+ }
64
+ DnsServer {
28
65
concurrency_limit,
29
66
resolver,
30
67
ttl,
31
- reverse_mapping_v4 : Arc :: new ( Mutex :: new ( LruCache :: new ( CACHE_CAPAICTY ) ) ) ,
32
- reverse_mapping_v6 : Arc :: new ( Mutex :: new ( LruCache :: new ( CACHE_CAPAICTY ) ) ) ,
68
+ reverse_mapping_v4 : Arc :: new ( Mutex :: new ( reverse_mapping_v4) ) ,
69
+ reverse_mapping_v6 : Arc :: new ( Mutex :: new ( reverse_mapping_v6) ) ,
70
+ plugin_cache,
71
+ new_notify : Arc :: new ( Notify :: new ( ) ) ,
33
72
}
34
73
}
74
+
75
+ fn save_reverse_mapping_cache < T : Serialize + Hash + Eq + Ord + Clone > (
76
+ & self ,
77
+ cache : & Mutex < LruCache < T , String > > ,
78
+ key : & str ,
79
+ ) {
80
+ let cache = {
81
+ let inner = cache. lock ( ) . unwrap ( ) ;
82
+ ReverseMappingCache (
83
+ ( & * inner)
84
+ . iter ( )
85
+ . map ( |( k, v) | ( k. clone ( ) , v. clone ( ) ) )
86
+ . collect ( ) ,
87
+ )
88
+ } ;
89
+ self . plugin_cache . set ( key, & cache) . ok ( ) ;
90
+ }
91
+ pub ( crate ) fn save_cache ( & self ) {
92
+ self . save_reverse_mapping_cache ( & self . reverse_mapping_v4 , REVERSE_MAPPING_V4_CACHE_KEY ) ;
93
+ self . save_reverse_mapping_cache ( & self . reverse_mapping_v6 , REVERSE_MAPPING_V6_CACHE_KEY ) ;
94
+ }
35
95
}
36
96
37
- impl DatagramSessionHandler for DnsDatagramHandler {
97
+ impl DatagramSessionHandler for DnsServer {
38
98
fn on_session ( & self , mut session : Box < dyn DatagramSession > , _context : Box < FlowContext > ) {
39
99
let resolver = match self . resolver . upgrade ( ) {
40
100
Some ( resolver) => resolver,
@@ -44,6 +104,7 @@ impl DatagramSessionHandler for DnsDatagramHandler {
44
104
let ttl = self . ttl ;
45
105
let reverse_mapping_v4 = self . reverse_mapping_v4 . clone ( ) ;
46
106
let reverse_mapping_v6 = self . reverse_mapping_v6 . clone ( ) ;
107
+ let new_notify = self . new_notify . clone ( ) ;
47
108
tokio:: spawn ( async move {
48
109
let mut send_ready = true ;
49
110
while let Some ( ( dest, buf) ) = poll_fn ( |cx| {
@@ -65,18 +126,25 @@ impl DatagramSessionHandler for DnsDatagramHandler {
65
126
} ;
66
127
let mut res_code = ResponseCode :: NoError ;
67
128
let mut ans_records = Vec :: with_capacity ( msg. queries ( ) . len ( ) ) ;
129
+ let mut notify_cache_update = false ;
68
130
for query in msg. queries ( ) {
69
131
let name = query. name ( ) ;
70
132
let name_str = name. to_lowercase ( ) . to_ascii ( ) ;
71
- #[ allow( unreachable_code) ]
72
133
match query. query_type ( ) {
73
134
RecordType :: A => {
74
135
let ips = match resolver. resolve_ipv4 ( name_str. clone ( ) ) . await {
75
136
Ok ( addrs) => addrs,
76
- Err ( _) => ( res_code = ResponseCode :: NXDomain , continue ) . 1 ,
137
+ Err ( _) => {
138
+ res_code = ResponseCode :: NXDomain ;
139
+ continue ;
140
+ }
77
141
} ;
78
142
let mut reverse_mapping = reverse_mapping_v4. lock ( ) . unwrap ( ) ;
79
143
for ip in & ips {
144
+ notify_cache_update |= reverse_mapping
145
+ . peek_mut ( ip)
146
+ . filter ( |n| * n == & name_str)
147
+ . is_none ( ) ;
80
148
reverse_mapping. get_or_insert ( * ip, || name_str. clone ( ) ) ;
81
149
}
82
150
ans_records. extend (
@@ -88,20 +156,34 @@ impl DatagramSessionHandler for DnsDatagramHandler {
88
156
RecordType :: AAAA => {
89
157
let ips = match resolver. resolve_ipv6 ( name_str. clone ( ) ) . await {
90
158
Ok ( addrs) => addrs,
91
- Err ( _) => ( res_code = ResponseCode :: NXDomain , continue ) . 1 ,
159
+ Err ( _) => {
160
+ res_code = ResponseCode :: NXDomain ;
161
+ continue ;
162
+ }
92
163
} ;
93
164
let mut reverse_mapping = reverse_mapping_v6. lock ( ) . unwrap ( ) ;
94
165
for ip in & ips {
166
+ notify_cache_update |= reverse_mapping
167
+ . peek_mut ( ip)
168
+ . filter ( |n| * n == & name_str)
169
+ . is_none ( ) ;
95
170
reverse_mapping. get_or_insert ( * ip, || name_str. clone ( ) ) ;
96
171
}
97
172
ans_records. extend ( ips. into_iter ( ) . map ( |addr| {
98
173
Record :: from_rdata ( name. clone ( ) , ttl, RData :: AAAA ( addr) )
99
174
} ) )
100
175
}
101
176
// TODO: SRV
102
- _ => ( res_code = ResponseCode :: NotImp , continue ) . 1 ,
177
+ _ => {
178
+ res_code = ResponseCode :: NotImp ;
179
+ continue ;
180
+ }
103
181
}
104
182
}
183
+ if notify_cache_update {
184
+ new_notify. notify_one ( ) ;
185
+ }
186
+
105
187
* msg. set_message_type ( MessageType :: Response )
106
188
. set_response_code ( res_code)
107
189
. answers_mut ( ) = ans_records;
0 commit comments