@@ -151,6 +151,30 @@ template <> winrt::hstring ConvertToPointerType<TensorKind::String>(winrt::hstri
151151 return static_cast <winrt::hstring>(value);
152152};
153153
154+ static ColorManagementMode GetColorManagementMode (const LearningModel& model)
155+
156+ {
157+ // Get model color space gamma
158+ hstring gammaSpace = L" " ;
159+ try
160+ {
161+ gammaSpace = model.Metadata ().Lookup (L" Image.ColorSpaceGamma" );
162+ }
163+ catch (...)
164+ {
165+ printf (" Model does not have color space gamma information. Will color manage to sRGB by default...\n " );
166+ }
167+ if (gammaSpace == L" " || _wcsicmp (gammaSpace.c_str (), L" SRGB" ) == 0 )
168+ {
169+ return ColorManagementMode::ColorManageToSRgb;
170+ }
171+ // Due diligence should be done to make sure that the input image is within the model's colorspace. There are
172+ // multiple non-sRGB color spaces.
173+ printf (" Model metadata indicates that color gamma space is : %ws. Will not manage color space to sRGB...\n " ,
174+ gammaSpace.c_str ());
175+ return ColorManagementMode::DoNotColorManage;
176+ }
177+
154178void GetHeightAndWidthFromLearningModelFeatureDescriptor (const ILearningModelFeatureDescriptor& modelFeatureDescriptor,
155179 uint64_t & width, uint64_t & height)
156180{
@@ -214,24 +238,35 @@ namespace BindingUtilities
214238
215239 SoftwareBitmap LoadImageFile (const ILearningModelFeatureDescriptor& modelFeatureDescriptor,
216240 const InputDataType inputDataType, const hstring& filePath,
217- const CommandLineArgs& args, uint32_t iterationNum)
241+ const CommandLineArgs& args, uint32_t iterationNum,
242+ ColorManagementMode colorManagementMode)
218243 {
219244 // We assume NCHW and NCDHW
220245 uint64_t width = 0 ;
221246 uint64_t height = 0 ;
222247 GetHeightAndWidthFromLearningModelFeatureDescriptor (modelFeatureDescriptor, width, height);
248+ IRandomAccessStream stream;
249+ BitmapDecoder decoder = NULL ;
223250 try
224251 {
225252 // open the file
226253 StorageFile file = StorageFile::GetFileFromPathAsync (filePath).get ();
227254 // get a stream on it
228- auto stream = file.OpenAsync (FileAccessMode::Read).get ();
255+ stream = file.OpenAsync (FileAccessMode::Read).get ();
229256 // Create the decoder from the stream
230- BitmapDecoder decoder = BitmapDecoder::CreateAsync (stream).get ();
231- BitmapPixelFormat format = inputDataType == InputDataType::Tensor
232- ? decoder.BitmapPixelFormat ()
233- : TypeHelper::GetBitmapPixelFormat (inputDataType);
234-
257+ decoder = BitmapDecoder::CreateAsync (stream).get ();
258+ }
259+ catch (hresult_error hr)
260+ {
261+ printf (" Failed to load the image file, make sure you are using fully qualified paths\r\n " );
262+ printf (" %ws\n " , hr.message ().c_str ());
263+ exit (hr.code ());
264+ }
265+ BitmapPixelFormat format = inputDataType == InputDataType::Tensor
266+ ? decoder.BitmapPixelFormat ()
267+ : TypeHelper::GetBitmapPixelFormat (inputDataType);
268+ try
269+ {
235270 // If input dimensions are different from tensor input, then scale / crop while reading
236271 if (args.IsAutoScale () && (decoder.PixelHeight () != height || decoder.PixelWidth () != width))
237272 {
@@ -247,22 +282,23 @@ namespace BindingUtilities
247282
248283 // get the bitmap
249284 return decoder
250- .GetSoftwareBitmapAsync (format, BitmapAlphaMode::Ignore, transform,
251- ExifOrientationMode::RespectExifOrientation,
252- ColorManagementMode::DoNotColorManage)
253- .get ();
285+ .GetSoftwareBitmapAsync (format, decoder.BitmapAlphaMode (), transform,
286+ ExifOrientationMode::RespectExifOrientation, colorManagementMode).get ();
254287 }
255288 else
256289 {
257290 // get the bitmap
258- return decoder.GetSoftwareBitmapAsync (format, BitmapAlphaMode::Ignore).get ();
291+ return decoder
292+ .GetSoftwareBitmapAsync (format, decoder.BitmapAlphaMode (), BitmapTransform (),
293+ ExifOrientationMode::RespectExifOrientation, colorManagementMode).get ();
259294 }
260295 }
261- catch (... )
296+ catch (hresult_error hr )
262297 {
263- std::wcout << L" BindingUtilities: could not open image file (" << std::wstring (filePath) << L" ), "
264- << L" make sure you are using fully qualified paths." << std::endl;
265- return nullptr ;
298+ printf (" Failed to create SoftwareBitmap! Please make sure that input image is within the model's "
299+ " colorspace.\n " );
300+ printf (" %ws\n " , hr.message ().c_str ());
301+ exit (hr.code ());
266302 }
267303 }
268304
@@ -364,7 +400,8 @@ namespace BindingUtilities
364400 template <TensorKind TKind, typename WriteType>
365401 static void GenerateRandomData (WriteType* data, uint32_t sizeInBytes, uint32_t maxValue)
366402 {
367- static std::independent_bits_engine<std::default_random_engine, sizeof (uint32_t ) * 8 , uint32_t > randomBitsEngine;
403+ static std::independent_bits_engine<std::default_random_engine, sizeof (uint32_t ) * 8 , uint32_t >
404+ randomBitsEngine;
368405 randomBitsEngine.seed (seed++);
369406
370407 WriteType* begin = data;
@@ -639,7 +676,8 @@ namespace BindingUtilities
639676 // Binds tensor floats, ints, doubles from CSV data.
640677 ITensor CreateBindableTensor (const ILearningModelFeatureDescriptor& description, const std::wstring& imagePath,
641678 const InputBindingType inputBindingType, const InputDataType inputDataType,
642- const CommandLineArgs& args, uint32_t iterationNum)
679+ const CommandLineArgs& args, uint32_t iterationNum,
680+ ColorManagementMode colorManagementMode)
643681 {
644682 InputBufferDesc inputBufferDesc = {};
645683
@@ -669,7 +707,8 @@ namespace BindingUtilities
669707 }
670708 else if (args.IsImageInput ())
671709 {
672- softwareBitmap = LoadImageFile (description, inputDataType, imagePath.c_str (), args, iterationNum);
710+ softwareBitmap =
711+ LoadImageFile (description, inputDataType, imagePath.c_str (), args, iterationNum, colorManagementMode);
673712
674713 // Get Pointers to the SoftwareBitmap data buffers
675714 const BitmapBuffer sbBitmapBuffer (softwareBitmap.LockBuffer (BitmapBufferAccessMode::Read));
@@ -781,11 +820,12 @@ namespace BindingUtilities
781820 ImageFeatureValue CreateBindableImage (const ILearningModelFeatureDescriptor& featureDescriptor,
782821 const std::wstring& imagePath, InputBindingType inputBindingType,
783822 InputDataType inputDataType, const IDirect3DDevice winrtDevice,
784- const CommandLineArgs& args, uint32_t iterationNum)
823+ const CommandLineArgs& args, uint32_t iterationNum,
824+ ColorManagementMode colorManagementMode)
785825 {
786- auto softwareBitmap =
787- imagePath. empty () ? GenerateGarbageImage (featureDescriptor, inputDataType)
788- : LoadImageFile (featureDescriptor, inputDataType, imagePath. c_str (), args, iterationNum);
826+ auto softwareBitmap = imagePath. empty () ? GenerateGarbageImage (featureDescriptor, inputDataType)
827+ : LoadImageFile (featureDescriptor, inputDataType, imagePath. c_str (),
828+ args, iterationNum, colorManagementMode );
789829 auto videoFrame = CreateVideoFrame (softwareBitmap, inputBindingType, inputDataType, winrtDevice);
790830 return ImageFeatureValue::CreateFromVideoFrame (videoFrame);
791831 }
0 commit comments