|
| 1 | +use std::any::TypeId; |
| 2 | +use std::mem; |
| 3 | +use std::pin::{Pin, pin}; |
| 4 | +use std::task::{Context, Poll, ready}; |
| 5 | + |
| 6 | +use bytes::{Buf, Bytes}; |
| 7 | +use futures_util::stream::StreamExt; |
| 8 | +use http_body_util::BodyExt as _; |
| 9 | +use hyper::body::{Body as HttpBody, Frame, Incoming, SizeHint}; |
| 10 | +use mlua::{ |
| 11 | + Error, ExternalError, FromLua, Lua, Result as LuaResult, UserData, UserDataMethods, UserDataRegistry, |
| 12 | + Value, |
| 13 | +}; |
| 14 | + |
| 15 | +/// A Lua-accessible HTTP body |
| 16 | +/// |
| 17 | +/// This can wrap various body types, including raw bytes, Hyper incoming bodies, |
| 18 | +/// and Reqwest bodies (if the `reqwest` feature is enabled). |
| 19 | +pub struct LuaBody(Inner); |
| 20 | + |
| 21 | +enum Inner { |
| 22 | + Bytes(Bytes), |
| 23 | + Incoming { |
| 24 | + incoming: Incoming, |
| 25 | + // If Some, the maximum number of bytes allowed to be read from the body |
| 26 | + remaining: Option<usize>, |
| 27 | + }, |
| 28 | + #[cfg(feature = "reqwest")] |
| 29 | + Reqwest { |
| 30 | + body: reqwest::Body, |
| 31 | + // If Some, the maximum number of bytes allowed to be read from the body |
| 32 | + remaining: Option<usize>, |
| 33 | + }, |
| 34 | +} |
| 35 | + |
| 36 | +impl Default for LuaBody { |
| 37 | + #[inline] |
| 38 | + fn default() -> Self { |
| 39 | + LuaBody::new() |
| 40 | + } |
| 41 | +} |
| 42 | + |
| 43 | +impl LuaBody { |
| 44 | + pub const fn new() -> Self { |
| 45 | + LuaBody(Inner::Bytes(Bytes::new())) |
| 46 | + } |
| 47 | + |
| 48 | + async fn buffer(&mut self) -> Result<(), Error> { |
| 49 | + match self { |
| 50 | + LuaBody(Inner::Bytes(_)) => Ok(()), |
| 51 | + _ => { |
| 52 | + let collect = self.collect().await?; |
| 53 | + *self = LuaBody::from(collect.to_bytes()); |
| 54 | + Ok(()) |
| 55 | + } |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + fn consume_if_unbuffered(&mut self) -> Self { |
| 60 | + match self { |
| 61 | + LuaBody(Inner::Bytes(bytes)) => LuaBody(Inner::Bytes(bytes.clone())), |
| 62 | + _ => mem::take(self), |
| 63 | + } |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +impl From<Bytes> for LuaBody { |
| 68 | + fn from(bytes: Bytes) -> Self { |
| 69 | + LuaBody(Inner::Bytes(bytes)) |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +impl From<Incoming> for LuaBody { |
| 74 | + fn from(incoming: Incoming) -> Self { |
| 75 | + LuaBody(Inner::Incoming { |
| 76 | + incoming, |
| 77 | + remaining: None, |
| 78 | + }) |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +#[cfg(feature = "reqwest")] |
| 83 | +impl From<reqwest::Body> for LuaBody { |
| 84 | + fn from(body: reqwest::Body) -> Self { |
| 85 | + LuaBody(Inner::Reqwest { |
| 86 | + body, |
| 87 | + remaining: None, |
| 88 | + }) |
| 89 | + } |
| 90 | +} |
| 91 | + |
| 92 | +#[cfg(feature = "reqwest")] |
| 93 | +impl From<LuaBody> for reqwest::Body { |
| 94 | + fn from(body: LuaBody) -> Self { |
| 95 | + match body { |
| 96 | + LuaBody(Inner::Bytes(bytes)) => reqwest::Body::from(bytes), |
| 97 | + LuaBody(Inner::Incoming { incoming, .. }) => reqwest::Body::wrap(incoming), |
| 98 | + #[cfg(feature = "reqwest")] |
| 99 | + LuaBody(Inner::Reqwest { body, .. }) => body, |
| 100 | + } |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +impl FromLua for LuaBody { |
| 105 | + fn from_lua(value: Value, _: &Lua) -> LuaResult<Self> { |
| 106 | + match value { |
| 107 | + Value::String(s) => Ok(LuaBody::from(Bytes::copy_from_slice(&s.as_bytes()))), |
| 108 | + Value::UserData(ud) => match ud.type_id() { |
| 109 | + Some(id) if id == TypeId::of::<Bytes>() => Ok(LuaBody::from(ud.borrow::<Bytes>()?.clone())), |
| 110 | + Some(id) if id == TypeId::of::<LuaBody>() => ud.take::<LuaBody>(), |
| 111 | + _ => Err(mlua::Error::FromLuaConversionError { |
| 112 | + from: "UserData", |
| 113 | + to: "Body".to_string(), |
| 114 | + message: Some("expected Bytes or Body userdata".to_string()), |
| 115 | + }), |
| 116 | + }, |
| 117 | + _ => Err(Error::FromLuaConversionError { |
| 118 | + from: value.type_name(), |
| 119 | + to: "Body".to_string(), |
| 120 | + message: Some("expected String".to_string()), |
| 121 | + }), |
| 122 | + } |
| 123 | + } |
| 124 | +} |
| 125 | + |
| 126 | +impl HttpBody for LuaBody { |
| 127 | + type Data = Bytes; |
| 128 | + type Error = Error; |
| 129 | + |
| 130 | + fn poll_frame( |
| 131 | + self: Pin<&mut Self>, |
| 132 | + cx: &mut Context<'_>, |
| 133 | + ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { |
| 134 | + fn process_frame(frame: Frame<Bytes>, remaining: &mut Option<usize>) -> Result<Frame<Bytes>, Error> { |
| 135 | + if let (Some(data), Some(remaining)) = (frame.data_ref(), remaining.as_mut()) { |
| 136 | + if data.remaining() > *remaining { |
| 137 | + *remaining = 0; |
| 138 | + Err("body size limit exceeded".into_lua_err()) |
| 139 | + } else { |
| 140 | + *remaining -= data.remaining(); |
| 141 | + Ok(frame) |
| 142 | + } |
| 143 | + } else { |
| 144 | + Ok(frame) |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + let this = self.get_mut(); |
| 149 | + match &mut this.0 { |
| 150 | + Inner::Bytes(bytes) if bytes.is_empty() => Poll::Ready(None), |
| 151 | + Inner::Bytes(bytes) => { |
| 152 | + let chunk = mem::take(bytes); |
| 153 | + Poll::Ready(Some(Ok(Frame::data(chunk)))) |
| 154 | + } |
| 155 | + Inner::Incoming { incoming, remaining } => match ready!(pin!(incoming).poll_frame(cx)) { |
| 156 | + Some(Ok(frame)) => Poll::Ready(Some(process_frame(frame, remaining))), |
| 157 | + Some(Err(e)) => Poll::Ready(Some(Err(e.into_lua_err()))), |
| 158 | + None => Poll::Ready(None), |
| 159 | + }, |
| 160 | + #[cfg(feature = "reqwest")] |
| 161 | + Inner::Reqwest { body, remaining } => match ready!(pin!(body).poll_frame(cx)) { |
| 162 | + Some(Ok(frame)) => Poll::Ready(Some(process_frame(frame, remaining))), |
| 163 | + Some(Err(e)) => Poll::Ready(Some(Err(e.into_lua_err()))), |
| 164 | + None => Poll::Ready(None), |
| 165 | + }, |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + fn is_end_stream(&self) -> bool { |
| 170 | + match &self.0 { |
| 171 | + Inner::Bytes(bytes) => bytes.is_empty(), |
| 172 | + Inner::Incoming { incoming, .. } => incoming.is_end_stream(), |
| 173 | + #[cfg(feature = "reqwest")] |
| 174 | + Inner::Reqwest { body, .. } => body.is_end_stream(), |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + fn size_hint(&self) -> SizeHint { |
| 179 | + match &self.0 { |
| 180 | + Inner::Bytes(bytes) => SizeHint::with_exact(bytes.len() as u64), |
| 181 | + Inner::Incoming { incoming, .. } => incoming.size_hint(), |
| 182 | + #[cfg(feature = "reqwest")] |
| 183 | + Inner::Reqwest { body, .. } => body.size_hint(), |
| 184 | + } |
| 185 | + } |
| 186 | +} |
| 187 | + |
| 188 | +impl UserData for LuaBody { |
| 189 | + fn register(registry: &mut UserDataRegistry<Self>) { |
| 190 | + // Get the (upper) size hint for the body |
| 191 | + registry.add_method("size_hint", |_, this, ()| { |
| 192 | + let hint = this.size_hint(); |
| 193 | + match hint.upper() { |
| 194 | + Some(upper) => Ok(Some(upper)), |
| 195 | + None => Ok(None), |
| 196 | + } |
| 197 | + }); |
| 198 | + |
| 199 | + // Set a size limit for reading the body from an incoming stream |
| 200 | + registry.add_method_mut("set_size_limit", |_, this, limit| match &mut this.0 { |
| 201 | + Inner::Bytes(_) => Ok(()), |
| 202 | + Inner::Incoming { remaining, .. } => { |
| 203 | + *remaining = Some(limit); |
| 204 | + Ok(()) |
| 205 | + } |
| 206 | + #[cfg(feature = "reqwest")] |
| 207 | + Inner::Reqwest { remaining, .. } => { |
| 208 | + *remaining = Some(limit); |
| 209 | + Ok(()) |
| 210 | + } |
| 211 | + }); |
| 212 | + |
| 213 | + // Buffer the body fully into memory |
| 214 | + registry.add_async_method_mut("buffer", |_, mut this, ()| async move { |
| 215 | + lua_try!(this.buffer().await); |
| 216 | + Ok(Ok(())) |
| 217 | + }); |
| 218 | + |
| 219 | + // Discard the body without reading it |
| 220 | + registry.add_method_mut("discard", |_, this, ()| { |
| 221 | + *this = LuaBody(Inner::Bytes(Bytes::new())); |
| 222 | + Ok(()) |
| 223 | + }); |
| 224 | + |
| 225 | + // Read the full body as bytes |
| 226 | + // |
| 227 | + // Consumes the body if it is not buffered |
| 228 | + registry.add_async_method_mut("read", |lua, mut this, ()| async move { |
| 229 | + let body = this.consume_if_unbuffered(); |
| 230 | + let body = lua_try!(body.collect().await); |
| 231 | + let bytes = body.to_bytes(); |
| 232 | + Ok(Ok(lua.create_any_userdata(bytes)?)) |
| 233 | + }); |
| 234 | + |
| 235 | + // Get an async reader for the body |
| 236 | + // |
| 237 | + // Consumes the body if it is not buffered, returns a function that can be |
| 238 | + // called to get the next chunk of data |
| 239 | + registry.add_method_mut("reader", |lua, this, ()| { |
| 240 | + use std::cell::RefCell; |
| 241 | + use std::rc::Rc; |
| 242 | + |
| 243 | + let body_stream = this.consume_if_unbuffered().into_data_stream(); |
| 244 | + let body_stream = Rc::new(RefCell::new(body_stream)); |
| 245 | + lua.create_async_function(move |lua, ()| { |
| 246 | + let body_stream = body_stream.clone(); |
| 247 | + #[allow(clippy::await_holding_refcell_ref)] |
| 248 | + async move { |
| 249 | + let mut body_stream = lua_try!(body_stream.try_borrow_mut()); |
| 250 | + match body_stream.next().await { |
| 251 | + Some(Ok(data)) => { |
| 252 | + let data = lua.create_any_userdata(data)?; |
| 253 | + Ok(Ok(Value::UserData(data))) |
| 254 | + } |
| 255 | + Some(Err(e)) => Ok(Err(e.to_string())), |
| 256 | + None => Ok(Ok(Value::Nil)), |
| 257 | + } |
| 258 | + } |
| 259 | + }) |
| 260 | + }); |
| 261 | + |
| 262 | + // Read the full body as text |
| 263 | + // |
| 264 | + // Consumes the body if it is not buffered |
| 265 | + registry.add_async_method_mut("text", |lua, mut this, ()| async move { |
| 266 | + let body = this.consume_if_unbuffered(); |
| 267 | + let body = lua_try!(body.collect().await); |
| 268 | + let text = lua.create_string(body.to_bytes())?; |
| 269 | + Ok(Ok(text)) |
| 270 | + }); |
| 271 | + |
| 272 | + // Read the full body as JSON |
| 273 | + // |
| 274 | + // Consumes the body if it is not buffered |
| 275 | + #[cfg(feature = "json")] |
| 276 | + registry.add_async_method_mut("json", |_, mut this, ()| async move { |
| 277 | + let body = this.consume_if_unbuffered(); |
| 278 | + let body = lua_try!(body.collect().await); |
| 279 | + let body_reader = body.aggregate().reader(); |
| 280 | + let json = lua_try!(serde_json::from_reader::<_, serde_json::Value>(body_reader)); |
| 281 | + Ok(Ok(crate::json::JsonObject::from(json))) |
| 282 | + }); |
| 283 | + } |
| 284 | +} |
0 commit comments