@@ -58,7 +58,7 @@ public async Task UnloadAsync(CancellationToken cancellationToken = default)
5858 public async Task < ImageTensor > RunAsync ( BackgroundImageOptions options , IProgress < RunProgress > progressCallback = null , CancellationToken cancellationToken = default )
5959 {
6060 var timestamp = RunProgress . GetTimestamp ( ) ;
61- var resultTensor = await ExtractBackgroundInternalAsync ( options . Mode , options . Image , cancellationToken ) ;
61+ var resultTensor = await ExtractBackgroundInternalAsync ( options . Mode , options . IsTransparentSupported , options . Image , cancellationToken ) ;
6262 progressCallback ? . Report ( new RunProgress ( timestamp ) ) ;
6363 return resultTensor ;
6464 }
@@ -78,7 +78,7 @@ public void Dispose()
7878 /// </summary>
7979 /// <param name="imageInput">The image tensor.</param>
8080 /// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
81- private async Task < ImageTensor > ExtractBackgroundInternalAsync ( BackgroundMode backgroundMode , ImageTensor imageInput , CancellationToken cancellationToken = default )
81+ private async Task < ImageTensor > ExtractBackgroundInternalAsync ( BackgroundMode backgroundMode , bool isTranparentSupported , ImageTensor imageInput , CancellationToken cancellationToken = default )
8282 {
8383 var metadata = await _model . LoadAsync ( cancellationToken : cancellationToken ) ;
8484 cancellationToken . ThrowIfCancellationRequested ( ) ;
@@ -104,16 +104,31 @@ private async Task<ImageTensor> ExtractBackgroundInternalAsync(BackgroundMode ba
104104
105105 // Normalize
106106 outputTensor . Normalize ( _model . OutputNormalization ) ;
107- if ( backgroundMode == BackgroundMode . MaskBackground || backgroundMode == BackgroundMode . RemoveForeground )
108- outputTensor . Invert ( ) ;
109107
110- // Output Image
111- var outputImage = backgroundMode == BackgroundMode . RemoveBackground || backgroundMode == BackgroundMode . RemoveForeground
112- ? inputTensor . CloneAs ( )
113- : new ImageTensor ( inputTensor . Height , inputTensor . Width , - 1 ) ;
114-
115- // Set Alpha
116- outputImage . UpdateAlphaChannel ( outputTensor . Span ) ;
108+ // Process Image
109+ var outputImage = default ( ImageTensor ) ;
110+ if ( backgroundMode == BackgroundMode . MaskForeground || backgroundMode == BackgroundMode . MaskBackground )
111+ {
112+ if ( backgroundMode == BackgroundMode . MaskBackground )
113+ outputTensor . Invert ( ) ;
114+ outputImage = new ImageTensor ( inputTensor . Height , inputTensor . Width , - 1 ) ;
115+ }
116+ else if ( backgroundMode == BackgroundMode . RemoveBackground || backgroundMode == BackgroundMode . RemoveForeground )
117+ {
118+ if ( backgroundMode == BackgroundMode . RemoveForeground )
119+ outputTensor . Invert ( ) ;
120+ outputImage = inputTensor . CloneAs ( ) ;
121+ }
122+
123+ // Set Alpha Channel
124+ if ( isTranparentSupported )
125+ {
126+ outputImage . UpdateAlphaChannel ( outputTensor . Span ) ;
127+ }
128+ else
129+ {
130+ outputImage . FlattenAlphaChannel ( outputTensor . Span ) ;
131+ }
117132
118133 // Resize Output
119134 if ( outputImage . Width != imageInput . Width || outputImage . Height != imageInput . Height )
@@ -132,7 +147,7 @@ private async Task<ImageTensor> ExtractBackgroundInternalAsync(BackgroundMode ba
132147 /// <returns>BackgroundPipeline.</returns>
133148 public static BackgroundPipeline Create ( ExtractorConfig configuration )
134149 {
135- return new BackgroundPipeline ( ExtractorModel . Create ( configuration ) ) ;
150+ return new BackgroundPipeline ( ExtractorModel . Create ( configuration ) ) ;
136151 }
137152 }
138153}
0 commit comments