|
6 | 6 | // found in the THIRD-PARTY file. |
7 | 7 |
|
8 | 8 | use std::fs::File; |
9 | | -use std::io::{Read, Seek, SeekFrom}; |
| 9 | +use std::io::{Read, Seek, SeekFrom, Write}; |
10 | 10 | use std::sync::Arc; |
11 | 11 |
|
12 | 12 | use serde::{Deserialize, Serialize}; |
@@ -56,53 +56,132 @@ pub enum MemoryError { |
56 | 56 | /// Newtype that implements [`ReadVolatile`] and [`WriteVolatile`] if `T` implements `Read` or |
57 | 57 | /// `Write` respectively, by reading/writing using a bounce buffer, and memcpy-ing into the |
58 | 58 | /// [`VolatileSlice`]. |
| 59 | +/// |
| 60 | +/// Bounce buffers are allocated on the heap, as on-stack bounce buffers could cause stack |
| 61 | +/// overflows. If `N == 0` then bounce buffers will be allocated on demand. |
59 | 62 | #[derive(Debug)] |
60 | | -pub struct MaybeBounce<T>(pub T, pub bool); |
| 63 | +pub struct MaybeBounce<T, const N: usize = 0> { |
| 64 | + pub(crate) target: T, |
| 65 | + persistent_buffer: Option<Box<[u8; N]>>, |
| 66 | +} |
| 67 | + |
| 68 | +impl<T> MaybeBounce<T, 0> { |
| 69 | + /// Creates a new `MaybeBounce` that always allocates a bounce |
| 70 | + /// buffer on-demand |
| 71 | + pub fn new(target: T, should_bounce: bool) -> Self { |
| 72 | + MaybeBounce::new_persistent(target, should_bounce) |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +impl<T, const N: usize> MaybeBounce<T, N> { |
| 77 | + /// Creates a new `MaybeBounce` that uses a persistent, fixed size bounce buffer |
| 78 | + /// of size `N`. If a read/write request exceeds the size of this bounce buffer, it |
| 79 | + /// is split into multiple, `<= N`-size read/writes. |
| 80 | + pub fn new_persistent(target: T, should_bounce: bool) -> Self { |
| 81 | + let mut bounce = MaybeBounce { |
| 82 | + target, |
| 83 | + persistent_buffer: None, |
| 84 | + }; |
| 85 | + |
| 86 | + if should_bounce { |
| 87 | + bounce.activate() |
| 88 | + } |
| 89 | + |
| 90 | + bounce |
| 91 | + } |
| 92 | + |
| 93 | + /// Activates this [`MaybeBounce`] to start doing reads/writes via a bounce buffer, |
| 94 | + /// which is allocated on the heap by this function (e.g. if `activate()` is never called, |
| 95 | + /// no bounce buffer is ever allocated). |
| 96 | + pub fn activate(&mut self) { |
| 97 | + self.persistent_buffer = Some(vec![0u8; N].into_boxed_slice().try_into().unwrap()) |
| 98 | + } |
| 99 | +} |
61 | 100 |
|
62 | 101 | // FIXME: replace AsFd with ReadVolatile once &File: ReadVolatile in vm-memory. |
63 | | -impl<T: ReadVolatile> ReadVolatile for MaybeBounce<T> { |
| 102 | +impl<T: ReadVolatile, const N: usize> ReadVolatile for MaybeBounce<T, N> { |
64 | 103 | fn read_volatile<B: BitmapSlice>( |
65 | 104 | &mut self, |
66 | 105 | buf: &mut VolatileSlice<B>, |
67 | 106 | ) -> Result<usize, VolatileMemoryError> { |
68 | | - if self.1 { |
69 | | - let mut bbuf = vec![0; buf.len()]; |
70 | | - let n = self |
71 | | - .0 |
72 | | - .read_volatile(&mut VolatileSlice::from(bbuf.as_mut_slice()))?; |
73 | | - buf.copy_from(&bbuf[..n]); |
74 | | - Ok(n) |
| 107 | + if let Some(ref mut persistent) = self.persistent_buffer { |
| 108 | + let mut bbuf = (N == 0).then(|| vec![0u8; buf.len()]); |
| 109 | + let bbuf = bbuf.as_deref_mut().unwrap_or(persistent.as_mut_slice()); |
| 110 | + |
| 111 | + let mut buf = buf.offset(0)?; |
| 112 | + let mut total = 0; |
| 113 | + while !buf.is_empty() { |
| 114 | + let how_much = buf.len().min(bbuf.len()); |
| 115 | + let n = self |
| 116 | + .target |
| 117 | + .read_volatile(&mut VolatileSlice::from(&mut bbuf[..how_much]))?; |
| 118 | + buf.copy_from(&bbuf[..n]); |
| 119 | + |
| 120 | + buf = buf.offset(n)?; |
| 121 | + total += n; |
| 122 | + |
| 123 | + if n < how_much { |
| 124 | + break; |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + Ok(total) |
75 | 129 | } else { |
76 | | - self.0.read_volatile(buf) |
| 130 | + self.target.read_volatile(buf) |
77 | 131 | } |
78 | 132 | } |
79 | 133 | } |
80 | 134 |
|
81 | | -impl<T: WriteVolatile> WriteVolatile for MaybeBounce<T> { |
| 135 | +impl<T: WriteVolatile, const N: usize> WriteVolatile for MaybeBounce<T, N> { |
82 | 136 | fn write_volatile<B: BitmapSlice>( |
83 | 137 | &mut self, |
84 | 138 | buf: &VolatileSlice<B>, |
85 | 139 | ) -> Result<usize, VolatileMemoryError> { |
86 | | - if self.1 { |
87 | | - let mut bbuf = vec![0; buf.len()]; |
88 | | - buf.copy_to(bbuf.as_mut_slice()); |
89 | | - self.0 |
90 | | - .write_volatile(&VolatileSlice::from(bbuf.as_mut_slice())) |
| 140 | + if let Some(ref mut persistent) = self.persistent_buffer { |
| 141 | + let mut bbuf = (N == 0).then(|| vec![0u8; buf.len()]); |
| 142 | + let bbuf = bbuf.as_deref_mut().unwrap_or(persistent.as_mut_slice()); |
| 143 | + |
| 144 | + let mut buf = buf.offset(0)?; |
| 145 | + let mut total = 0; |
| 146 | + while !buf.is_empty() { |
| 147 | + let how_much = buf.copy_to(bbuf); |
| 148 | + let n = self |
| 149 | + .target |
| 150 | + .write_volatile(&VolatileSlice::from(&mut bbuf[..how_much]))?; |
| 151 | + buf = buf.offset(n)?; |
| 152 | + total += n; |
| 153 | + |
| 154 | + if n < how_much { |
| 155 | + break; |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + Ok(total) |
91 | 160 | } else { |
92 | | - self.0.write_volatile(buf) |
| 161 | + self.target.write_volatile(buf) |
93 | 162 | } |
94 | 163 | } |
95 | 164 | } |
96 | 165 |
|
97 | | -impl<R: Read> Read for MaybeBounce<R> { |
| 166 | +impl<R: Read, const N: usize> Read for MaybeBounce<R, N> { |
98 | 167 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { |
99 | | - self.0.read(buf) |
| 168 | + self.target.read(buf) |
| 169 | + } |
| 170 | +} |
| 171 | + |
| 172 | +impl<W: Write, const N: usize> Write for MaybeBounce<W, N> { |
| 173 | + fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { |
| 174 | + self.target.write(buf) |
| 175 | + } |
| 176 | + |
| 177 | + fn flush(&mut self) -> std::io::Result<()> { |
| 178 | + self.target.flush() |
100 | 179 | } |
101 | 180 | } |
102 | 181 |
|
103 | | -impl<S: Seek> Seek for MaybeBounce<S> { |
| 182 | +impl<S: Seek, const N: usize> Seek for MaybeBounce<S, N> { |
104 | 183 | fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> { |
105 | | - self.0.seek(pos) |
| 184 | + self.target.seek(pos) |
106 | 185 | } |
107 | 186 | } |
108 | 187 |
|
@@ -783,30 +862,45 @@ mod tests { |
783 | 862 | fn test_bounce() { |
784 | 863 | let file_direct = TempFile::new().unwrap(); |
785 | 864 | let file_bounced = TempFile::new().unwrap(); |
| 865 | + let file_persistent_bounced = TempFile::new().unwrap(); |
786 | 866 |
|
787 | 867 | let mut data = (0..=255).collect::<Vec<_>>(); |
788 | 868 |
|
789 | | - MaybeBounce(file_direct.as_file().as_fd(), false) |
| 869 | + MaybeBounce::new(file_direct.as_file().as_fd(), false) |
790 | 870 | .write_all_volatile(&VolatileSlice::from(data.as_mut_slice())) |
791 | 871 | .unwrap(); |
792 | | - MaybeBounce(file_bounced.as_file().as_fd(), true) |
| 872 | + MaybeBounce::new(file_bounced.as_file().as_fd(), true) |
| 873 | + .write_all_volatile(&VolatileSlice::from(data.as_mut_slice())) |
| 874 | + .unwrap(); |
| 875 | + MaybeBounce::<_, 7>::new_persistent(file_persistent_bounced.as_file().as_fd(), true) |
793 | 876 | .write_all_volatile(&VolatileSlice::from(data.as_mut_slice())) |
794 | 877 | .unwrap(); |
795 | 878 |
|
796 | 879 | let mut data_direct = vec![0u8; 256]; |
797 | 880 | let mut data_bounced = vec![0u8; 256]; |
| 881 | + let mut data_persistent_bounced = vec![0u8; 256]; |
798 | 882 |
|
799 | 883 | file_direct.as_file().seek(SeekFrom::Start(0)).unwrap(); |
800 | 884 | file_bounced.as_file().seek(SeekFrom::Start(0)).unwrap(); |
| 885 | + file_persistent_bounced |
| 886 | + .as_file() |
| 887 | + .seek(SeekFrom::Start(0)) |
| 888 | + .unwrap(); |
801 | 889 |
|
802 | | - MaybeBounce(file_direct.as_file().as_fd(), false) |
| 890 | + MaybeBounce::new(file_direct.as_file().as_fd(), false) |
803 | 891 | .read_exact_volatile(&mut VolatileSlice::from(data_direct.as_mut_slice())) |
804 | 892 | .unwrap(); |
805 | | - MaybeBounce(file_bounced.as_file().as_fd(), true) |
| 893 | + MaybeBounce::new(file_bounced.as_file().as_fd(), true) |
806 | 894 | .read_exact_volatile(&mut VolatileSlice::from(data_bounced.as_mut_slice())) |
807 | 895 | .unwrap(); |
| 896 | + MaybeBounce::<_, 7>::new_persistent(file_persistent_bounced.as_file().as_fd(), true) |
| 897 | + .read_exact_volatile(&mut VolatileSlice::from( |
| 898 | + data_persistent_bounced.as_mut_slice(), |
| 899 | + )) |
| 900 | + .unwrap(); |
808 | 901 |
|
809 | 902 | assert_eq!(data_direct, data_bounced); |
810 | 903 | assert_eq!(data_direct, data); |
| 904 | + assert_eq!(data_persistent_bounced, data); |
811 | 905 | } |
812 | 906 | } |
0 commit comments