@@ -20,6 +20,7 @@ use std::fmt;
20
20
use std:: io;
21
21
use std:: io:: IoSlice ;
22
22
use std:: io:: SeekFrom ;
23
+ use std:: io:: Write ;
23
24
use std:: ops:: Range ;
24
25
use std:: os:: fd:: AsFd ;
25
26
use std:: os:: fd:: BorrowedFd ;
@@ -199,7 +200,7 @@ pub struct DmaFile<F> {
199
200
fd : F ,
200
201
alignment : Alignment ,
201
202
buf : Option < DmaBuffer > ,
202
- length : usize ,
203
+ written : usize ,
203
204
}
204
205
205
206
impl < F : AsFd > DmaFile < F > {
@@ -232,19 +233,37 @@ impl<F: AsFd> DmaFile<F> {
232
233
}
233
234
234
235
fn write_direct ( & mut self ) -> io:: Result < usize > {
235
- let buf = self . buffer ( ) ;
236
- let buf_size = buf. len ( ) ;
237
- match rustix:: io:: write ( & self . fd , buf) {
238
- Ok ( n) => {
239
- self . length += n;
240
- if n != buf_size {
241
- return Err ( io:: Error :: other ( "short write" ) ) ;
236
+ let buf = self . buf . as_ref ( ) . unwrap ( ) . as_slice ( ) ;
237
+ let mut written = 0 ;
238
+
239
+ while written < buf. len ( ) {
240
+ match rustix:: io:: write ( & self . fd , & buf[ written..] ) {
241
+ Ok ( 0 ) => {
242
+ return Err ( io:: Error :: new (
243
+ io:: ErrorKind :: WriteZero ,
244
+ "write returned zero bytes" ,
245
+ ) ) ;
246
+ }
247
+ Ok ( n) => {
248
+ written += n;
249
+ }
250
+ Err ( err) => {
251
+ if err. kind ( ) == io:: ErrorKind :: Interrupted {
252
+ continue ;
253
+ }
254
+ return Err ( err. into ( ) ) ;
242
255
}
243
- self . mut_buffer ( ) . clear ( ) ;
244
- Ok ( n)
245
256
}
246
- Err ( e) => Err ( e. into ( ) ) ,
247
257
}
258
+ self . inc_written ( written) ;
259
+ self . mut_buffer ( ) . clear ( ) ;
260
+ Ok ( written)
261
+ }
262
+
263
+ fn inc_written ( & mut self , n : usize ) {
264
+ debug_assert ! ( n >= self . alignment. as_usize( ) ) ;
265
+ debug_assert_eq ! ( n, self . alignment. align_down( n) ) ;
266
+ self . written = self . align_down ( self . written ) + n;
248
267
}
249
268
250
269
fn read_direct ( & mut self , n : usize ) -> io:: Result < usize > {
@@ -273,7 +292,7 @@ impl<F: AsFd> DmaFile<F> {
273
292
}
274
293
275
294
pub fn length ( & self ) -> usize {
276
- self . length
295
+ self . written
277
296
}
278
297
}
279
298
@@ -344,7 +363,7 @@ impl AsyncDmaFile {
344
363
fd : file,
345
364
alignment,
346
365
buf : None ,
347
- length : 0 ,
366
+ written : 0 ,
348
367
} )
349
368
}
350
369
@@ -380,7 +399,7 @@ impl AsyncDmaFile {
380
399
fd : unsafe { BorrowedFd :: borrow_raw ( fd) } ,
381
400
alignment,
382
401
buf : Some ( buf) ,
383
- length : 0 ,
402
+ written : 0 ,
384
403
} ;
385
404
file. read_direct ( remain) . map ( |n| ( file. buf . unwrap ( ) , n) )
386
405
} )
@@ -411,12 +430,12 @@ impl SyncDmaFile {
411
430
412
431
fn create_fd ( path : impl rustix:: path:: Arg , dio : bool ) -> io:: Result < OwnedFd > {
413
432
let flags = if cfg ! ( target_os = "linux" ) && dio {
414
- OFlags :: EXCL | OFlags :: CREATE | OFlags :: TRUNC | OFlags :: DIRECT
433
+ OFlags :: EXCL | OFlags :: CREATE | OFlags :: TRUNC | OFlags :: RDWR | OFlags :: DIRECT
415
434
} else {
416
- OFlags :: EXCL | OFlags :: CREATE | OFlags :: TRUNC
435
+ OFlags :: EXCL | OFlags :: CREATE | OFlags :: TRUNC | OFlags :: RDWR
417
436
} ;
418
437
419
- rustix:: fs:: open ( path, flags, rustix:: fs:: Mode :: empty ( ) ) . map_err ( |e| e. into ( ) )
438
+ rustix:: fs:: open ( path, flags, rustix:: fs:: Mode :: from_raw_mode ( 0o666 ) ) . map_err ( |e| e. into ( ) )
420
439
}
421
440
422
441
fn open_dma ( fd : OwnedFd ) -> io:: Result < DmaFile < OwnedFd > > {
@@ -427,7 +446,7 @@ impl SyncDmaFile {
427
446
fd,
428
447
alignment,
429
448
buf : None ,
430
- length : 0 ,
449
+ written : 0 ,
431
450
} )
432
451
}
433
452
@@ -485,7 +504,7 @@ impl DmaWriteBuf {
485
504
fd : AsyncDmaFile :: create_fd ( path, dio) . await ?,
486
505
alignment : self . allocator . 0 ,
487
506
buf : None ,
488
- length : 0 ,
507
+ written : 0 ,
489
508
} ;
490
509
491
510
let file_length = self . size ( ) ;
@@ -572,8 +591,8 @@ impl DmaWriteBuf {
572
591
573
592
let len = data. len ( ) * self . chunk ;
574
593
575
- let bufs = data. iter ( ) . map ( |buf| IoSlice :: new ( buf) ) . collect :: < Vec < _ > > ( ) ;
576
- let written = rustix :: io :: writev ( & file. fd , & bufs ) ?;
594
+ let mut io_slices : Vec < _ > = data. iter ( ) . map ( |buf| IoSlice :: new ( buf) ) . collect ( ) ;
595
+ let written = writev_all ( & file. fd , & mut io_slices ) ?;
577
596
578
597
let last = self . data . pop ( ) ;
579
598
self . data . clear ( ) ;
@@ -584,7 +603,7 @@ impl DmaWriteBuf {
584
603
_ => ( ) ,
585
604
}
586
605
587
- file. length += written;
606
+ file. inc_written ( written) ;
588
607
589
608
if written != len {
590
609
Err ( io:: Error :: other ( "short write" ) )
@@ -618,26 +637,108 @@ impl DmaWriteBuf {
618
637
None => unreachable ! ( ) ,
619
638
} ;
620
639
let len = self . data . len ( ) * self . chunk - diff;
621
- let bufs = self
622
- . data
623
- . iter ( )
624
- . map ( |buf| IoSlice :: new ( buf) )
625
- . collect :: < Vec < _ > > ( ) ;
626
640
627
- let written = rustix:: io:: writev ( & file. fd , & bufs) ?;
641
+ let mut io_slices: Vec < _ > = self . data . iter ( ) . map ( |buf| IoSlice :: new ( buf) ) . collect ( ) ;
642
+ let written = writev_all ( & file. fd , & mut io_slices) ?;
628
643
if written != len {
629
644
return Err ( io:: Error :: other ( "short write" ) ) ;
630
645
}
631
646
632
647
if to_truncate == 0 {
633
- file. length += written;
648
+ file. inc_written ( written) ;
634
649
return Ok ( written) ;
635
650
}
636
651
637
- file. length -= to_truncate;
638
- file. truncate ( file. length ) ?;
652
+ file. written -= to_truncate;
653
+ file. truncate ( file. written ) ?;
639
654
Ok ( written - to_truncate)
640
655
}
656
+
657
+ pub fn flush ( & mut self , file : & mut SyncDmaFile ) -> io:: Result < ( ) > {
658
+ debug_assert_eq ! ( self . allocator. 0 , file. alignment) ;
659
+
660
+ if self . data . is_empty ( ) {
661
+ return Ok ( ( ) ) ;
662
+ }
663
+
664
+ let last = self
665
+ . data
666
+ . pop_if ( |last| file. align_up ( last. len ( ) ) > last. len ( ) ) ;
667
+
668
+ let last = if let Some ( mut last) = last {
669
+ if self . data . is_empty ( ) {
670
+ use std:: cmp:: Ordering :: * ;
671
+ match ( file. written - file. align_down ( file. written ) ) . cmp ( & last. len ( ) ) {
672
+ Equal => return Ok ( ( ) ) ,
673
+ Greater => unreachable ! ( ) ,
674
+ Less => { }
675
+ }
676
+ }
677
+ let len = last. len ( ) ;
678
+ let align_up = file. align_up ( len) ;
679
+ let pad = align_up - len;
680
+ debug_assert ! ( pad != 0 ) ;
681
+ unsafe { last. set_len ( align_up) } ;
682
+ Some ( ( last, len, pad) )
683
+ } else {
684
+ None
685
+ } ;
686
+
687
+ let mut slices: Vec < _ > = self
688
+ . data
689
+ . iter ( )
690
+ . map ( |buf| IoSlice :: new ( buf) )
691
+ . chain ( last. as_ref ( ) . map ( |last| IoSlice :: new ( & last. 0 ) ) )
692
+ . collect ( ) ;
693
+ let written = writev_all ( & file. fd , & mut slices[ ..] ) ?;
694
+ self . data . clear ( ) ;
695
+
696
+ file. inc_written ( written) ;
697
+
698
+ if let Some ( ( last, len, pad) ) = last. as_ref ( ) {
699
+ let len = * len;
700
+ let pad = * pad;
701
+ file. written -= pad;
702
+
703
+ file. truncate ( file. written ) ?;
704
+ let last_align = file. align_down ( file. written ) ;
705
+ rustix:: fs:: seek ( & file. fd , rustix:: fs:: SeekFrom :: Start ( last_align as _ ) )
706
+ . map_err ( io:: Error :: from) ?;
707
+
708
+ debug_assert_eq ! ( pad, file. align_up( file. written) - file. written) ;
709
+
710
+ self . write ( & last[ file. align_down ( len) ..( file. align_up ( len) - pad) ] ) ?;
711
+ }
712
+
713
+ Ok ( ( ) )
714
+ }
715
+ }
716
+
717
+ fn writev_all ( fd : impl AsFd , mut slices : & mut [ IoSlice < ' _ > ] ) -> io:: Result < usize > {
718
+ let mut written = 0 ;
719
+
720
+ while !slices. is_empty ( ) {
721
+ let n = match rustix:: io:: writev ( fd. as_fd ( ) , slices) {
722
+ Ok ( 0 ) => {
723
+ return Err ( io:: Error :: new (
724
+ io:: ErrorKind :: WriteZero ,
725
+ "writev returned zero bytes" ,
726
+ ) ) ;
727
+ }
728
+ Ok ( n) => n,
729
+ Err ( err) => {
730
+ if err. kind ( ) == io:: ErrorKind :: Interrupted {
731
+ continue ;
732
+ }
733
+ return Err ( err. into ( ) ) ;
734
+ }
735
+ } ;
736
+
737
+ written += n;
738
+ IoSlice :: advance_slices ( & mut slices, n) ;
739
+ }
740
+
741
+ Ok ( written)
641
742
}
642
743
643
744
impl io:: Write for DmaWriteBuf {
@@ -768,6 +869,7 @@ pub async fn dma_read_file_range(
768
869
769
870
#[ cfg( test) ]
770
871
mod tests {
872
+ use std:: io:: Read ;
771
873
use std:: io:: Write ;
772
874
773
875
use super :: * ;
@@ -928,4 +1030,66 @@ mod tests {
928
1030
let buf = got. to_vec ( ) ;
929
1031
println ! ( "{:?} {}" , buf. as_ptr( ) , buf. capacity( ) ) ;
930
1032
}
1033
+
1034
+ #[ test]
1035
+ fn test_write ( ) -> io:: Result < ( ) > {
1036
+ let filename = "test_file" ;
1037
+ let _ = std:: fs:: remove_file ( filename) ;
1038
+ let mut file = SyncDmaFile :: create ( filename, true ) ?;
1039
+
1040
+ let mut buf = DmaWriteBuf :: new ( file. alignment , file. alignment . as_usize ( ) * 2 ) ;
1041
+
1042
+ {
1043
+ buf. write ( b"1" ) ?;
1044
+ buf. flush ( & mut file) ?;
1045
+
1046
+ assert_eq ! ( file. written, 1 ) ;
1047
+
1048
+ let mut got = Vec :: new ( ) ;
1049
+ let mut read = std:: fs:: File :: open ( filename) ?;
1050
+ let n = read. read_to_end ( & mut got) ?;
1051
+ assert_eq ! ( n, 1 ) ;
1052
+
1053
+ assert_eq ! ( b"1" . as_slice( ) , got. as_slice( ) ) ;
1054
+ }
1055
+
1056
+ {
1057
+ buf. write ( b"2" ) ?;
1058
+ buf. write ( b"3" ) ?;
1059
+ buf. flush ( & mut file) ?;
1060
+
1061
+ assert_eq ! ( file. written, 3 ) ;
1062
+
1063
+ let mut got = Vec :: new ( ) ;
1064
+ let mut read = std:: fs:: File :: open ( filename) ?;
1065
+ let n = read. read_to_end ( & mut got) ?;
1066
+ assert_eq ! ( n, 3 ) ;
1067
+
1068
+ assert_eq ! ( b"123" . as_slice( ) , got. as_slice( ) ) ;
1069
+ }
1070
+
1071
+ {
1072
+ let data: Vec < _ > = b"123"
1073
+ . iter ( )
1074
+ . copied ( )
1075
+ . cycle ( )
1076
+ . take ( file. alignment . as_usize ( ) * 3 )
1077
+ . collect ( ) ;
1078
+
1079
+ buf. write ( & data) ?;
1080
+ buf. flush ( & mut file) ?;
1081
+
1082
+ assert_eq ! ( file. written, 3 + data. len( ) ) ;
1083
+
1084
+ let mut got = Vec :: new ( ) ;
1085
+ let mut read = std:: fs:: File :: open ( filename) ?;
1086
+ let n = read. read_to_end ( & mut got) ?;
1087
+ assert_eq ! ( n, 3 + data. len( ) ) ;
1088
+
1089
+ let want: Vec < _ > = [ & b"123" [ ..] , & data] . concat ( ) ;
1090
+ assert_eq ! ( want. as_slice( ) , got. as_slice( ) ) ;
1091
+ }
1092
+
1093
+ Ok ( ( ) )
1094
+ }
931
1095
}
0 commit comments