Skip to content
186 changes: 167 additions & 19 deletions crates/ironrdp-dvc/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use alloc::boxed::Box;
use alloc::collections::btree_map::BTreeMap;
use alloc::vec::Vec;
use core::any::TypeId;
use core::fmt;

use crate::alloc::borrow::ToOwned as _;
use ironrdp_core::{Decode as _, DecodeResult, ReadCursor, impl_as_any};
use ironrdp_pdu::{self as pdu, decode_err, encode_err, pdu_other_err};
use ironrdp_svc::{ChannelFlags, CompressionCondition, SvcClientProcessor, SvcMessage, SvcProcessor};
Expand All @@ -13,10 +16,45 @@ use crate::pdu::{
CapabilitiesResponsePdu, CapsVersion, ClosePdu, CreateResponsePdu, CreationStatus, DrdynvcClientPdu,
DrdynvcServerPdu,
};
use crate::{DvcProcessor, DynamicChannelSet, DynamicVirtualChannel, encode_dvc_messages};
use crate::{DvcProcessor, DynamicChannelId, DynamicChannelName, DynamicVirtualChannel, encode_dvc_messages};

pub trait DvcClientProcessor: DvcProcessor {}

pub trait DvcChannelListener: Send {
fn channel_name(&self) -> &str;
/// Called for each incoming DYNVC_CREATE_REQ matching this name.
/// Return `None` to reject (NO_LISTENER).
fn create(&mut self) -> Option<Box<dyn DvcProcessor>>;
}

pub type DynamicChannelListener = Box<dyn DvcChannelListener>;

/// For pre-registered DVC
struct OnceListener {
inner: Option<Box<dyn DvcProcessor>>,
}

impl OnceListener {
fn new(dvc_processor: impl DvcProcessor + 'static) -> Self {
Self {
inner: Some(Box::new(dvc_processor)),
}
}
}

impl DvcChannelListener for OnceListener {
fn channel_name(&self) -> &str {
self.inner
.as_ref()
.expect("channel name called after created")
.channel_name()
}

fn create(&mut self) -> Option<Box<dyn DvcProcessor>> {
self.inner.take()
}
}

/// DRDYNVC Static Virtual Channel (the Remote Desktop Protocol: Dynamic Virtual Channel Extension)
///
/// It adds support for dynamic virtual channels (DVC).
Expand Down Expand Up @@ -51,22 +89,40 @@ impl DrdynvcClient {
}
}

// FIXME(#61): it’s likely we want to enable adding dynamic channels at any point during the session (message passing? other approach?)

/// Registers a pre-initialized dynamic virtual channel with the DrdynvcClient,
/// making it available for immediate use when the session starts.
#[must_use]
pub fn with_dynamic_channel<T>(mut self, channel: T) -> Self
where
T: DvcProcessor + 'static,
{
self.dynamic_channels.insert(channel);
self.dynamic_channels.register_once(channel);
self
}

/// Bind a listener. Doesn't support type id look up
#[must_use]
pub fn with_listener<T>(mut self, listener: T) -> Self
where
T: DvcChannelListener + 'static,
{
self.dynamic_channels.register_listener(listener);
self
}

/// Doesn't support type id look up
pub fn attach_listener<T>(&mut self, listener: T)
where
T: DvcChannelListener + 'static,
{
self.dynamic_channels.register_listener(listener);
}

pub fn attach_dynamic_channel<T>(&mut self, channel: T)
where
T: DvcProcessor + 'static,
{
self.dynamic_channels.insert(channel);
self.dynamic_channels.register_once(channel);
}

pub fn get_dvc_by_type_id<T>(&self) -> Option<&DynamicVirtualChannel>
Expand Down Expand Up @@ -127,20 +183,12 @@ impl SvcProcessor for DrdynvcClient {
responses.push(self.create_capabilities_response(CapsVersion::V2));
}

let channel_exists = self.dynamic_channels.get_by_channel_name(&channel_name).is_some();
let (creation_status, start_messages) = if channel_exists {
// If we have a handler for this channel, attach the channel ID
// and get any start messages.
self.dynamic_channels
.attach_channel_id(channel_name.clone(), channel_id);
let dynamic_channel = self
.dynamic_channels
.get_by_channel_name_mut(&channel_name)
.expect("channel exists");
(CreationStatus::OK, dynamic_channel.start()?)
} else {
(CreationStatus::NO_LISTENER, Vec::new())
};
let (creation_status, start_messages) =
if let Some(dvc) = self.dynamic_channels.try_create_channel(&channel_name, channel_id) {
(CreationStatus::OK, dvc.start()?)
} else {
(CreationStatus::NO_LISTENER, Vec::new())
};

let create_response = DrdynvcClientPdu::Create(CreateResponsePdu::new(channel_id, creation_status));
debug!("Send DVC Create Response PDU: {create_response:?}");
Expand Down Expand Up @@ -182,6 +230,106 @@ impl SvcProcessor for DrdynvcClient {
}
}

struct ListenerEntry {
listener: DynamicChannelListener,
/// `Some` only for channels registered via `with_dynamic_channel<T>()`.
type_id: Option<TypeId>,
}

struct DynamicChannelSet {
listeners: BTreeMap<DynamicChannelName, ListenerEntry>,
active_channels: BTreeMap<DynamicChannelId, DynamicVirtualChannel>,
type_id_to_channel_id: BTreeMap<TypeId, DynamicChannelId>,
}

impl DynamicChannelSet {
#[inline]
fn new() -> Self {
Self {
listeners: BTreeMap::new(),
active_channels: BTreeMap::new(),
type_id_to_channel_id: BTreeMap::new(),
}
}

fn register_listener<T: DvcChannelListener + 'static>(&mut self, listener: T) {
let name = listener.channel_name().to_owned();
self.listeners.insert(
name,
ListenerEntry {
listener: Box::new(listener),
type_id: None,
},
);
}

fn register_once<T: DvcProcessor + 'static>(&mut self, channel: T) {
let name = channel.channel_name().to_owned();
self.listeners.insert(
name,
ListenerEntry {
listener: Box::new(OnceListener::new(channel)),
type_id: Some(TypeId::of::<T>()),
},
);
}

fn try_create_channel(
&mut self,
name: &DynamicChannelName,
channel_id: DynamicChannelId,
) -> Option<&mut DynamicVirtualChannel> {
let entry = self.listeners.get_mut(name)?;
let processor = entry.listener.create()?;

if let Some(type_id) = entry.type_id {
self.type_id_to_channel_id.insert(type_id, channel_id);
}

let mut dvc = DynamicVirtualChannel::from_boxed(processor);
dvc.channel_id = Some(channel_id);
let dvc = match self.active_channels.entry(channel_id) {
alloc::collections::btree_map::Entry::Occupied(mut e) => {
e.insert(dvc);
e.into_mut()
}
alloc::collections::btree_map::Entry::Vacant(e) => e.insert(dvc),
};
Some(dvc)
}
Comment on lines +302 to +324

fn get_by_type_id(&self, type_id: TypeId) -> Option<&DynamicVirtualChannel> {
self.type_id_to_channel_id
.get(&type_id)
.and_then(|id| self.active_channels.get(id))
}

fn get_by_channel_id(&self, id: DynamicChannelId) -> Option<&DynamicVirtualChannel> {
self.active_channels.get(&id)
}

fn get_by_channel_id_mut(&mut self, id: DynamicChannelId) -> Option<&mut DynamicVirtualChannel> {
self.active_channels.get_mut(&id)
}

fn remove_by_channel_id(&mut self, id: DynamicChannelId) {
if let Some(dvc) = self.active_channels.remove(&id) {
let type_id = dvc.processor_type_id();

// Only matters for pre-registered channels
if let alloc::collections::btree_map::Entry::Occupied(entry) = self.type_id_to_channel_id.entry(type_id)
&& entry.get() == &id
{
entry.remove();
}
}
}

#[inline]
fn values(&self) -> impl Iterator<Item = &DynamicVirtualChannel> {
self.active_channels.values()
}
}
impl SvcClientProcessor for DrdynvcClient {}

fn decode_dvc_message(user_data: &[u8]) -> DecodeResult<DrdynvcServerPdu> {
Expand Down
87 changes: 8 additions & 79 deletions crates/ironrdp-dvc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

extern crate alloc;

use core::any::TypeId;

use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::vec::Vec;
use core::any::TypeId;

use pdu::DrdynvcDataPdu;

use crate::alloc::borrow::ToOwned as _;
// Re-export ironrdp_pdu crate for convenience
#[rustfmt::skip] // do not re-order this pub use
pub use ironrdp_pdu;
Expand Down Expand Up @@ -111,14 +110,18 @@ pub struct DynamicVirtualChannel {
}

impl DynamicVirtualChannel {
fn new<T: DvcProcessor + 'static>(handler: T) -> Self {
fn from_boxed(processor: Box<dyn DvcProcessor + Send>) -> Self {
Self {
channel_processor: Box::new(handler),
channel_processor: processor,
complete_data: CompleteData::new(),
channel_id: None,
}
}

fn processor_type_id(&self) -> TypeId {
self.channel_processor.as_any().type_id()
}

pub fn is_open(&self) -> bool {
self.channel_id.is_some()
}
Expand Down Expand Up @@ -154,79 +157,5 @@ impl DynamicVirtualChannel {
}
}

struct DynamicChannelSet {
channels: BTreeMap<DynamicChannelName, DynamicVirtualChannel>,
name_to_channel_id: BTreeMap<DynamicChannelName, DynamicChannelId>,
channel_id_to_name: BTreeMap<DynamicChannelId, DynamicChannelName>,
type_id_to_name: BTreeMap<TypeId, DynamicChannelName>,
}

impl DynamicChannelSet {
#[inline]
fn new() -> Self {
Self {
channels: BTreeMap::new(),
name_to_channel_id: BTreeMap::new(),
channel_id_to_name: BTreeMap::new(),
type_id_to_name: BTreeMap::new(),
}
}

fn insert<T: DvcProcessor + 'static>(&mut self, channel: T) -> Option<DynamicVirtualChannel> {
let name = channel.channel_name().to_owned();
self.type_id_to_name.insert(TypeId::of::<T>(), name.clone());
self.channels.insert(name, DynamicVirtualChannel::new(channel))
}

fn attach_channel_id(&mut self, name: DynamicChannelName, id: DynamicChannelId) -> Option<DynamicChannelId> {
self.channel_id_to_name.insert(id, name.clone());
self.name_to_channel_id.insert(name.clone(), id);
let dvc = self.get_by_channel_name_mut(&name)?;
let old_id = dvc.channel_id;
dvc.channel_id = Some(id);
old_id
}

fn get_by_type_id(&self, type_id: TypeId) -> Option<&DynamicVirtualChannel> {
self.type_id_to_name
.get(&type_id)
.and_then(|name| self.channels.get(name))
}

fn get_by_channel_name(&self, name: &DynamicChannelName) -> Option<&DynamicVirtualChannel> {
self.channels.get(name)
}

fn get_by_channel_name_mut(&mut self, name: &DynamicChannelName) -> Option<&mut DynamicVirtualChannel> {
self.channels.get_mut(name)
}

fn get_by_channel_id(&self, id: DynamicChannelId) -> Option<&DynamicVirtualChannel> {
self.channel_id_to_name
.get(&id)
.and_then(|name| self.channels.get(name))
}

fn get_by_channel_id_mut(&mut self, id: DynamicChannelId) -> Option<&mut DynamicVirtualChannel> {
self.channel_id_to_name
.get(&id)
.and_then(|name| self.channels.get_mut(name))
}

fn remove_by_channel_id(&mut self, id: DynamicChannelId) -> Option<DynamicChannelId> {
if let Some(name) = self.channel_id_to_name.remove(&id) {
return self.name_to_channel_id.remove(&name);
// Channels are retained in the `self.channels` and `self.type_id_to_name` map to allow potential
// dynamic re-addition by the server.
}
None
}

#[inline]
fn values(&self) -> impl Iterator<Item = &DynamicVirtualChannel> {
self.channels.values()
}
}

pub type DynamicChannelName = String;
pub type DynamicChannelId = u32;
23 changes: 21 additions & 2 deletions crates/ironrdp-dvc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ impl DrdynvcServer {
.is_some_and(|c| c.state == ChannelState::Opened)
}

// FIXME(#61): it's likely we want to enable adding dynamic channels at any point during the session (message passing? other approach?)

/// Registers a dynamic channel with the server.
///
/// # Panics
Expand All @@ -121,6 +119,27 @@ impl DrdynvcServer {
.get_mut(id)
.ok_or_else(|| invalid_field_err!("DRDYNVC", "", "invalid channel id"))
}

/// Creates a new DVC, returns CreateRequest PDU to send to client.
///
/// # Panics
///
/// Panics if the number of registered dynamic channels exceeds `u32::MAX`.
pub fn create_channel<T>(&mut self, channel: T) -> PduResult<SvcMessage>
where
T: DvcServerProcessor + 'static,
{
let channel_name = channel.channel_name().into();
let mut dvc = DynamicChannel::new(channel);
dvc.state = ChannelState::Creation;

let id = self.dynamic_channels.insert(dvc);
// The slab index is used as the DVC channel ID (a u32).
let channel_id = u32::try_from(id).expect("DVC channel count should not exceed u32::MAX");

let req = DrdynvcServerPdu::Create(CreateRequestPdu::new(channel_id, channel_name));
as_svc_msg_with_flag(req)
}
}

impl_as_any!(DrdynvcServer);
Expand Down
Loading