|
| 1 | +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +// Portions Copyright 2019 Intel Corporation. All Rights Reserved. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 |
| 6 | + |
| 7 | +use std::os::fd::AsRawFd; |
| 8 | +use std::os::unix::net::UnixStream; |
| 9 | + |
| 10 | +use thiserror::Error; |
| 11 | +use utils::eventfd::EventFd; |
| 12 | +use vhost::vhost_user::message::*; |
| 13 | +use vhost::vhost_user::{Frontend, VhostUserFrontend}; |
| 14 | +use vhost::{Error as VhostError, VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; |
| 15 | +use vm_memory::{Address, Error as MmapError, GuestMemory, GuestMemoryError, GuestMemoryRegion}; |
| 16 | + |
| 17 | +use crate::devices::virtio::queue::Queue; |
| 18 | +use crate::devices::virtio::IrqTrigger; |
| 19 | +use crate::vstate::memory::GuestMemoryMmap; |
| 20 | + |
| 21 | +/// vhost-user error. |
| 22 | +#[derive(Error, Debug, displaydoc::Display)] |
| 23 | +pub enum VhostUserError { |
| 24 | + /// Invalid available address |
| 25 | + AvailAddress(GuestMemoryError), |
| 26 | + /// Failed to connect to UDS Unix stream: {0} |
| 27 | + Connect(#[from] std::io::Error), |
| 28 | + /// Invalid descriptor table address |
| 29 | + DescriptorTableAddress(GuestMemoryError), |
| 30 | + /// Get features failed: {0} |
| 31 | + VhostUserGetFeatures(VhostError), |
| 32 | + /// Get protocol features failed: {0} |
| 33 | + VhostUserGetProtocolFeatures(VhostError), |
| 34 | + /// Set owner failed: {0} |
| 35 | + VhostUserSetOwner(VhostError), |
| 36 | + /// Set features failed: {0} |
| 37 | + VhostUserSetFeatures(VhostError), |
| 38 | + /// Set protocol features failed: {0} |
| 39 | + VhostUserSetProtocolFeatures(VhostError), |
| 40 | + /// Set mem table failed: {0} |
| 41 | + VhostUserSetMemTable(VhostError), |
| 42 | + /// Set vring num failed: {0} |
| 43 | + VhostUserSetVringNum(VhostError), |
| 44 | + /// Set vring addr failed: {0} |
| 45 | + VhostUserSetVringAddr(VhostError), |
| 46 | + /// Set vring base failed: {0} |
| 47 | + VhostUserSetVringBase(VhostError), |
| 48 | + /// Set vring call failed: {0} |
| 49 | + VhostUserSetVringCall(VhostError), |
| 50 | + /// Set vring kick failed: {0} |
| 51 | + VhostUserSetVringKick(VhostError), |
| 52 | + /// Set vring enable failed: {0} |
| 53 | + VhostUserSetVringEnable(VhostError), |
| 54 | + /// Failed to read vhost eventfd: {0} |
| 55 | + VhostUserMemoryRegion(MmapError), |
| 56 | + /// Invalid used address |
| 57 | + UsedAddress(GuestMemoryError), |
| 58 | +} |
| 59 | + |
| 60 | +/// vhost-user socket handle |
| 61 | +#[derive(Clone)] |
| 62 | +pub struct VhostUserHandle { |
| 63 | + pub vu: Frontend, |
| 64 | + pub socket_path: String, |
| 65 | +} |
| 66 | + |
| 67 | +impl std::fmt::Debug for VhostUserHandle { |
| 68 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 69 | + f.debug_struct("VhostUserHandle") |
| 70 | + .field("socket_path", &self.socket_path) |
| 71 | + .finish() |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +impl VhostUserHandle { |
| 76 | + /// Connect to the vhost-user backend socket and mark self as an |
| 77 | + /// owner of the session. |
| 78 | + pub fn new(socket_path: &str, num_queues: u64) -> Result<Self, VhostUserError> { |
| 79 | + let stream = UnixStream::connect(socket_path).map_err(VhostUserError::Connect)?; |
| 80 | + |
| 81 | + let vu = Frontend::from_stream(stream, num_queues); |
| 82 | + vu.set_owner().map_err(VhostUserError::VhostUserSetOwner)?; |
| 83 | + |
| 84 | + Ok(Self { |
| 85 | + vu, |
| 86 | + socket_path: socket_path.to_string(), |
| 87 | + }) |
| 88 | + } |
| 89 | + |
| 90 | + /// Set vhost-user features to the backend. |
| 91 | + pub fn set_features(&self, features: u64) -> Result<(), VhostUserError> { |
| 92 | + self.vu |
| 93 | + .set_features(features) |
| 94 | + .map_err(VhostUserError::VhostUserSetFeatures) |
| 95 | + } |
| 96 | + |
| 97 | + /// Set vhost-user protocol features to the backend. |
| 98 | + pub fn set_protocol_features( |
| 99 | + &mut self, |
| 100 | + acked_features: u64, |
| 101 | + acked_protocol_features: u64, |
| 102 | + ) -> Result<(), VhostUserError> { |
| 103 | + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { |
| 104 | + if let Some(acked_protocol_features) = |
| 105 | + VhostUserProtocolFeatures::from_bits(acked_protocol_features) |
| 106 | + { |
| 107 | + self.vu |
| 108 | + .set_protocol_features(acked_protocol_features) |
| 109 | + .map_err(VhostUserError::VhostUserSetProtocolFeatures)?; |
| 110 | + |
| 111 | + if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { |
| 112 | + self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + Ok(()) |
| 118 | + } |
| 119 | + |
| 120 | + /// Negotiate virtio and protocol features with the backend. |
| 121 | + pub fn negotiate_features( |
| 122 | + &mut self, |
| 123 | + avail_features: u64, |
| 124 | + avail_protocol_features: VhostUserProtocolFeatures, |
| 125 | + ) -> Result<(u64, u64), VhostUserError> { |
| 126 | + // Get features from backend, do negotiation to get a feature collection which |
| 127 | + // both VMM and backend support. |
| 128 | + let backend_features = self |
| 129 | + .vu |
| 130 | + .get_features() |
| 131 | + .map_err(VhostUserError::VhostUserGetFeatures)?; |
| 132 | + let acked_features = avail_features & backend_features; |
| 133 | + |
| 134 | + let acked_protocol_features = |
| 135 | + // If frontend can negotiate protocol features. |
| 136 | + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { |
| 137 | + let backend_protocol_features = self |
| 138 | + .vu |
| 139 | + .get_protocol_features() |
| 140 | + .map_err(VhostUserError::VhostUserGetProtocolFeatures)?; |
| 141 | + |
| 142 | + let acked_protocol_features = avail_protocol_features & backend_protocol_features; |
| 143 | + |
| 144 | + self.vu |
| 145 | + .set_protocol_features(acked_protocol_features) |
| 146 | + .map_err(VhostUserError::VhostUserSetProtocolFeatures)?; |
| 147 | + |
| 148 | + acked_protocol_features |
| 149 | + } else { |
| 150 | + VhostUserProtocolFeatures::empty() |
| 151 | + }; |
| 152 | + |
| 153 | + if acked_protocol_features.contains(VhostUserProtocolFeatures::REPLY_ACK) { |
| 154 | + self.vu.set_hdr_flags(VhostUserHeaderFlag::NEED_REPLY); |
| 155 | + } |
| 156 | + |
| 157 | + Ok((acked_features, acked_protocol_features.bits())) |
| 158 | + } |
| 159 | + |
| 160 | + /// Update guest memory table to the backend. |
| 161 | + fn update_mem_table(&mut self, mem: &GuestMemoryMmap) -> Result<(), VhostUserError> { |
| 162 | + let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); |
| 163 | + |
| 164 | + for region in mem.iter() { |
| 165 | + let (mmap_handle, mmap_offset) = match region.file_offset() { |
| 166 | + Some(_file_offset) => (_file_offset.file().as_raw_fd(), _file_offset.start()), |
| 167 | + None => { |
| 168 | + return Err(VhostUserError::VhostUserMemoryRegion( |
| 169 | + MmapError::NoMemoryRegion, |
| 170 | + )) |
| 171 | + } |
| 172 | + }; |
| 173 | + |
| 174 | + let vhost_user_net_reg = VhostUserMemoryRegionInfo { |
| 175 | + guest_phys_addr: region.start_addr().raw_value(), |
| 176 | + memory_size: region.len(), |
| 177 | + userspace_addr: region.as_ptr() as u64, |
| 178 | + mmap_offset, |
| 179 | + mmap_handle, |
| 180 | + }; |
| 181 | + regions.push(vhost_user_net_reg); |
| 182 | + } |
| 183 | + |
| 184 | + self.vu |
| 185 | + .set_mem_table(regions.as_slice()) |
| 186 | + .map_err(VhostUserError::VhostUserSetMemTable)?; |
| 187 | + |
| 188 | + Ok(()) |
| 189 | + } |
| 190 | + |
| 191 | + /// Set up vhost-user backend. This includes updating memory table, |
| 192 | + /// sending information about virtio rings and enabling them. |
| 193 | + pub fn setup_backend( |
| 194 | + &mut self, |
| 195 | + mem: &GuestMemoryMmap, |
| 196 | + queues: &[(usize, &Queue, &EventFd)], |
| 197 | + irq_trigger: &IrqTrigger, |
| 198 | + ) -> Result<(), VhostUserError> { |
| 199 | + // Provide the memory table to the backend. |
| 200 | + self.update_mem_table(mem)?; |
| 201 | + |
| 202 | + // Send set_vring_num here, since it could tell backends, like SPDK, |
| 203 | + // how many virt queues to be handled, which backend required to know |
| 204 | + // at early stage. |
| 205 | + for (queue_index, queue, _) in queues.iter() { |
| 206 | + self.vu |
| 207 | + .set_vring_num(*queue_index, queue.actual_size()) |
| 208 | + .map_err(VhostUserError::VhostUserSetVringNum)?; |
| 209 | + } |
| 210 | + |
| 211 | + for (queue_index, queue, queue_evt) in queues.iter() { |
| 212 | + let config_data = VringConfigData { |
| 213 | + queue_max_size: queue.get_max_size(), |
| 214 | + queue_size: queue.actual_size(), |
| 215 | + flags: 0u32, |
| 216 | + desc_table_addr: mem |
| 217 | + .get_host_address(queue.desc_table) |
| 218 | + .map_err(VhostUserError::DescriptorTableAddress)? |
| 219 | + as u64, |
| 220 | + used_ring_addr: mem |
| 221 | + .get_host_address(queue.used_ring) |
| 222 | + .map_err(VhostUserError::UsedAddress)? as u64, |
| 223 | + avail_ring_addr: mem |
| 224 | + .get_host_address(queue.avail_ring) |
| 225 | + .map_err(VhostUserError::AvailAddress)? as u64, |
| 226 | + log_addr: None, |
| 227 | + }; |
| 228 | + |
| 229 | + self.vu |
| 230 | + .set_vring_addr(*queue_index, &config_data) |
| 231 | + .map_err(VhostUserError::VhostUserSetVringAddr)?; |
| 232 | + self.vu |
| 233 | + .set_vring_base(*queue_index, queue.avail_idx(mem).0) |
| 234 | + .map_err(VhostUserError::VhostUserSetVringBase)?; |
| 235 | + |
| 236 | + // No matter the queue, we set irq_evt for signaling the guest that buffers were |
| 237 | + // consumed. |
| 238 | + self.vu |
| 239 | + .set_vring_call(*queue_index, &irq_trigger.irq_evt) |
| 240 | + .map_err(VhostUserError::VhostUserSetVringCall)?; |
| 241 | + |
| 242 | + self.vu |
| 243 | + .set_vring_kick(*queue_index, queue_evt) |
| 244 | + .map_err(VhostUserError::VhostUserSetVringKick)?; |
| 245 | + |
| 246 | + self.vu |
| 247 | + .set_vring_enable(*queue_index, true) |
| 248 | + .map_err(VhostUserError::VhostUserSetVringEnable)?; |
| 249 | + } |
| 250 | + |
| 251 | + Ok(()) |
| 252 | + } |
| 253 | +} |
0 commit comments