|
| 1 | +// Importing required namespaces for working with images, machine learning models, and tensors |
| 2 | +using System.Drawing; |
| 3 | +using System.Drawing.Imaging; |
| 4 | +using Microsoft.ML.OnnxRuntime; // For ONNX model inference |
| 5 | +using Microsoft.ML.OnnxRuntime.Tensors; // For working with tensors |
| 6 | + |
| 7 | +namespace RmbgSharp |
| 8 | +{ |
| 9 | + // This class is responsible for removing the background from an image using a machine learning model |
| 10 | + public class BackgroundRemover |
| 11 | + { |
| 12 | + // Private member to hold the ONNX inference session, which is used to run the model |
| 13 | + private InferenceSession _inferenceSession; |
| 14 | + |
| 15 | + // Constructor for the BackgroundRemover class. It initializes the inference session using an ONNX model |
| 16 | + public BackgroundRemover(string modelPath, bool useFP16, bool useGPU) |
| 17 | + { |
| 18 | + // Create a session options object to configure the ONNX model session |
| 19 | + SessionOptions sessionOptions = new SessionOptions(); |
| 20 | + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL; |
| 21 | + |
| 22 | + // If the user wants to use FP16 precision, set the optimization level for FP16 support |
| 23 | + if (useFP16) |
| 24 | + { |
| 25 | + sessionOptions.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_EXTENDED; |
| 26 | + } |
| 27 | + |
| 28 | + // If the user wants to use the GPU, try to enable the DirectML (DML) execution provider |
| 29 | + if (useGPU) |
| 30 | + { |
| 31 | + try |
| 32 | + { |
| 33 | + // Attempt to append the DML execution provider to the session options |
| 34 | + sessionOptions.AppendExecutionProvider_DML(); |
| 35 | + } |
| 36 | + catch |
| 37 | + { |
| 38 | + // In case of any failure (e.g., no GPU or incompatible hardware), the exception is silently caught |
| 39 | + } |
| 40 | + } |
| 41 | + |
| 42 | + // Initialize the inference session with the provided model path and session options |
| 43 | + _inferenceSession = new InferenceSession(modelPath, sessionOptions); |
| 44 | + } |
| 45 | + |
| 46 | + // Overloaded method to remove the background from a bitmap and save it to an output file |
| 47 | + public void RemoveBackground(Bitmap bitmap, string outputPath) |
| 48 | + { |
| 49 | + // Calls the RemoveBackground method and saves the result to the specified output path |
| 50 | + RemoveBackground(bitmap).Save(outputPath); |
| 51 | + } |
| 52 | + |
| 53 | + // Overloaded method to load an image from a file, remove its background, and save it to an output file |
| 54 | + public void RemoveBackground(string inputPath, string outputPath) |
| 55 | + { |
| 56 | + // Load the image from the file, convert it to a Bitmap, and remove its background |
| 57 | + RemoveBackground((Bitmap)Image.FromFile(inputPath), outputPath); |
| 58 | + } |
| 59 | + |
| 60 | + // Method to remove the background from an image represented by a Bitmap |
| 61 | + public Bitmap RemoveBackground(Bitmap bitmap) |
| 62 | + { |
| 63 | + // Preprocess the image (resize and normalize) and extract pixel data |
| 64 | + float[,,]? image = LoadAndPreprocessImage(bitmap, out var width, out var height); |
| 65 | + |
| 66 | + // Create a new tensor to hold the image data, reshaped for the model input |
| 67 | + DenseTensor<float>? inputTensor = new DenseTensor<float>(new[] { 1, 3, width, height }); |
| 68 | + |
| 69 | + // Fill the tensor with the preprocessed image data |
| 70 | + for (int c = 0; c < 3; c++) // Loop over color channels (RGB) |
| 71 | + { |
| 72 | + for (int y = 0; y < height; y++) // Loop over rows |
| 73 | + { |
| 74 | + for (int x = 0; x < width; x++) // Loop over columns |
| 75 | + { |
| 76 | + // Set the corresponding value in the tensor |
| 77 | + inputTensor[0, c, y, x] = image[c, y, x]; |
| 78 | + } |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + // Create the input tensor for the model, associating it with the name expected by the model |
| 83 | + List<NamedOnnxValue> inputs = new List<NamedOnnxValue> |
| 84 | + { |
| 85 | + NamedOnnxValue.CreateFromTensor("pixel_values", inputTensor) |
| 86 | + }; |
| 87 | + |
| 88 | + // Run the model on the input tensor and get the results |
| 89 | + IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = _inferenceSession.Run(inputs); |
| 90 | + |
| 91 | + // Extract the resulting tensor from the model's output |
| 92 | + Tensor<float>? resultTensor = results.First().AsTensor<float>(); |
| 93 | + |
| 94 | + // Convert the result tensor into a 1D array |
| 95 | + float[]? resultArray = resultTensor.ToArray(); |
| 96 | + |
| 97 | + // Create a 2D mask to store the model's output (binary mask for background removal) |
| 98 | + float[,]? mask = new float[height, width]; |
| 99 | + int idx = 0; |
| 100 | + |
| 101 | + // Populate the mask with the values from the result array |
| 102 | + for (int y = 0; y < height; y++) |
| 103 | + { |
| 104 | + for (int x = 0; x < width; x++) |
| 105 | + { |
| 106 | + mask[y, x] = resultArray[idx++]; |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + // Apply the mask to the image and return the result |
| 111 | + return ApplyMaskToImage(bitmap, mask); |
| 112 | + } |
| 113 | + |
| 114 | + // Helper method to preprocess the image by resizing it and normalizing the pixel values |
| 115 | + private float[,,] LoadAndPreprocessImage(Bitmap bitmap, out int width, out int height) |
| 116 | + { |
| 117 | + // Set the target width and height for the image resizing |
| 118 | + width = 1024; |
| 119 | + height = 1024; |
| 120 | + |
| 121 | + // Resize the image to match the target dimensions |
| 122 | + Bitmap? resizedImage = new Bitmap(bitmap, new Size(width, height)); |
| 123 | + |
| 124 | + // Initialize a 3D array to store the normalized RGB pixel values |
| 125 | + float[,,]? data = new float[3, height, width]; |
| 126 | + |
| 127 | + // Lock the bits of the resized image for efficient processing |
| 128 | + BitmapData bmpData = resizedImage.LockBits(new Rectangle(0, 0, width, height), ImageLockMode.ReadOnly, PixelFormat.Format32bppArgb); |
| 129 | + IntPtr ptr = bmpData.Scan0; // Pointer to the image data in memory |
| 130 | + int bytesPerPixel = Image.GetPixelFormatSize(bmpData.PixelFormat) / 8; // Number of bytes per pixel (4 for ARGB format) |
| 131 | + int stride = bmpData.Stride; // Stride (number of bytes per row) |
| 132 | + |
| 133 | + // Using unsafe code to directly manipulate image memory for faster processing |
| 134 | + unsafe |
| 135 | + { |
| 136 | + byte* dataPtr = (byte*)ptr; |
| 137 | + |
| 138 | + // Loop through each pixel in the image |
| 139 | + for (int y = 0; y < height; y++) |
| 140 | + { |
| 141 | + for (int x = 0; x < width; x++) |
| 142 | + { |
| 143 | + int offset = y * stride + x * bytesPerPixel; // Calculate the offset for the pixel |
| 144 | + |
| 145 | + // Extract the color channels (Blue, Green, Red) from the ARGB pixel |
| 146 | + byte b = dataPtr[offset]; |
| 147 | + byte g = dataPtr[offset + 1]; |
| 148 | + byte r = dataPtr[offset + 2]; |
| 149 | + |
| 150 | + // Normalize the color channels to the model's expected range |
| 151 | + data[0, y, x] = (r / 255f - 0.485f) / 0.229f; // Normalize Red channel |
| 152 | + data[1, y, x] = (g / 255f - 0.456f) / 0.224f; // Normalize Green channel |
| 153 | + data[2, y, x] = (b / 255f - 0.406f) / 0.225f; // Normalize Blue channel |
| 154 | + } |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + // Unlock the bits after processing |
| 159 | + resizedImage.UnlockBits(bmpData); |
| 160 | + |
| 161 | + // Return the preprocessed data |
| 162 | + return data; |
| 163 | + } |
| 164 | + |
| 165 | + // Helper method to apply a mask to an image and return the resulting image with background removed |
| 166 | + private Bitmap ApplyMaskToImage(Bitmap bitmap, float[,] mask) |
| 167 | + { |
| 168 | + // Get the dimensions of the original image |
| 169 | + int width = bitmap.Width; |
| 170 | + int height = bitmap.Height; |
| 171 | + |
| 172 | + // Create a new bitmap to store the result |
| 173 | + Bitmap? resultImage = new Bitmap(width, height); |
| 174 | + |
| 175 | + // Resize the mask to match the dimensions of the original image |
| 176 | + float[,]? resizedMask = ResizeMask(mask, width, height); |
| 177 | + |
| 178 | + // Lock the bits of the original image and the result image for efficient processing |
| 179 | + BitmapData bmpData = bitmap.LockBits(new Rectangle(0, 0, width, height), ImageLockMode.ReadOnly, PixelFormat.Format32bppArgb); |
| 180 | + BitmapData resultData = resultImage.LockBits(new Rectangle(0, 0, width, height), ImageLockMode.WriteOnly, PixelFormat.Format32bppArgb); |
| 181 | + |
| 182 | + IntPtr ptr = bmpData.Scan0; // Pointer to the original image data |
| 183 | + IntPtr resultPtr = resultData.Scan0; // Pointer to the result image data |
| 184 | + int bytesPerPixel = Image.GetPixelFormatSize(bmpData.PixelFormat) / 8; // Bytes per pixel |
| 185 | + int stride = bmpData.Stride; // Stride (number of bytes per row) |
| 186 | + |
| 187 | + // Using unsafe code for efficient pixel manipulation |
| 188 | + unsafe |
| 189 | + { |
| 190 | + byte* dataPtr = (byte*)ptr; |
| 191 | + byte* resultDataPtr = (byte*)resultPtr; |
| 192 | + |
| 193 | + // Loop through each pixel in the image |
| 194 | + for (int y = 0; y < height; y++) |
| 195 | + { |
| 196 | + for (int x = 0; x < width; x++) |
| 197 | + { |
| 198 | + int offset = y * stride + x * bytesPerPixel; // Calculate the pixel offset |
| 199 | + |
| 200 | + // Get the original color channels (Blue, Green, Red) |
| 201 | + byte b = dataPtr[offset]; |
| 202 | + byte g = dataPtr[offset + 1]; |
| 203 | + byte r = dataPtr[offset + 2]; |
| 204 | + |
| 205 | + // Get the mask value at the current pixel |
| 206 | + var maskValue = resizedMask[y, x]; |
| 207 | + |
| 208 | + // Calculate the alpha channel based on the mask (0 for transparent, 255 for opaque) |
| 209 | + var alpha = (int)(maskValue * 255); |
| 210 | + |
| 211 | + // Set the pixel values in the result image (RGBA format) |
| 212 | + resultDataPtr[offset] = b; |
| 213 | + resultDataPtr[offset + 1] = g; |
| 214 | + resultDataPtr[offset + 2] = r; |
| 215 | + resultDataPtr[offset + 3] = (byte)alpha; |
| 216 | + } |
| 217 | + } |
| 218 | + } |
| 219 | + |
| 220 | + // Unlock the bits after processing |
| 221 | + bitmap.UnlockBits(bmpData); |
| 222 | + resultImage.UnlockBits(resultData); |
| 223 | + |
| 224 | + // Return the result image with the background removed |
| 225 | + return resultImage; |
| 226 | + } |
| 227 | + |
| 228 | + // Helper method to resize the mask to match the target dimensions (width and height) |
| 229 | + private float[,] ResizeMask(float[,] mask, int targetWidth, int targetHeight) |
| 230 | + { |
| 231 | + // Get the original dimensions of the mask |
| 232 | + int maskHeight = mask.GetLength(0); |
| 233 | + int maskWidth = mask.GetLength(1); |
| 234 | + |
| 235 | + // Create a new array to hold the resized mask |
| 236 | + float[,]? resizedMask = new float[targetHeight, targetWidth]; |
| 237 | + |
| 238 | + // Loop through each pixel in the target size |
| 239 | + for (int y = 0; y < targetHeight; y++) |
| 240 | + { |
| 241 | + for (int x = 0; x < targetWidth; x++) |
| 242 | + { |
| 243 | + // Calculate the corresponding position in the original mask |
| 244 | + float origY = (float)y / targetHeight * maskHeight; |
| 245 | + float origX = (float)x / targetWidth * maskWidth; |
| 246 | + |
| 247 | + // Find the nearest neighboring pixels (bilinear interpolation) |
| 248 | + int y0 = (int)origY; |
| 249 | + int x0 = (int)origX; |
| 250 | + |
| 251 | + // Clamp the values to ensure they are within bounds |
| 252 | + int y1 = Math.Min(y0 + 1, maskHeight - 1); |
| 253 | + int x1 = Math.Min(x0 + 1, maskWidth - 1); |
| 254 | + |
| 255 | + // Get the values of the neighboring pixels in the original mask |
| 256 | + float v00 = mask[y0, x0]; |
| 257 | + float v01 = mask[y0, x1]; |
| 258 | + float v10 = mask[y1, x0]; |
| 259 | + float v11 = mask[y1, x1]; |
| 260 | + |
| 261 | + // Calculate the fractional differences |
| 262 | + float xFrac = origX - x0; |
| 263 | + float yFrac = origY - y0; |
| 264 | + |
| 265 | + // Perform bilinear interpolation |
| 266 | + float interpolatedValue = (1 - xFrac) * (1 - yFrac) * v00 + |
| 267 | + xFrac * (1 - yFrac) * v01 + |
| 268 | + (1 - xFrac) * yFrac * v10 + |
| 269 | + xFrac * yFrac * v11; |
| 270 | + |
| 271 | + // Set the interpolated value in the resized mask |
| 272 | + resizedMask[y, x] = interpolatedValue; |
| 273 | + } |
| 274 | + } |
| 275 | + |
| 276 | + // Return the resized mask |
| 277 | + return resizedMask; |
| 278 | + } |
| 279 | + } |
| 280 | +} |
0 commit comments