@@ -5,6 +5,7 @@ use std::{
55 time:: Duration ,
66} ;
77
8+ use arc_swap:: ArcSwapOption ;
89use hickory_server:: ServerFuture ;
910use landscape_common:: {
1011 config:: DnsRuntimeConfig , dns:: ChainDnsServerInitInfo , event:: DnsMetricMessage ,
@@ -25,7 +26,7 @@ pub(crate) mod rule;
2526#[ derive( Clone ) ]
2627pub struct LandscapeDnsServer {
2728 pub status : WatchService ,
28- flow_dns_server : Arc < Mutex < HashMap < u32 , ( DnsRequestHandler , CancellationToken ) > > > ,
29+ flow_dns_server : Arc < Mutex < HashMap < u32 , Arc < FlowServerEntry > > > > ,
2930 pub addr : SocketAddr ,
3031 pub msg_tx : Option < mpsc:: Sender < DnsMetricMessage > > ,
3132 pub doh : Option < DohListenerConfig > ,
@@ -40,6 +41,25 @@ pub struct DohListenerConfig {
4041 pub http_endpoint : String ,
4142}
4243
44+ struct FlowServerRuntime {
45+ handler : DnsRequestHandler ,
46+ _token : CancellationToken ,
47+ }
48+
49+ struct FlowServerEntry {
50+ refresh_lock : Mutex < ( ) > ,
51+ runtime : Arc < ArcSwapOption < FlowServerRuntime > > ,
52+ }
53+
54+ impl FlowServerEntry {
55+ fn new ( ) -> Self {
56+ Self {
57+ refresh_lock : Mutex :: new ( ( ) ) ,
58+ runtime : Arc :: new ( ArcSwapOption :: new ( None ) ) ,
59+ }
60+ }
61+ }
62+
4363impl LandscapeDnsServer {
4464 pub fn new (
4565 listen_port : u16 ,
@@ -67,33 +87,35 @@ impl LandscapeDnsServer {
6787 info : ChainDnsServerInitInfo ,
6888 dns_config : DnsRuntimeConfig ,
6989 ) {
70- {
90+ let entry = {
7191 let mut lock = self . flow_dns_server . lock ( ) . await ;
72- if let Some ( ( old_handler, _) ) = lock. get_mut ( & flow_id) {
73- old_handler. renew_rules ( info, dns_config) . await ;
74- return ;
75- }
92+ lock. entry ( flow_id) . or_insert_with ( || Arc :: new ( FlowServerEntry :: new ( ) ) ) . clone ( )
93+ } ;
94+
95+ let _refresh_guard = entry. refresh_lock . lock ( ) . await ;
96+ if let Some ( runtime) = entry. runtime . load_full ( ) {
97+ runtime. handler . renew_rules ( info, dns_config) . await ;
98+ return ;
7699 }
77100
78101 let handler = DnsRequestHandler :: new ( info, dns_config, flow_id, self . msg_tx . clone ( ) ) ;
79102 let token = start_dns_server ( flow_id, self . addr , self . doh . clone ( ) , handler. clone ( ) ) . await ;
80-
81- {
82- let mut lock = self . flow_dns_server . lock ( ) . await ;
83- lock. insert ( flow_id, ( handler, token) ) ;
103+ if token. is_cancelled ( ) {
104+ tracing:: error!( "[flow: {flow_id}]: DNS server start failed, runtime not registered" ) ;
105+ return ;
84106 }
107+
108+ entry. runtime . store ( Some ( Arc :: new ( FlowServerRuntime { handler, _token : token } ) ) ) ;
85109 }
86110
87111 pub async fn check_domain ( & self , req : CheckDnsReq ) -> CheckChainDnsResult {
88- let handler = {
112+ let entry = {
89113 let flow_server = self . flow_dns_server . lock ( ) . await ;
90- if let Some ( ( handler, _) ) = flow_server. get ( & req. flow_id ) {
91- Some ( handler. clone ( ) )
92- } else {
93- None
94- }
114+ flow_server. get ( & req. flow_id ) . cloned ( )
95115 } ;
96116
117+ let handler = entry
118+ . and_then ( |entry| entry. runtime . load_full ( ) . map ( |runtime| runtime. handler . clone ( ) ) ) ;
97119 if let Some ( handler) = handler {
98120 handler. check_domain ( & req. get_domain ( ) , convert_record_type ( req. record_type ) ) . await
99121 } else {
@@ -153,3 +175,56 @@ pub async fn start_dns_server(
153175
154176 token
155177}
178+
179+ #[ cfg( test) ]
180+ mod tests {
181+ use super :: * ;
182+ use landscape_common:: { config:: DnsRuntimeConfig , dns:: ChainDnsServerInitInfo } ;
183+
184+ fn run_async_test ( test : impl std:: future:: Future < Output = ( ) > ) {
185+ tokio:: runtime:: Builder :: new_current_thread ( ) . enable_all ( ) . build ( ) . unwrap ( ) . block_on ( test) ;
186+ }
187+
188+ fn test_dns_config ( ) -> DnsRuntimeConfig {
189+ DnsRuntimeConfig {
190+ cache_capacity : 16 ,
191+ cache_ttl : 60 ,
192+ negative_cache_ttl : 10 ,
193+ doh_listen_port : 0 ,
194+ doh_http_endpoint : String :: new ( ) ,
195+ }
196+ }
197+
198+ #[ test]
199+ fn flow_server_entry_runtime_reads_do_not_wait_on_refresh_lock ( ) {
200+ run_async_test ( async {
201+ let entry = FlowServerEntry :: new ( ) ;
202+ let handler = DnsRequestHandler :: new (
203+ ChainDnsServerInitInfo :: default ( ) ,
204+ test_dns_config ( ) ,
205+ 7 ,
206+ None ,
207+ ) ;
208+ entry. runtime . store ( Some ( Arc :: new ( FlowServerRuntime {
209+ handler,
210+ _token : CancellationToken :: new ( ) ,
211+ } ) ) ) ;
212+
213+ let _guard = entry. refresh_lock . lock ( ) . await ;
214+ let runtime = entry. runtime . load_full ( ) ;
215+
216+ assert ! ( runtime. is_some( ) ) ;
217+ assert_eq ! ( runtime. unwrap( ) . handler. flow_id, 7 ) ;
218+ } ) ;
219+ }
220+
221+ #[ test]
222+ fn flow_server_entry_allows_empty_runtime_while_refreshing ( ) {
223+ run_async_test ( async {
224+ let entry = FlowServerEntry :: new ( ) ;
225+ let _guard = entry. refresh_lock . lock ( ) . await ;
226+
227+ assert ! ( entry. runtime. load_full( ) . is_none( ) ) ;
228+ } ) ;
229+ }
230+ }
0 commit comments