22// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33
44using DocumentFormat . OpenXml . Features ;
5+ using DocumentFormat . OpenXml . Framework ;
56using System ;
67using System . Collections . Generic ;
78using System . Diagnostics ;
89using System . Diagnostics . CodeAnalysis ;
910using System . IO ;
10- using System . IO . Packaging ;
11+ using System . Threading ;
1112
1213namespace DocumentFormat . OpenXml . Packaging
1314{
@@ -236,9 +237,7 @@ public IEnumerable<OpenXmlPart> GetParentParts()
236237 /// <returns>The content stream of the part. </returns>
237238 public Stream GetStream ( FileMode mode )
238239 {
239- ThrowIfObjectDisposed ( ) ;
240-
241- return PackagePart . GetStream ( mode , Features . GetRequired < IPackageFeature > ( ) . Package . FileOpenAccess ) ;
240+ return GetStream ( mode , Features . GetRequired < IPackageFeature > ( ) . Package . FileOpenAccess ) ;
242241 }
243242
244243 /// <summary>
@@ -251,7 +250,20 @@ public Stream GetStream(FileMode mode, FileAccess access)
251250 {
252251 ThrowIfObjectDisposed ( ) ;
253252
254- return PackagePart . GetStream ( mode , access ) ;
253+ var stream = PackagePart . GetStream ( mode , access ) ;
254+
255+ if ( mode is FileMode . Create || stream . Length == 0 )
256+ {
257+ UnloadRootElement ( ) ;
258+ return new UnloadingRootElementStream ( this , stream ) ;
259+ }
260+
261+ if ( stream . CanWrite )
262+ {
263+ return new UnloadingRootElementStream ( this , stream ) ;
264+ }
265+
266+ return stream ;
255267 }
256268
257269 /// <summary>
@@ -605,5 +617,76 @@ internal sealed override OpenXmlPart ThisOpenXmlPart
605617 internal MarkupCompatibilityProcessSettings ? MCSettings { get ; set ; }
606618
607619 #endregion
620+
621+ /// <summary>
622+ /// A <see cref="Stream"/> used by <see cref="GetStream(FileMode, FileAccess)" /> to unload the root if updated.
623+ /// </summary>
624+ private sealed class UnloadingRootElementStream : DelegatingStream
625+ {
626+ private readonly OpenXmlPart _part ;
627+
628+ private bool _hasWritten ;
629+
630+ public UnloadingRootElementStream ( OpenXmlPart part , Stream innerStream )
631+ : base ( innerStream )
632+ {
633+ _part = part ;
634+ }
635+
636+ protected override void Dispose ( bool disposing )
637+ {
638+ if ( disposing && _hasWritten )
639+ {
640+ _part . UnloadRootElement ( ) ;
641+ }
642+
643+ base . Dispose ( disposing ) ;
644+ }
645+
646+ public override void Write ( byte [ ] buffer , int offset , int count )
647+ {
648+ NotifyOfWrite ( ) ;
649+ base . Write ( buffer , offset , count ) ;
650+ }
651+
652+ #if NET46_OR_GREATER || NET || NETSTANDARD
653+ public override System . Threading . Tasks . Task WriteAsync ( byte [ ] buffer , int offset , int count , CancellationToken cancellationToken )
654+ {
655+ NotifyOfWrite ( ) ;
656+ return base . WriteAsync ( buffer , offset , count , cancellationToken ) ;
657+ }
658+ #endif
659+
660+ #if NET6_0_OR_GREATER
661+ public override void Write ( ReadOnlySpan < byte > buffer )
662+ {
663+ NotifyOfWrite ( ) ;
664+ base . Write ( buffer ) ;
665+ }
666+
667+ public override System . Threading . Tasks . ValueTask WriteAsync ( ReadOnlyMemory < byte > buffer , CancellationToken cancellationToken = default )
668+ {
669+ NotifyOfWrite ( ) ;
670+ return base . WriteAsync ( buffer , cancellationToken ) ;
671+ }
672+ #endif
673+
674+ public override void WriteByte ( byte value )
675+ {
676+ NotifyOfWrite ( ) ;
677+ base . WriteByte ( value ) ;
678+ }
679+
680+ public override IAsyncResult BeginWrite ( byte [ ] buffer , int offset , int count , AsyncCallback ? callback , object ? state )
681+ {
682+ NotifyOfWrite ( ) ;
683+ return base . BeginWrite ( buffer , offset , count , callback , state ) ;
684+ }
685+
686+ private void NotifyOfWrite ( )
687+ {
688+ _hasWritten = true ;
689+ }
690+ }
608691 }
609692}
0 commit comments