Skip to content

Commit 4487174

Browse files
committed
proxy: Reintroduce structures
Instead of a generic array, use a trait on different structs. This ensures even more type safety for callers. Signed-off-by: Colin Walters <[email protected]>
1 parent f2118a3 commit 4487174

File tree

1 file changed

+120
-96
lines changed

1 file changed

+120
-96
lines changed

src/imageproxy.rs

Lines changed: 120 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -341,30 +341,81 @@ pub struct ConvertedLayerInfo {
341341
pub media_type: oci_spec::image::MediaType,
342342
}
343343

344-
// Consumes an iterable and tries to convert it to a fixed-size array. Returns Ok([T; N]) if the
345-
// number of items in the iterable was correct, else an error describing the mismatch.
346-
fn fixed_from_iterable<T, const N: usize>(
347-
iterable: impl IntoIterator<IntoIter: FusedIterator, Item = T>,
348-
) -> Result<[T; N]> {
349-
let mut iter = iterable.into_iter();
350-
// We make use of the fact that [_; N].map() returns [_; N]. That makes this a bit more
351-
// awkward than it would otherwise be, but it's not too bad.
352-
let collected = [(); N].map(|_| iter.next());
353-
// Count the Some() in `collected` plus leftovers in the iter.
354-
let actual = collected.iter().flatten().count() + iter.count();
355-
if actual == N {
356-
// SAFETY: This is a fused iter, so all N items are in our array
357-
Ok(collected.map(Option::unwrap))
358-
} else {
359-
let type_name = std::any::type_name::<T>();
360-
let basename = type_name
361-
.rsplit_once("::")
362-
.map(|(_path, name)| name)
363-
.unwrap_or(type_name);
364-
365-
Err(Error::Other(
366-
format!("Expected {N} {basename} but got {actual}").into(),
367-
))
344+
/// A single fd; requires invoking FinishPipe
345+
#[derive(Debug)]
346+
struct FinishPipe {
347+
pipeid: PipeId,
348+
datafd: OwnedFd,
349+
}
350+
351+
/// There is a data FD and an error FD. The error FD will be JSON.
352+
#[derive(Debug)]
353+
struct DualFds {
354+
datafd: OwnedFd,
355+
errfd: OwnedFd,
356+
}
357+
358+
/// Helper trait for parsing the pipeid and/or file descriptors of a reply
359+
trait FromReplyFds: Send + 'static
360+
where
361+
Self: Sized,
362+
{
363+
fn from_reply(
364+
iterable: impl IntoIterator<IntoIter: FusedIterator, Item = OwnedFd>,
365+
pipeid: u32,
366+
) -> Result<Self>;
367+
}
368+
369+
/// No file descriptors or pipeid expected
370+
impl FromReplyFds for () {
371+
fn from_reply(fds: impl IntoIterator<Item = OwnedFd>, pipeid: u32) -> Result<Self> {
372+
if fds.into_iter().next().is_some() {
373+
return Err(Error::Other("expected no fds".into()));
374+
}
375+
if pipeid != 0 {
376+
return Err(Error::Other("unexpected pipeid".into()));
377+
}
378+
Ok(())
379+
}
380+
}
381+
382+
/// A FinishPipe instance
383+
impl FromReplyFds for FinishPipe {
384+
fn from_reply(fds: impl IntoIterator<Item = OwnedFd>, pipeid: u32) -> Result<Self> {
385+
let mut fds = fds.into_iter();
386+
let Some(first_fd) = fds.next() else {
387+
return Err(Error::Other("Expected fd for FinishPipe".into()));
388+
};
389+
if fds.next().is_some() {
390+
return Err(Error::Other("More than one fd for FinishPipe".into()));
391+
}
392+
let Some(pipeid) = PipeId::try_new(pipeid) else {
393+
return Err(Error::Other("Expected pipeid for FinishPipe".into()));
394+
};
395+
Ok(Self {
396+
pipeid,
397+
datafd: first_fd,
398+
})
399+
}
400+
}
401+
402+
/// A DualFds instance
403+
impl FromReplyFds for DualFds {
404+
fn from_reply(fds: impl IntoIterator<Item = OwnedFd>, pipeid: u32) -> Result<Self> {
405+
let mut fds = fds.into_iter();
406+
let Some(datafd) = fds.next() else {
407+
return Err(Error::Other("Expected data fd for DualFds".into()));
408+
};
409+
let Some(errfd) = fds.next() else {
410+
return Err(Error::Other("Expected err fd for DualFds".into()));
411+
};
412+
if fds.next().is_some() {
413+
return Err(Error::Other("More than two fds for DualFds".into()));
414+
}
415+
if pipeid != 0 {
416+
return Err(Error::Other("Unexpected pipeid with DualFds".into()));
417+
}
418+
Ok(Self { datafd, errfd })
368419
}
369420
}
370421

@@ -404,7 +455,7 @@ impl ImageProxy {
404455
};
405456

406457
// Verify semantic version
407-
let (protover, [], []): (String, _, _) = r.impl_request("Initialize", [(); 0]).await?;
458+
let (protover, _): (String, ()) = r.impl_request("Initialize", [(); 0]).await?;
408459
tracing::debug!("Remote protocol version: {protover}");
409460
let protover = semver::Version::parse(protover.as_str())?;
410461
// Previously we had a feature to opt-in to requiring newer versions using `if cfg!()`.
@@ -420,14 +471,10 @@ impl ImageProxy {
420471
Ok(r)
421472
}
422473

423-
async fn impl_request_raw<
424-
T: serde::de::DeserializeOwned + Send + 'static,
425-
const N: usize,
426-
const M: usize,
427-
>(
474+
async fn impl_request_raw<T: serde::de::DeserializeOwned + Send + 'static, F: FromReplyFds>(
428475
sockfd: Arc<Mutex<OwnedFd>>,
429476
req: Request,
430-
) -> Result<(T, [OwnedFd; N], [PipeId; M])> {
477+
) -> Result<(T, F)> {
431478
tracing::trace!("sending request {}", req.method.as_str());
432479
// TODO: Investigate https://crates.io/crates/uds for SOCK_SEQPACKET tokio
433480
let r = tokio::task::spawn_blocking(move || {
@@ -464,11 +511,8 @@ impl ImageProxy {
464511
error: reply.error.into(),
465512
});
466513
}
467-
Ok((
468-
serde_json::from_value(reply.value)?,
469-
fixed_from_iterable(fdret)?,
470-
fixed_from_iterable(PipeId::try_new(reply.pipeid))?,
471-
))
514+
let fds = FromReplyFds::from_reply(fdret, reply.pipeid)?;
515+
Ok((serde_json::from_value(reply.value)?, fds))
472516
})
473517
.await
474518
.map_err(|e| Error::Other(e.to_string().into()))??;
@@ -477,15 +521,11 @@ impl ImageProxy {
477521
}
478522

479523
#[instrument(skip(args))]
480-
async fn impl_request<
481-
T: serde::de::DeserializeOwned + Send + 'static,
482-
const N: usize,
483-
const M: usize,
484-
>(
524+
async fn impl_request<T: serde::de::DeserializeOwned + Send + 'static, F: FromReplyFds>(
485525
&self,
486526
method: &str,
487527
args: impl IntoIterator<Item = impl Into<serde_json::Value>>,
488-
) -> Result<(T, [OwnedFd; N], [PipeId; M])> {
528+
) -> Result<(T, F)> {
489529
let req = Self::impl_request_raw(Arc::clone(&self.sockfd), Request::new(method, args));
490530
let mut childwait = self.childwait.lock().await;
491531
tokio::select! {
@@ -501,21 +541,21 @@ impl ImageProxy {
501541
#[instrument]
502542
async fn finish_pipe(&self, pipeid: PipeId) -> Result<()> {
503543
tracing::debug!("closing pipe");
504-
let (r, [], []) = self.impl_request("FinishPipe", [pipeid.0.get()]).await?;
544+
let (r, ()) = self.impl_request("FinishPipe", [pipeid.0.get()]).await?;
505545
Ok(r)
506546
}
507547

508548
#[instrument]
509549
pub async fn open_image(&self, imgref: &str) -> Result<OpenedImage> {
510550
tracing::debug!("opening image");
511-
let (imgid, [], []) = self.impl_request("OpenImage", [imgref]).await?;
551+
let (imgid, ()) = self.impl_request("OpenImage", [imgref]).await?;
512552
Ok(OpenedImage(imgid))
513553
}
514554

515555
#[instrument]
516556
pub async fn open_image_optional(&self, imgref: &str) -> Result<Option<OpenedImage>> {
517557
tracing::debug!("opening image");
518-
let (imgid, [], []) = self.impl_request("OpenImageOptional", [imgref]).await?;
558+
let (imgid, ()) = self.impl_request("OpenImageOptional", [imgref]).await?;
519559
if imgid == 0 {
520560
Ok(None)
521561
} else {
@@ -526,16 +566,16 @@ impl ImageProxy {
526566
#[instrument]
527567
pub async fn close_image(&self, img: &OpenedImage) -> Result<()> {
528568
tracing::debug!("closing image");
529-
let (r, [], []) = self.impl_request("CloseImage", [img.0]).await?;
569+
let (r, ()) = self.impl_request("CloseImage", [img.0]).await?;
530570
Ok(r)
531571
}
532572

533-
async fn read_all_fd(&self, datafd: OwnedFd, pipeid: PipeId) -> Result<Vec<u8>> {
534-
let fd = tokio::fs::File::from_std(std::fs::File::from(datafd));
573+
async fn read_finish_pipe(&self, pipe: FinishPipe) -> Result<Vec<u8>> {
574+
let fd = tokio::fs::File::from_std(std::fs::File::from(pipe.datafd));
535575
let mut fd = tokio::io::BufReader::new(fd);
536576
let mut r = Vec::new();
537577
let reader = fd.read_to_end(&mut r);
538-
let (nbytes, finish) = tokio::join!(reader, self.finish_pipe(pipeid));
578+
let (nbytes, finish) = tokio::join!(reader, self.finish_pipe(pipe.pipeid));
539579
finish?;
540580
assert_eq!(nbytes?, r.len());
541581
Ok(r)
@@ -545,8 +585,8 @@ impl ImageProxy {
545585
/// The original digest of the unconverted manifest is also returned.
546586
/// For more information on OCI manifests, see <https://github.com/opencontainers/image-spec/blob/main/manifest.md>
547587
pub async fn fetch_manifest_raw_oci(&self, img: &OpenedImage) -> Result<(String, Vec<u8>)> {
548-
let (digest, [datafd], [pipeid]) = self.impl_request("GetManifest", [img.0]).await?;
549-
Ok((digest, self.read_all_fd(datafd, pipeid).await?))
588+
let (digest, pipefd) = self.impl_request("GetManifest", [img.0]).await?;
589+
Ok((digest, self.read_finish_pipe(pipefd).await?))
550590
}
551591

552592
/// Fetch the manifest.
@@ -563,8 +603,8 @@ impl ImageProxy {
563603
/// Fetch the config.
564604
/// For more information on OCI config, see <https://github.com/opencontainers/image-spec/blob/main/config.md>
565605
pub async fn fetch_config_raw(&self, img: &OpenedImage) -> Result<Vec<u8>> {
566-
let ((), [datafd], [pipeid]) = self.impl_request("GetFullConfig", [img.0]).await?;
567-
self.read_all_fd(datafd, pipeid).await
606+
let ((), pipe) = self.impl_request("GetFullConfig", [img.0]).await?;
607+
self.read_finish_pipe(pipe).await
568608
}
569609

570610
/// Fetch the config.
@@ -601,11 +641,11 @@ impl ImageProxy {
601641
tracing::debug!("fetching blob");
602642
let args: Vec<serde_json::Value> =
603643
vec![img.0.into(), digest.to_string().into(), size.into()];
604-
let (bloblen, [datafd], [pipeid]) = self.impl_request("GetBlob", args).await?;
644+
let (bloblen, pipe): (u64, FinishPipe) = self.impl_request("GetBlob", args).await?;
605645
let _: u64 = bloblen;
606-
let fd = tokio::fs::File::from_std(std::fs::File::from(datafd));
646+
let fd = tokio::fs::File::from_std(std::fs::File::from(pipe.datafd));
607647
let fd = tokio::io::BufReader::new(fd);
608-
let finish = Box::pin(self.finish_pipe(pipeid));
648+
let finish = Box::pin(self.finish_pipe(pipe.pipeid));
609649
Ok((fd, finish))
610650
}
611651

@@ -650,9 +690,9 @@ impl ImageProxy {
650690
)> {
651691
tracing::debug!("fetching blob");
652692
let args: Vec<serde_json::Value> = vec![img.0.into(), digest.to_string().into()];
653-
let (bloblen, [datafd, errfd], []) = self.impl_request("GetRawBlob", args).await?;
654-
let fd = tokio::fs::File::from_std(std::fs::File::from(datafd));
655-
let err = Self::read_blob_error(errfd).boxed();
693+
let (bloblen, fds): (u64, DualFds) = self.impl_request("GetRawBlob", args).await?;
694+
let fd = tokio::fs::File::from_std(std::fs::File::from(fds.datafd));
695+
let err = Self::read_blob_error(fds.errfd).boxed();
656696
Ok((bloblen, fd, err))
657697
}
658698

@@ -678,14 +718,14 @@ impl ImageProxy {
678718
) -> Result<Option<Vec<ConvertedLayerInfo>>> {
679719
tracing::debug!("Getting layer info");
680720
if layer_info_piped_proto_version().matches(&self.protover) {
681-
let ((), [datafd], [pipeid]) = self.impl_request("GetLayerInfoPiped", [img.0]).await?;
682-
let buf = self.read_all_fd(datafd, pipeid).await?;
721+
let ((), pipe) = self.impl_request("GetLayerInfoPiped", [img.0]).await?;
722+
let buf = self.read_finish_pipe(pipe).await?;
683723
return Ok(Some(serde_json::from_slice(&buf)?));
684724
}
685725
if !layer_info_proto_version().matches(&self.protover) {
686726
return Ok(None);
687727
}
688-
let (layers, [], []) = self.impl_request("GetLayerInfo", [img.0]).await?;
728+
let (layers, ()) = self.impl_request("GetLayerInfo", [img.0]).await?;
689729
Ok(Some(layers))
690730
}
691731

@@ -893,31 +933,15 @@ mod tests {
893933
memfd_create(c"test-fd", MemfdFlags::CLOEXEC).unwrap()
894934
}
895935

896-
fn fds_and_pipeid<const N: usize, const M: usize>(
897-
fds: impl IntoIterator<IntoIter: FusedIterator, Item = OwnedFd>,
898-
pipeid: u32,
899-
) -> Result<([OwnedFd; N], [PipeId; M])> {
900-
Ok((
901-
fixed_from_iterable(fds)?,
902-
fixed_from_iterable(PipeId::try_new(pipeid))?,
903-
))
904-
}
905-
906-
#[test]
907-
fn test_new_from_raw_values_no_fds_no_pipeid() {
908-
let ([], []) = fds_and_pipeid([], 0).unwrap();
909-
}
910-
911936
#[test]
912937
fn test_new_from_raw_values_finish_pipe() {
913938
let datafd = create_dummy_fd();
914939
// Keep a raw fd to compare later, as fds_and_pipeid consumes datafd
915940
let raw_datafd_val = datafd.as_raw_fd();
916941
let fds = vec![datafd];
917-
let pipeid = PipeId::try_new(1).unwrap();
918-
let ([res_datafd], [res_pipeid]) = fds_and_pipeid(fds, pipeid.0.get()).unwrap();
919-
assert_eq!(res_pipeid, pipeid);
920-
assert_eq!(res_datafd.as_raw_fd(), raw_datafd_val);
942+
let v = FinishPipe::from_reply(fds, 1).unwrap();
943+
assert_eq!(v.pipeid.0.get(), 1);
944+
assert_eq!(v.datafd.as_raw_fd(), raw_datafd_val);
921945
}
922946

923947
#[test]
@@ -927,18 +951,18 @@ mod tests {
927951
let raw_datafd_val = datafd.as_raw_fd();
928952
let raw_errfd_val = errfd.as_raw_fd();
929953
let fds = vec![datafd, errfd];
930-
let ([res_datafd, res_errfd], []) = fds_and_pipeid(fds, 0).unwrap();
931-
assert_eq!(res_datafd.as_raw_fd(), raw_datafd_val);
932-
assert_eq!(res_errfd.as_raw_fd(), raw_errfd_val);
954+
let v = DualFds::from_reply(fds, 0).unwrap();
955+
assert_eq!(v.datafd.as_raw_fd(), raw_datafd_val);
956+
assert_eq!(v.errfd.as_raw_fd(), raw_errfd_val);
933957
}
934958

935959
#[test]
936960
fn test_new_from_raw_values_error_too_many_fds() {
937961
let fds = vec![create_dummy_fd(), create_dummy_fd(), create_dummy_fd()];
938-
match fds_and_pipeid(fds, 0) {
939-
Ok(([datafd, errfd], [])) => unreachable!("{datafd:?} {errfd:?}"),
962+
match DualFds::from_reply(fds, 0) {
963+
Ok(v) => unreachable!("{v:?}"),
940964
Err(Error::Other(msg)) => {
941-
assert_eq!(msg.as_ref(), "Expected 2 OwnedFd but got 3")
965+
assert_eq!(msg.as_ref(), "More than two fds for DualFds")
942966
}
943967
Err(other) => unreachable!("{other}"),
944968
}
@@ -947,10 +971,10 @@ mod tests {
947971
#[test]
948972
fn test_new_from_raw_values_error_fd_with_zero_pipeid() {
949973
let fds = vec![create_dummy_fd()];
950-
match fds_and_pipeid(fds, 0) {
951-
Ok(([datafd], [pipeid])) => unreachable!("{datafd:?} {pipeid:?}"),
974+
match FinishPipe::from_reply(fds, 0) {
975+
Ok(v) => unreachable!("{v:?}"),
952976
Err(Error::Other(msg)) => {
953-
assert_eq!(msg.as_ref(), "Expected 1 PipeId but got 0")
977+
assert_eq!(msg.as_ref(), "Expected pipeid for FinishPipe")
954978
}
955979
Err(other) => unreachable!("{other}"),
956980
}
@@ -959,10 +983,10 @@ mod tests {
959983
#[test]
960984
fn test_new_from_raw_values_error_pipeid_with_both_fds() {
961985
let fds = vec![create_dummy_fd(), create_dummy_fd()];
962-
match fds_and_pipeid(fds, 1) {
963-
Ok(([datafd, errfd], [])) => unreachable!("{datafd:?} {errfd:?}"),
986+
match DualFds::from_reply(fds, 1) {
987+
Ok(v) => unreachable!("{v:?}"),
964988
Err(Error::Other(msg)) => {
965-
assert_eq!(msg.as_ref(), "Expected 0 PipeId but got 1")
989+
assert_eq!(msg.as_ref(), "Unexpected pipeid with DualFds")
966990
}
967991
Err(other) => unreachable!("{other}"),
968992
}
@@ -971,10 +995,10 @@ mod tests {
971995
#[test]
972996
fn test_new_from_raw_values_error_no_fd_with_pipeid() {
973997
let fds: Vec<OwnedFd> = vec![];
974-
match fds_and_pipeid(fds, 1) {
975-
Ok(([datafd], [pipeid])) => unreachable!("{datafd:?} {pipeid:?}"),
998+
match FinishPipe::from_reply(fds, 1) {
999+
Ok(v) => unreachable!("{v:?}"),
9761000
Err(Error::Other(msg)) => {
977-
assert_eq!(msg.as_ref(), "Expected 1 OwnedFd but got 0")
1001+
assert_eq!(msg.as_ref(), "Expected fd for FinishPipe")
9781002
}
9791003
Err(other) => unreachable!("{other}"),
9801004
}

0 commit comments

Comments
 (0)