Skip to content

Commit 28e8628

Browse files
authored
feat(dvc): add DvcChannelListener for multi-instance DVC support (#1142)
Signed-off-by: uchouT <i@uchout.moe>
1 parent 9a1ac30 commit 28e8628

File tree

3 files changed

+221
-100
lines changed

3 files changed

+221
-100
lines changed

crates/ironrdp-dvc/src/client.rs

Lines changed: 192 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
use alloc::boxed::Box;
2+
use alloc::collections::btree_map::BTreeMap;
13
use alloc::vec::Vec;
24
use core::any::TypeId;
35
use core::fmt;
46

7+
use crate::alloc::borrow::ToOwned as _;
58
use ironrdp_core::{Decode as _, DecodeResult, ReadCursor, impl_as_any};
69
use ironrdp_pdu::{self as pdu, decode_err, encode_err, pdu_other_err};
710
use ironrdp_svc::{ChannelFlags, CompressionCondition, SvcClientProcessor, SvcMessage, SvcProcessor};
@@ -13,10 +16,46 @@ use crate::pdu::{
1316
CapabilitiesResponsePdu, CapsVersion, ClosePdu, CreateResponsePdu, CreationStatus, DrdynvcClientPdu,
1417
DrdynvcServerPdu,
1518
};
16-
use crate::{DvcProcessor, DynamicChannelSet, DynamicVirtualChannel, encode_dvc_messages};
19+
use crate::{DvcProcessor, DynamicChannelId, DynamicChannelName, DynamicVirtualChannel, encode_dvc_messages};
1720

1821
pub trait DvcClientProcessor: DvcProcessor {}
1922

23+
pub trait DvcChannelListener: Send {
24+
fn channel_name(&self) -> &str;
25+
26+
/// Called for each incoming DYNVC_CREATE_REQ matching this name.
27+
/// Return `None` to reject (NO_LISTENER).
28+
fn create(&mut self) -> Option<Box<dyn DvcProcessor>>;
29+
}
30+
31+
pub type DynamicChannelListener = Box<dyn DvcChannelListener>;
32+
33+
/// For pre-registered DVC
34+
struct OnceListener {
35+
inner: Option<Box<dyn DvcProcessor>>,
36+
}
37+
38+
impl OnceListener {
39+
fn new(dvc_processor: impl DvcProcessor + 'static) -> Self {
40+
Self {
41+
inner: Some(Box::new(dvc_processor)),
42+
}
43+
}
44+
}
45+
46+
impl DvcChannelListener for OnceListener {
47+
fn channel_name(&self) -> &str {
48+
self.inner
49+
.as_ref()
50+
.expect("channel name called after created")
51+
.channel_name()
52+
}
53+
54+
fn create(&mut self) -> Option<Box<dyn DvcProcessor>> {
55+
self.inner.take()
56+
}
57+
}
58+
2059
/// DRDYNVC Static Virtual Channel (the Remote Desktop Protocol: Dynamic Virtual Channel Extension)
2160
///
2261
/// It adds support for dynamic virtual channels (DVC).
@@ -51,22 +90,64 @@ impl DrdynvcClient {
5190
}
5291
}
5392

54-
// FIXME(#61): it’s likely we want to enable adding dynamic channels at any point during the session (message passing? other approach?)
55-
93+
/// Registers a pre-initialized dynamic virtual channel with the [`DrdynvcClient`],
94+
/// making it available for immediate use when the session starts.
95+
///
96+
/// # Note
97+
///
98+
/// If a listener or a pre-registered channel with the same name already exists,
99+
/// it will be silently overwritten.
56100
#[must_use]
57101
pub fn with_dynamic_channel<T>(mut self, channel: T) -> Self
58102
where
59103
T: DvcProcessor + 'static,
60104
{
61-
self.dynamic_channels.insert(channel);
105+
self.dynamic_channels.register_once(channel);
62106
self
63107
}
64108

109+
/// Attaches a pre-initialized dynamic virtual channel with the [`DrdynvcClient`],
110+
/// making it available for immediate use when the session starts.
111+
///
112+
/// # Note
113+
///
114+
/// If a listener or a pre-registered channel with the same name already exists,
115+
/// it will be silently overwritten.
65116
pub fn attach_dynamic_channel<T>(&mut self, channel: T)
66117
where
67118
T: DvcProcessor + 'static,
68119
{
69-
self.dynamic_channels.insert(channel);
120+
self.dynamic_channels.register_once(channel);
121+
}
122+
123+
/// Bind a listener.
124+
///
125+
/// # Note
126+
///
127+
/// * Doesn't support [TypeId] lookup via [DrdynvcClient::get_dvc_by_type_id].
128+
/// * If a listener or a pre-registered channel with the same name already exists,
129+
/// it will be silently overwritten.
130+
#[must_use]
131+
pub fn with_listener<T>(mut self, listener: T) -> Self
132+
where
133+
T: DvcChannelListener + 'static,
134+
{
135+
self.dynamic_channels.register_listener(listener);
136+
self
137+
}
138+
139+
/// Attaches a listener.
140+
///
141+
/// # Note
142+
///
143+
/// * Doesn't support [TypeId] lookup via [DrdynvcClient::get_dvc_by_type_id].
144+
/// * If a listener or a pre-registered channel with the same name already exists,
145+
/// it will be silently overwritten.
146+
pub fn attach_listener<T>(&mut self, listener: T)
147+
where
148+
T: DvcChannelListener + 'static,
149+
{
150+
self.dynamic_channels.register_listener(listener);
70151
}
71152

72153
pub fn get_dvc_by_type_id<T>(&self) -> Option<&DynamicVirtualChannel>
@@ -127,20 +208,12 @@ impl SvcProcessor for DrdynvcClient {
127208
responses.push(self.create_capabilities_response(CapsVersion::V2));
128209
}
129210

130-
let channel_exists = self.dynamic_channels.get_by_channel_name(&channel_name).is_some();
131-
let (creation_status, start_messages) = if channel_exists {
132-
// If we have a handler for this channel, attach the channel ID
133-
// and get any start messages.
134-
self.dynamic_channels
135-
.attach_channel_id(channel_name.clone(), channel_id);
136-
let dynamic_channel = self
137-
.dynamic_channels
138-
.get_by_channel_name_mut(&channel_name)
139-
.expect("channel exists");
140-
(CreationStatus::OK, dynamic_channel.start()?)
141-
} else {
142-
(CreationStatus::NO_LISTENER, Vec::new())
143-
};
211+
let (creation_status, start_messages) =
212+
if let Some(dvc) = self.dynamic_channels.try_create_channel(&channel_name, channel_id) {
213+
(CreationStatus::OK, dvc.start()?)
214+
} else {
215+
(CreationStatus::NO_LISTENER, Vec::new())
216+
};
144217

145218
let create_response = DrdynvcClientPdu::Create(CreateResponsePdu::new(channel_id, creation_status));
146219
debug!("Send DVC Create Response PDU: {create_response:?}");
@@ -182,6 +255,106 @@ impl SvcProcessor for DrdynvcClient {
182255
}
183256
}
184257

258+
struct ListenerEntry {
259+
listener: DynamicChannelListener,
260+
/// `Some` only for channels registered via `with_dynamic_channel<T>()`.
261+
type_id: Option<TypeId>,
262+
}
263+
264+
struct DynamicChannelSet {
265+
listeners: BTreeMap<DynamicChannelName, ListenerEntry>,
266+
active_channels: BTreeMap<DynamicChannelId, DynamicVirtualChannel>,
267+
type_id_to_channel_id: BTreeMap<TypeId, DynamicChannelId>,
268+
}
269+
270+
impl DynamicChannelSet {
271+
#[inline]
272+
fn new() -> Self {
273+
Self {
274+
listeners: BTreeMap::new(),
275+
active_channels: BTreeMap::new(),
276+
type_id_to_channel_id: BTreeMap::new(),
277+
}
278+
}
279+
280+
fn register_listener<T: DvcChannelListener + 'static>(&mut self, listener: T) {
281+
let name = listener.channel_name().to_owned();
282+
self.listeners.insert(
283+
name,
284+
ListenerEntry {
285+
listener: Box::new(listener),
286+
type_id: None,
287+
},
288+
);
289+
}
290+
291+
fn register_once<T: DvcProcessor + 'static>(&mut self, channel: T) {
292+
let name = channel.channel_name().to_owned();
293+
self.listeners.insert(
294+
name,
295+
ListenerEntry {
296+
listener: Box::new(OnceListener::new(channel)),
297+
type_id: Some(TypeId::of::<T>()),
298+
},
299+
);
300+
}
301+
302+
fn try_create_channel(
303+
&mut self,
304+
name: &DynamicChannelName,
305+
channel_id: DynamicChannelId,
306+
) -> Option<&mut DynamicVirtualChannel> {
307+
let entry = self.listeners.get_mut(name)?;
308+
let processor = entry.listener.create()?;
309+
310+
if let Some(type_id) = entry.type_id {
311+
self.type_id_to_channel_id.insert(type_id, channel_id);
312+
}
313+
314+
let mut dvc = DynamicVirtualChannel::from_boxed(processor);
315+
dvc.channel_id = Some(channel_id);
316+
let dvc = match self.active_channels.entry(channel_id) {
317+
alloc::collections::btree_map::Entry::Occupied(mut e) => {
318+
e.insert(dvc);
319+
e.into_mut()
320+
}
321+
alloc::collections::btree_map::Entry::Vacant(e) => e.insert(dvc),
322+
};
323+
Some(dvc)
324+
}
325+
326+
fn get_by_type_id(&self, type_id: TypeId) -> Option<&DynamicVirtualChannel> {
327+
self.type_id_to_channel_id
328+
.get(&type_id)
329+
.and_then(|id| self.active_channels.get(id))
330+
}
331+
332+
fn get_by_channel_id(&self, id: DynamicChannelId) -> Option<&DynamicVirtualChannel> {
333+
self.active_channels.get(&id)
334+
}
335+
336+
fn get_by_channel_id_mut(&mut self, id: DynamicChannelId) -> Option<&mut DynamicVirtualChannel> {
337+
self.active_channels.get_mut(&id)
338+
}
339+
340+
fn remove_by_channel_id(&mut self, id: DynamicChannelId) {
341+
if let Some(dvc) = self.active_channels.remove(&id) {
342+
let type_id = dvc.processor_type_id();
343+
344+
// Only matters for pre-registered channels
345+
if let alloc::collections::btree_map::Entry::Occupied(entry) = self.type_id_to_channel_id.entry(type_id)
346+
&& entry.get() == &id
347+
{
348+
entry.remove();
349+
}
350+
}
351+
}
352+
353+
#[inline]
354+
fn values(&self) -> impl Iterator<Item = &DynamicVirtualChannel> {
355+
self.active_channels.values()
356+
}
357+
}
185358
impl SvcClientProcessor for DrdynvcClient {}
186359

187360
fn decode_dvc_message(user_data: &[u8]) -> DecodeResult<DrdynvcServerPdu> {

crates/ironrdp-dvc/src/lib.rs

Lines changed: 8 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44

55
extern crate alloc;
66

7+
use core::any::TypeId;
8+
79
use alloc::boxed::Box;
8-
use alloc::collections::BTreeMap;
910
use alloc::string::String;
1011
use alloc::vec::Vec;
11-
use core::any::TypeId;
1212

1313
use pdu::DrdynvcDataPdu;
1414

15-
use crate::alloc::borrow::ToOwned as _;
1615
// Re-export ironrdp_pdu crate for convenience
1716
#[rustfmt::skip] // do not re-order this pub use
1817
pub use ironrdp_pdu;
@@ -111,14 +110,18 @@ pub struct DynamicVirtualChannel {
111110
}
112111

113112
impl DynamicVirtualChannel {
114-
fn new<T: DvcProcessor + 'static>(handler: T) -> Self {
113+
fn from_boxed(processor: Box<dyn DvcProcessor + Send>) -> Self {
115114
Self {
116-
channel_processor: Box::new(handler),
115+
channel_processor: processor,
117116
complete_data: CompleteData::new(),
118117
channel_id: None,
119118
}
120119
}
121120

121+
fn processor_type_id(&self) -> TypeId {
122+
self.channel_processor.as_any().type_id()
123+
}
124+
122125
pub fn is_open(&self) -> bool {
123126
self.channel_id.is_some()
124127
}
@@ -154,79 +157,5 @@ impl DynamicVirtualChannel {
154157
}
155158
}
156159

157-
struct DynamicChannelSet {
158-
channels: BTreeMap<DynamicChannelName, DynamicVirtualChannel>,
159-
name_to_channel_id: BTreeMap<DynamicChannelName, DynamicChannelId>,
160-
channel_id_to_name: BTreeMap<DynamicChannelId, DynamicChannelName>,
161-
type_id_to_name: BTreeMap<TypeId, DynamicChannelName>,
162-
}
163-
164-
impl DynamicChannelSet {
165-
#[inline]
166-
fn new() -> Self {
167-
Self {
168-
channels: BTreeMap::new(),
169-
name_to_channel_id: BTreeMap::new(),
170-
channel_id_to_name: BTreeMap::new(),
171-
type_id_to_name: BTreeMap::new(),
172-
}
173-
}
174-
175-
fn insert<T: DvcProcessor + 'static>(&mut self, channel: T) -> Option<DynamicVirtualChannel> {
176-
let name = channel.channel_name().to_owned();
177-
self.type_id_to_name.insert(TypeId::of::<T>(), name.clone());
178-
self.channels.insert(name, DynamicVirtualChannel::new(channel))
179-
}
180-
181-
fn attach_channel_id(&mut self, name: DynamicChannelName, id: DynamicChannelId) -> Option<DynamicChannelId> {
182-
self.channel_id_to_name.insert(id, name.clone());
183-
self.name_to_channel_id.insert(name.clone(), id);
184-
let dvc = self.get_by_channel_name_mut(&name)?;
185-
let old_id = dvc.channel_id;
186-
dvc.channel_id = Some(id);
187-
old_id
188-
}
189-
190-
fn get_by_type_id(&self, type_id: TypeId) -> Option<&DynamicVirtualChannel> {
191-
self.type_id_to_name
192-
.get(&type_id)
193-
.and_then(|name| self.channels.get(name))
194-
}
195-
196-
fn get_by_channel_name(&self, name: &DynamicChannelName) -> Option<&DynamicVirtualChannel> {
197-
self.channels.get(name)
198-
}
199-
200-
fn get_by_channel_name_mut(&mut self, name: &DynamicChannelName) -> Option<&mut DynamicVirtualChannel> {
201-
self.channels.get_mut(name)
202-
}
203-
204-
fn get_by_channel_id(&self, id: DynamicChannelId) -> Option<&DynamicVirtualChannel> {
205-
self.channel_id_to_name
206-
.get(&id)
207-
.and_then(|name| self.channels.get(name))
208-
}
209-
210-
fn get_by_channel_id_mut(&mut self, id: DynamicChannelId) -> Option<&mut DynamicVirtualChannel> {
211-
self.channel_id_to_name
212-
.get(&id)
213-
.and_then(|name| self.channels.get_mut(name))
214-
}
215-
216-
fn remove_by_channel_id(&mut self, id: DynamicChannelId) -> Option<DynamicChannelId> {
217-
if let Some(name) = self.channel_id_to_name.remove(&id) {
218-
return self.name_to_channel_id.remove(&name);
219-
// Channels are retained in the `self.channels` and `self.type_id_to_name` map to allow potential
220-
// dynamic re-addition by the server.
221-
}
222-
None
223-
}
224-
225-
#[inline]
226-
fn values(&self) -> impl Iterator<Item = &DynamicVirtualChannel> {
227-
self.channels.values()
228-
}
229-
}
230-
231160
pub type DynamicChannelName = String;
232161
pub type DynamicChannelId = u32;

0 commit comments

Comments
 (0)