Skip to content

Commit 813f82d

Browse files
committed
cache dns-server reverse map
1 parent 68350d3 commit 813f82d

File tree

5 files changed

+164
-24
lines changed

5 files changed

+164
-24
lines changed

ytflow/src/config/plugin/dns_server.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use serde::Deserialize;
44

55
use crate::config::factory::*;
66
use crate::config::*;
7+
use crate::data::PluginId;
78

89
#[cfg_attr(not(feature = "plugins"), allow(dead_code))]
910
#[derive(Deserialize)]
@@ -16,12 +17,17 @@ pub struct DnsServerFactory<'a> {
1617
tcp_map_back: HashSet<&'a str>,
1718
#[serde(borrow)]
1819
udp_map_back: HashSet<&'a str>,
20+
#[serde(skip)]
21+
plugin_id: Option<PluginId>,
1922
}
2023

2124
impl<'de> DnsServerFactory<'de> {
2225
pub(in super::super) fn parse(plugin: &'de Plugin) -> ConfigResult<ParsedPlugin<'de, Self>> {
23-
let Plugin { name, param, .. } = plugin;
24-
let config: Self = parse_param(name, param)?;
26+
let Plugin {
27+
name, param, id, ..
28+
} = plugin;
29+
let mut config: Self = parse_param(name, param)?;
30+
config.plugin_id = *id;
2531
let resolver = config.resolver;
2632
Ok(ParsedPlugin {
2733
requires: [Descriptor {
@@ -61,10 +67,24 @@ impl<'de> DnsServerFactory<'de> {
6167
impl<'de> Factory for DnsServerFactory<'de> {
6268
#[cfg(feature = "plugins")]
6369
fn load(&mut self, plugin_name: String, set: &mut PartialPluginSet) -> LoadResult<()> {
70+
use crate::data::PluginCache;
6471
use crate::plugin::dns_server;
6572
use crate::plugin::null::Null;
6673
use crate::plugin::reject::RejectHandler;
6774

75+
let db = set
76+
.db
77+
.ok_or_else(|| LoadError::DatabaseRequired {
78+
plugin: plugin_name.clone(),
79+
})?
80+
.clone();
81+
let cache = PluginCache::new(
82+
self.plugin_id.ok_or_else(|| LoadError::DatabaseRequired {
83+
plugin: plugin_name.clone(),
84+
})?,
85+
Some(db.clone()),
86+
);
87+
6888
let mut err = None;
6989
let factory = Arc::new_cyclic(|weak| {
7090
set.datagram_handlers
@@ -75,7 +95,7 @@ impl<'de> Factory for DnsServerFactory<'de> {
7595
err = Some(e);
7696
Arc::downgrade(&(Arc::new(Null) as _))
7797
});
78-
dns_server::DnsDatagramHandler::new(self.concurrency_limit as usize, resolver, self.ttl)
98+
dns_server::DnsServer::new(self.concurrency_limit as usize, resolver, self.ttl, cache)
7999
});
80100
if let Some(e) = err {
81101
set.errors.push(e);
@@ -127,7 +147,10 @@ impl<'de> Factory for DnsServerFactory<'de> {
127147

128148
set.fully_constructed
129149
.datagram_handlers
130-
.insert(plugin_name + ".udp", factory);
150+
.insert(plugin_name + ".udp", factory.clone());
151+
set.fully_constructed
152+
.long_running_tasks
153+
.push(tokio::spawn(dns_server::cache_writer(factory)));
131154
Ok(())
132155
}
133156
}

ytflow/src/plugin/dns_server/datagram.rs

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,100 @@
1+
use std::collections::BTreeMap;
2+
use std::hash::Hash;
13
use std::net::{Ipv4Addr, Ipv6Addr};
24
use std::num::NonZeroUsize;
35
use std::sync::{Arc, Mutex, Weak};
46

57
use futures::future::poll_fn;
68
use lru::LruCache;
7-
use tokio::sync::Semaphore;
9+
use serde::{Deserialize, Serialize};
10+
use tokio::sync::{Notify, Semaphore};
811
use trust_dns_resolver::proto::op::{Message as DnsMessage, MessageType, ResponseCode};
912
use trust_dns_resolver::proto::rr::{RData, Record, RecordType};
1013
use trust_dns_resolver::proto::serialize::binary::BinDecodable;
1114

15+
use crate::data::PluginCache;
1216
use crate::flow::*;
1317

14-
const CACHE_CAPAICTY: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1024) };
18+
const CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(1024).unwrap();
19+
const REVERSE_MAPPING_V4_CACHE_KEY: &str = "rev_v4";
20+
const REVERSE_MAPPING_V6_CACHE_KEY: &str = "rev_v6";
1521

16-
pub struct DnsDatagramHandler {
22+
pub struct DnsServer {
1723
concurrency_limit: Arc<Semaphore>,
1824
resolver: Weak<dyn Resolver>,
1925
ttl: u32,
2026
pub(super) reverse_mapping_v4: Arc<Mutex<LruCache<Ipv4Addr, String>>>,
2127
pub(super) reverse_mapping_v6: Arc<Mutex<LruCache<Ipv6Addr, String>>>,
28+
plugin_cache: PluginCache,
29+
pub(super) new_notify: Arc<Notify>,
2230
}
2331

24-
impl DnsDatagramHandler {
25-
pub fn new(concurrency_limit: usize, resolver: Weak<dyn Resolver>, ttl: u32) -> Self {
32+
#[derive(Debug, Clone, Default, PartialOrd, Ord, PartialEq, Eq, Serialize, Deserialize)]
33+
#[serde(transparent)]
34+
struct ReverseMappingCache<T: Ord>(BTreeMap<T, String>);
35+
36+
impl DnsServer {
37+
pub fn new(
38+
concurrency_limit: usize,
39+
resolver: Weak<dyn Resolver>,
40+
ttl: u32,
41+
plugin_cache: PluginCache,
42+
) -> Self {
2643
let concurrency_limit = Arc::new(Semaphore::new(concurrency_limit));
27-
DnsDatagramHandler {
44+
let mut reverse_mapping_v4 = LruCache::new(CACHE_CAPACITY);
45+
let mut reverse_mapping_v6 = LruCache::new(CACHE_CAPACITY);
46+
if let Some(reverse_mapping_v4_cache) = plugin_cache
47+
.get::<ReverseMappingCache<_>>(REVERSE_MAPPING_V4_CACHE_KEY)
48+
.ok()
49+
.flatten()
50+
{
51+
for (k, v) in reverse_mapping_v4_cache.0 {
52+
reverse_mapping_v4.put(k, v);
53+
}
54+
}
55+
if let Some(reverse_mapping_v6_cache) = plugin_cache
56+
.get::<ReverseMappingCache<_>>(REVERSE_MAPPING_V6_CACHE_KEY)
57+
.ok()
58+
.flatten()
59+
{
60+
for (k, v) in reverse_mapping_v6_cache.0 {
61+
reverse_mapping_v6.put(k, v);
62+
}
63+
}
64+
DnsServer {
2865
concurrency_limit,
2966
resolver,
3067
ttl,
31-
reverse_mapping_v4: Arc::new(Mutex::new(LruCache::new(CACHE_CAPAICTY))),
32-
reverse_mapping_v6: Arc::new(Mutex::new(LruCache::new(CACHE_CAPAICTY))),
68+
reverse_mapping_v4: Arc::new(Mutex::new(reverse_mapping_v4)),
69+
reverse_mapping_v6: Arc::new(Mutex::new(reverse_mapping_v6)),
70+
plugin_cache,
71+
new_notify: Arc::new(Notify::new()),
3372
}
3473
}
74+
75+
fn save_reverse_mapping_cache<T: Serialize + Hash + Eq + Ord + Clone>(
76+
&self,
77+
cache: &Mutex<LruCache<T, String>>,
78+
key: &str,
79+
) {
80+
let cache = {
81+
let inner = cache.lock().unwrap();
82+
ReverseMappingCache(
83+
(&*inner)
84+
.iter()
85+
.map(|(k, v)| (k.clone(), v.clone()))
86+
.collect(),
87+
)
88+
};
89+
self.plugin_cache.set(key, &cache).ok();
90+
}
91+
pub(crate) fn save_cache(&self) {
92+
self.save_reverse_mapping_cache(&self.reverse_mapping_v4, REVERSE_MAPPING_V4_CACHE_KEY);
93+
self.save_reverse_mapping_cache(&self.reverse_mapping_v6, REVERSE_MAPPING_V6_CACHE_KEY);
94+
}
3595
}
3696

37-
impl DatagramSessionHandler for DnsDatagramHandler {
97+
impl DatagramSessionHandler for DnsServer {
3898
fn on_session(&self, mut session: Box<dyn DatagramSession>, _context: Box<FlowContext>) {
3999
let resolver = match self.resolver.upgrade() {
40100
Some(resolver) => resolver,
@@ -44,6 +104,7 @@ impl DatagramSessionHandler for DnsDatagramHandler {
44104
let ttl = self.ttl;
45105
let reverse_mapping_v4 = self.reverse_mapping_v4.clone();
46106
let reverse_mapping_v6 = self.reverse_mapping_v6.clone();
107+
let new_notify = self.new_notify.clone();
47108
tokio::spawn(async move {
48109
let mut send_ready = true;
49110
while let Some((dest, buf)) = poll_fn(|cx| {
@@ -65,18 +126,25 @@ impl DatagramSessionHandler for DnsDatagramHandler {
65126
};
66127
let mut res_code = ResponseCode::NoError;
67128
let mut ans_records = Vec::with_capacity(msg.queries().len());
129+
let mut notify_cache_update = false;
68130
for query in msg.queries() {
69131
let name = query.name();
70132
let name_str = name.to_lowercase().to_ascii();
71-
#[allow(unreachable_code)]
72133
match query.query_type() {
73134
RecordType::A => {
74135
let ips = match resolver.resolve_ipv4(name_str.clone()).await {
75136
Ok(addrs) => addrs,
76-
Err(_) => (res_code = ResponseCode::NXDomain, continue).1,
137+
Err(_) => {
138+
res_code = ResponseCode::NXDomain;
139+
continue;
140+
}
77141
};
78142
let mut reverse_mapping = reverse_mapping_v4.lock().unwrap();
79143
for ip in &ips {
144+
notify_cache_update |= reverse_mapping
145+
.peek_mut(ip)
146+
.filter(|n| *n == &name_str)
147+
.is_none();
80148
reverse_mapping.get_or_insert(*ip, || name_str.clone());
81149
}
82150
ans_records.extend(
@@ -88,20 +156,34 @@ impl DatagramSessionHandler for DnsDatagramHandler {
88156
RecordType::AAAA => {
89157
let ips = match resolver.resolve_ipv6(name_str.clone()).await {
90158
Ok(addrs) => addrs,
91-
Err(_) => (res_code = ResponseCode::NXDomain, continue).1,
159+
Err(_) => {
160+
res_code = ResponseCode::NXDomain;
161+
continue;
162+
}
92163
};
93164
let mut reverse_mapping = reverse_mapping_v6.lock().unwrap();
94165
for ip in &ips {
166+
notify_cache_update |= reverse_mapping
167+
.peek_mut(ip)
168+
.filter(|n| *n == &name_str)
169+
.is_none();
95170
reverse_mapping.get_or_insert(*ip, || name_str.clone());
96171
}
97172
ans_records.extend(ips.into_iter().map(|addr| {
98173
Record::from_rdata(name.clone(), ttl, RData::AAAA(addr))
99174
}))
100175
}
101176
// TODO: SRV
102-
_ => (res_code = ResponseCode::NotImp, continue).1,
177+
_ => {
178+
res_code = ResponseCode::NotImp;
179+
continue;
180+
}
103181
}
104182
}
183+
if notify_cache_update {
184+
new_notify.notify_one();
185+
}
186+
105187
*msg.set_message_type(MessageType::Response)
106188
.set_response_code(res_code)
107189
.answers_mut() = ans_records;

ytflow/src/plugin/dns_server/map_back.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::task::{ready, Context, Poll};
55

66
use lru::LruCache;
77

8-
use super::DnsDatagramHandler;
8+
use super::DnsServer;
99
use crate::flow::*;
1010

1111
#[derive(Clone)]
@@ -46,8 +46,7 @@ pub struct MapBackStreamHandler {
4646
}
4747

4848
impl MapBackStreamHandler {
49-
pub fn new(handler: &DnsDatagramHandler, next: Weak<dyn StreamHandler>) -> Self {
50-
// TODO: persist mapping into cache
49+
pub fn new(handler: &DnsServer, next: Weak<dyn StreamHandler>) -> Self {
5150
Self {
5251
back_mapper: BackMapper {
5352
reverse_mapping_v4: handler.reverse_mapping_v4.clone(),
@@ -104,7 +103,7 @@ struct MapBackDatagramSession {
104103
}
105104

106105
impl MapBackDatagramSessionHandler {
107-
pub fn new(handler: &DnsDatagramHandler, next: Weak<dyn DatagramSessionHandler>) -> Self {
106+
pub fn new(handler: &DnsServer, next: Weak<dyn DatagramSessionHandler>) -> Self {
108107
Self {
109108
back_mapper: BackMapper {
110109
reverse_mapping_v4: handler.reverse_mapping_v4.clone(),
Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,41 @@
11
mod datagram;
22
mod map_back;
33

4-
pub use datagram::DnsDatagramHandler;
4+
use std::sync::Arc;
5+
6+
pub use datagram::DnsServer;
57
pub use map_back::{MapBackDatagramSessionHandler, MapBackStreamHandler};
8+
9+
pub async fn cache_writer(plugin: Arc<DnsServer>) {
10+
let (plugin, notify) = {
11+
let notify = plugin.new_notify.clone();
12+
let weak = Arc::downgrade(&plugin);
13+
drop(plugin);
14+
(weak, notify)
15+
};
16+
if plugin.strong_count() == 0 {
17+
panic!("dns-server has no strong reference left for cache_writer");
18+
}
19+
20+
use tokio::select;
21+
use tokio::time::{sleep, Duration};
22+
loop {
23+
let mut notified_fut = notify.notified();
24+
let mut sleep_fut = sleep(Duration::from_secs(3600));
25+
'debounce: loop {
26+
select! {
27+
_ = notified_fut => {
28+
notified_fut = notify.notified();
29+
sleep_fut = sleep(Duration::from_secs(3));
30+
}
31+
_ = sleep_fut => {
32+
break 'debounce;
33+
}
34+
}
35+
}
36+
match plugin.upgrade() {
37+
Some(plugin) => plugin.save_cache(),
38+
None => break,
39+
}
40+
}
41+
}

ytflow/src/plugin/fakeip.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use tokio::sync::Notify;
1111
use crate::data::PluginCache;
1212
use crate::flow::*;
1313

14-
const CACHE_SIZE: NonZeroUsize = NonZeroUsize::new(1000).unwrap();
14+
const CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(1000).unwrap();
1515
const PLUGIN_CACHE_KEY: &str = "map";
1616

1717
struct Inner {
@@ -35,7 +35,7 @@ pub struct FakeIp {
3535

3636
impl FakeIp {
3737
pub fn new(prefix_v4: [u8; 2], prefix_v6: [u8; 14], plugin_cache: PluginCache) -> Self {
38-
let mut lru = LruCache::new(CACHE_SIZE);
38+
let mut lru = LruCache::new(CACHE_CAPACITY);
3939
let inner = match plugin_cache
4040
.get::<InnerCache>(PLUGIN_CACHE_KEY)
4141
.ok()

0 commit comments

Comments
 (0)