Skip to content

Commit 21fe5dc

Browse files
committed
dma.flush
1 parent c40f400 commit 21fe5dc

File tree

2 files changed

+200
-37
lines changed

2 files changed

+200
-37
lines changed

src/common/base/src/base/dma.rs

Lines changed: 195 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use std::fmt;
2020
use std::io;
2121
use std::io::IoSlice;
2222
use std::io::SeekFrom;
23+
use std::io::Write;
2324
use std::ops::Range;
2425
use std::os::fd::AsFd;
2526
use std::os::fd::BorrowedFd;
@@ -199,7 +200,7 @@ pub struct DmaFile<F> {
199200
fd: F,
200201
alignment: Alignment,
201202
buf: Option<DmaBuffer>,
202-
length: usize,
203+
written: usize,
203204
}
204205

205206
impl<F: AsFd> DmaFile<F> {
@@ -232,19 +233,37 @@ impl<F: AsFd> DmaFile<F> {
232233
}
233234

234235
fn write_direct(&mut self) -> io::Result<usize> {
235-
let buf = self.buffer();
236-
let buf_size = buf.len();
237-
match rustix::io::write(&self.fd, buf) {
238-
Ok(n) => {
239-
self.length += n;
240-
if n != buf_size {
241-
return Err(io::Error::other("short write"));
236+
let buf = self.buf.as_ref().unwrap().as_slice();
237+
let mut written = 0;
238+
239+
while written < buf.len() {
240+
match rustix::io::write(&self.fd, &buf[written..]) {
241+
Ok(0) => {
242+
return Err(io::Error::new(
243+
io::ErrorKind::WriteZero,
244+
"write returned zero bytes",
245+
));
246+
}
247+
Ok(n) => {
248+
written += n;
249+
}
250+
Err(err) => {
251+
if err.kind() == io::ErrorKind::Interrupted {
252+
continue;
253+
}
254+
return Err(err.into());
242255
}
243-
self.mut_buffer().clear();
244-
Ok(n)
245256
}
246-
Err(e) => Err(e.into()),
247257
}
258+
self.inc_written(written);
259+
self.mut_buffer().clear();
260+
Ok(written)
261+
}
262+
263+
fn inc_written(&mut self, n: usize) {
264+
debug_assert!(n >= self.alignment.as_usize());
265+
debug_assert_eq!(n, self.alignment.align_down(n));
266+
self.written = self.align_down(self.written) + n;
248267
}
249268

250269
fn read_direct(&mut self, n: usize) -> io::Result<usize> {
@@ -273,7 +292,7 @@ impl<F: AsFd> DmaFile<F> {
273292
}
274293

275294
pub fn length(&self) -> usize {
276-
self.length
295+
self.written
277296
}
278297
}
279298

@@ -344,7 +363,7 @@ impl AsyncDmaFile {
344363
fd: file,
345364
alignment,
346365
buf: None,
347-
length: 0,
366+
written: 0,
348367
})
349368
}
350369

@@ -380,7 +399,7 @@ impl AsyncDmaFile {
380399
fd: unsafe { BorrowedFd::borrow_raw(fd) },
381400
alignment,
382401
buf: Some(buf),
383-
length: 0,
402+
written: 0,
384403
};
385404
file.read_direct(remain).map(|n| (file.buf.unwrap(), n))
386405
})
@@ -411,12 +430,12 @@ impl SyncDmaFile {
411430

412431
fn create_fd(path: impl rustix::path::Arg, dio: bool) -> io::Result<OwnedFd> {
413432
let flags = if cfg!(target_os = "linux") && dio {
414-
OFlags::EXCL | OFlags::CREATE | OFlags::TRUNC | OFlags::DIRECT
433+
OFlags::EXCL | OFlags::CREATE | OFlags::TRUNC | OFlags::RDWR | OFlags::DIRECT
415434
} else {
416-
OFlags::EXCL | OFlags::CREATE | OFlags::TRUNC
435+
OFlags::EXCL | OFlags::CREATE | OFlags::TRUNC | OFlags::RDWR
417436
};
418437

419-
rustix::fs::open(path, flags, rustix::fs::Mode::empty()).map_err(|e| e.into())
438+
rustix::fs::open(path, flags, rustix::fs::Mode::from_raw_mode(0o666)).map_err(|e| e.into())
420439
}
421440

422441
fn open_dma(fd: OwnedFd) -> io::Result<DmaFile<OwnedFd>> {
@@ -427,7 +446,7 @@ impl SyncDmaFile {
427446
fd,
428447
alignment,
429448
buf: None,
430-
length: 0,
449+
written: 0,
431450
})
432451
}
433452

@@ -485,7 +504,7 @@ impl DmaWriteBuf {
485504
fd: AsyncDmaFile::create_fd(path, dio).await?,
486505
alignment: self.allocator.0,
487506
buf: None,
488-
length: 0,
507+
written: 0,
489508
};
490509

491510
let file_length = self.size();
@@ -572,8 +591,8 @@ impl DmaWriteBuf {
572591

573592
let len = data.len() * self.chunk;
574593

575-
let bufs = data.iter().map(|buf| IoSlice::new(buf)).collect::<Vec<_>>();
576-
let written = rustix::io::writev(&file.fd, &bufs)?;
594+
let mut io_slices: Vec<_> = data.iter().map(|buf| IoSlice::new(buf)).collect();
595+
let written = writev_all(&file.fd, &mut io_slices)?;
577596

578597
let last = self.data.pop();
579598
self.data.clear();
@@ -584,7 +603,7 @@ impl DmaWriteBuf {
584603
_ => (),
585604
}
586605

587-
file.length += written;
606+
file.inc_written(written);
588607

589608
if written != len {
590609
Err(io::Error::other("short write"))
@@ -618,26 +637,108 @@ impl DmaWriteBuf {
618637
None => unreachable!(),
619638
};
620639
let len = self.data.len() * self.chunk - diff;
621-
let bufs = self
622-
.data
623-
.iter()
624-
.map(|buf| IoSlice::new(buf))
625-
.collect::<Vec<_>>();
626640

627-
let written = rustix::io::writev(&file.fd, &bufs)?;
641+
let mut io_slices: Vec<_> = self.data.iter().map(|buf| IoSlice::new(buf)).collect();
642+
let written = writev_all(&file.fd, &mut io_slices)?;
628643
if written != len {
629644
return Err(io::Error::other("short write"));
630645
}
631646

632647
if to_truncate == 0 {
633-
file.length += written;
648+
file.inc_written(written);
634649
return Ok(written);
635650
}
636651

637-
file.length -= to_truncate;
638-
file.truncate(file.length)?;
652+
file.written -= to_truncate;
653+
file.truncate(file.written)?;
639654
Ok(written - to_truncate)
640655
}
656+
657+
pub fn flush(&mut self, file: &mut SyncDmaFile) -> io::Result<()> {
658+
debug_assert_eq!(self.allocator.0, file.alignment);
659+
660+
if self.data.is_empty() {
661+
return Ok(());
662+
}
663+
664+
let last = self
665+
.data
666+
.pop_if(|last| file.align_up(last.len()) > last.len());
667+
668+
let last = if let Some(mut last) = last {
669+
if self.data.is_empty() {
670+
use std::cmp::Ordering::*;
671+
match (file.written - file.align_down(file.written)).cmp(&last.len()) {
672+
Equal => return Ok(()),
673+
Greater => unreachable!(),
674+
Less => {}
675+
}
676+
}
677+
let len = last.len();
678+
let align_up = file.align_up(len);
679+
let pad = align_up - len;
680+
debug_assert!(pad != 0);
681+
unsafe { last.set_len(align_up) };
682+
Some((last, len, pad))
683+
} else {
684+
None
685+
};
686+
687+
let mut slices: Vec<_> = self
688+
.data
689+
.iter()
690+
.map(|buf| IoSlice::new(buf))
691+
.chain(last.as_ref().map(|last| IoSlice::new(&last.0)))
692+
.collect();
693+
let written = writev_all(&file.fd, &mut slices[..])?;
694+
self.data.clear();
695+
696+
file.inc_written(written);
697+
698+
if let Some((last, len, pad)) = last.as_ref() {
699+
let len = *len;
700+
let pad = *pad;
701+
file.written -= pad;
702+
703+
file.truncate(file.written)?;
704+
let last_align = file.align_down(file.written);
705+
rustix::fs::seek(&file.fd, rustix::fs::SeekFrom::Start(last_align as _))
706+
.map_err(io::Error::from)?;
707+
708+
debug_assert_eq!(pad, file.align_up(file.written) - file.written);
709+
710+
self.write(&last[file.align_down(len)..(file.align_up(len) - pad)])?;
711+
}
712+
713+
Ok(())
714+
}
715+
}
716+
717+
fn writev_all(fd: impl AsFd, mut slices: &mut [IoSlice<'_>]) -> io::Result<usize> {
718+
let mut written = 0;
719+
720+
while !slices.is_empty() {
721+
let n = match rustix::io::writev(fd.as_fd(), slices) {
722+
Ok(0) => {
723+
return Err(io::Error::new(
724+
io::ErrorKind::WriteZero,
725+
"writev returned zero bytes",
726+
));
727+
}
728+
Ok(n) => n,
729+
Err(err) => {
730+
if err.kind() == io::ErrorKind::Interrupted {
731+
continue;
732+
}
733+
return Err(err.into());
734+
}
735+
};
736+
737+
written += n;
738+
IoSlice::advance_slices(&mut slices, n);
739+
}
740+
741+
Ok(written)
641742
}
642743

643744
impl io::Write for DmaWriteBuf {
@@ -768,6 +869,7 @@ pub async fn dma_read_file_range(
768869

769870
#[cfg(test)]
770871
mod tests {
872+
use std::io::Read;
771873
use std::io::Write;
772874

773875
use super::*;
@@ -928,4 +1030,66 @@ mod tests {
9281030
let buf = got.to_vec();
9291031
println!("{:?} {}", buf.as_ptr(), buf.capacity());
9301032
}
1033+
1034+
#[test]
1035+
fn test_write() -> io::Result<()> {
1036+
let filename = "test_file";
1037+
let _ = std::fs::remove_file(filename);
1038+
let mut file = SyncDmaFile::create(filename, true)?;
1039+
1040+
let mut buf = DmaWriteBuf::new(file.alignment, file.alignment.as_usize() * 2);
1041+
1042+
{
1043+
buf.write(b"1")?;
1044+
buf.flush(&mut file)?;
1045+
1046+
assert_eq!(file.written, 1);
1047+
1048+
let mut got = Vec::new();
1049+
let mut read = std::fs::File::open(filename)?;
1050+
let n = read.read_to_end(&mut got)?;
1051+
assert_eq!(n, 1);
1052+
1053+
assert_eq!(b"1".as_slice(), got.as_slice());
1054+
}
1055+
1056+
{
1057+
buf.write(b"2")?;
1058+
buf.write(b"3")?;
1059+
buf.flush(&mut file)?;
1060+
1061+
assert_eq!(file.written, 3);
1062+
1063+
let mut got = Vec::new();
1064+
let mut read = std::fs::File::open(filename)?;
1065+
let n = read.read_to_end(&mut got)?;
1066+
assert_eq!(n, 3);
1067+
1068+
assert_eq!(b"123".as_slice(), got.as_slice());
1069+
}
1070+
1071+
{
1072+
let data: Vec<_> = b"123"
1073+
.iter()
1074+
.copied()
1075+
.cycle()
1076+
.take(file.alignment.as_usize() * 3)
1077+
.collect();
1078+
1079+
buf.write(&data)?;
1080+
buf.flush(&mut file)?;
1081+
1082+
assert_eq!(file.written, 3 + data.len());
1083+
1084+
let mut got = Vec::new();
1085+
let mut read = std::fs::File::open(filename)?;
1086+
let n = read.read_to_end(&mut got)?;
1087+
assert_eq!(n, 3 + data.len());
1088+
1089+
let want: Vec<_> = [&b"123"[..], &data].concat();
1090+
assert_eq!(want.as_slice(), got.as_slice());
1091+
}
1092+
1093+
Ok(())
1094+
}
9311095
}

src/query/service/src/spillers/union_file.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,10 @@ impl UnionFileWriter {
279279
) => {
280280
let dma = local.buf.as_mut().unwrap();
281281

282-
let file = local.file.take().unwrap();
283-
let file_size = file.length() + dma.size();
284-
dma.flush_and_close(file)?;
282+
let mut file = local.file.take().unwrap();
283+
dma.flush(&mut file)?;
284+
let file_size = file.length();
285+
drop(file);
285286

286287
local.path.set_size(file_size).unwrap();
287288

@@ -368,9 +369,7 @@ impl io::Write for UnionFileWriter {
368369
..
369370
}) = &mut self.local
370371
{
371-
// warning: not completely flushed, data may be lost
372-
dma.flush_full_buffer(file)?;
373-
return Ok(());
372+
return dma.flush(file);
374373
}
375374

376375
self.remote_writer.as_mut().unwrap().flush()

0 commit comments

Comments
 (0)