Skip to content

Commit 8ca3e52

Browse files
author
Ryan Lai
authored
Update WinMLRunner to be explicit with colorspace management of GetSoftwareBitmapAsync (#283)
* Argument for tensor garbage data with a range of values. * Add GarbageDataMaxValue to limit generated garbage data to values between 0 and the specified maximum. * If an image is passed in, check model's metadata to see if color space needs management * PR comment about try catch tighter * small comment change
1 parent f410c47 commit 8ca3e52

File tree

2 files changed

+70
-26
lines changed

2 files changed

+70
-26
lines changed

Tools/WinMLRunner/src/BindingUtilities.h

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,30 @@ template <> winrt::hstring ConvertToPointerType<TensorKind::String>(winrt::hstri
151151
return static_cast<winrt::hstring>(value);
152152
};
153153

154+
static ColorManagementMode GetColorManagementMode(const LearningModel& model)
155+
156+
{
157+
// Get model color space gamma
158+
hstring gammaSpace = L"";
159+
try
160+
{
161+
gammaSpace = model.Metadata().Lookup(L"Image.ColorSpaceGamma");
162+
}
163+
catch (...)
164+
{
165+
printf(" Model does not have color space gamma information. Will color manage to sRGB by default...\n");
166+
}
167+
if (gammaSpace == L"" || _wcsicmp(gammaSpace.c_str(), L"SRGB") == 0)
168+
{
169+
return ColorManagementMode::ColorManageToSRgb;
170+
}
171+
// Due diligence should be done to make sure that the input image is within the model's colorspace. There are
172+
// multiple non-sRGB color spaces.
173+
printf(" Model metadata indicates that color gamma space is : %ws. Will not manage color space to sRGB...\n",
174+
gammaSpace.c_str());
175+
return ColorManagementMode::DoNotColorManage;
176+
}
177+
154178
void GetHeightAndWidthFromLearningModelFeatureDescriptor(const ILearningModelFeatureDescriptor& modelFeatureDescriptor,
155179
uint64_t& width, uint64_t& height)
156180
{
@@ -214,24 +238,35 @@ namespace BindingUtilities
214238

215239
SoftwareBitmap LoadImageFile(const ILearningModelFeatureDescriptor& modelFeatureDescriptor,
216240
const InputDataType inputDataType, const hstring& filePath,
217-
const CommandLineArgs& args, uint32_t iterationNum)
241+
const CommandLineArgs& args, uint32_t iterationNum,
242+
ColorManagementMode colorManagementMode)
218243
{
219244
// We assume NCHW and NCDHW
220245
uint64_t width = 0;
221246
uint64_t height = 0;
222247
GetHeightAndWidthFromLearningModelFeatureDescriptor(modelFeatureDescriptor, width, height);
248+
IRandomAccessStream stream;
249+
BitmapDecoder decoder = NULL;
223250
try
224251
{
225252
// open the file
226253
StorageFile file = StorageFile::GetFileFromPathAsync(filePath).get();
227254
// get a stream on it
228-
auto stream = file.OpenAsync(FileAccessMode::Read).get();
255+
stream = file.OpenAsync(FileAccessMode::Read).get();
229256
// Create the decoder from the stream
230-
BitmapDecoder decoder = BitmapDecoder::CreateAsync(stream).get();
231-
BitmapPixelFormat format = inputDataType == InputDataType::Tensor
232-
? decoder.BitmapPixelFormat()
233-
: TypeHelper::GetBitmapPixelFormat(inputDataType);
234-
257+
decoder = BitmapDecoder::CreateAsync(stream).get();
258+
}
259+
catch (hresult_error hr)
260+
{
261+
printf(" Failed to load the image file, make sure you are using fully qualified paths\r\n");
262+
printf(" %ws\n", hr.message().c_str());
263+
exit(hr.code());
264+
}
265+
BitmapPixelFormat format = inputDataType == InputDataType::Tensor
266+
? decoder.BitmapPixelFormat()
267+
: TypeHelper::GetBitmapPixelFormat(inputDataType);
268+
try
269+
{
235270
// If input dimensions are different from tensor input, then scale / crop while reading
236271
if (args.IsAutoScale() && (decoder.PixelHeight() != height || decoder.PixelWidth() != width))
237272
{
@@ -247,22 +282,23 @@ namespace BindingUtilities
247282

248283
// get the bitmap
249284
return decoder
250-
.GetSoftwareBitmapAsync(format, BitmapAlphaMode::Ignore, transform,
251-
ExifOrientationMode::RespectExifOrientation,
252-
ColorManagementMode::DoNotColorManage)
253-
.get();
285+
.GetSoftwareBitmapAsync(format, decoder.BitmapAlphaMode(), transform,
286+
ExifOrientationMode::RespectExifOrientation, colorManagementMode).get();
254287
}
255288
else
256289
{
257290
// get the bitmap
258-
return decoder.GetSoftwareBitmapAsync(format, BitmapAlphaMode::Ignore).get();
291+
return decoder
292+
.GetSoftwareBitmapAsync(format, decoder.BitmapAlphaMode(), BitmapTransform(),
293+
ExifOrientationMode::RespectExifOrientation, colorManagementMode).get();
259294
}
260295
}
261-
catch (...)
296+
catch (hresult_error hr)
262297
{
263-
std::wcout << L"BindingUtilities: could not open image file (" << std::wstring(filePath) << L"), "
264-
<< L"make sure you are using fully qualified paths." << std::endl;
265-
return nullptr;
298+
printf(" Failed to create SoftwareBitmap! Please make sure that input image is within the model's "
299+
"colorspace.\n");
300+
printf(" %ws\n", hr.message().c_str());
301+
exit(hr.code());
266302
}
267303
}
268304

@@ -364,7 +400,8 @@ namespace BindingUtilities
364400
template <TensorKind TKind, typename WriteType>
365401
static void GenerateRandomData(WriteType* data, uint32_t sizeInBytes, uint32_t maxValue)
366402
{
367-
static std::independent_bits_engine<std::default_random_engine, sizeof(uint32_t) * 8, uint32_t> randomBitsEngine;
403+
static std::independent_bits_engine<std::default_random_engine, sizeof(uint32_t) * 8, uint32_t>
404+
randomBitsEngine;
368405
randomBitsEngine.seed(seed++);
369406

370407
WriteType* begin = data;
@@ -639,7 +676,8 @@ namespace BindingUtilities
639676
// Binds tensor floats, ints, doubles from CSV data.
640677
ITensor CreateBindableTensor(const ILearningModelFeatureDescriptor& description, const std::wstring& imagePath,
641678
const InputBindingType inputBindingType, const InputDataType inputDataType,
642-
const CommandLineArgs& args, uint32_t iterationNum)
679+
const CommandLineArgs& args, uint32_t iterationNum,
680+
ColorManagementMode colorManagementMode)
643681
{
644682
InputBufferDesc inputBufferDesc = {};
645683

@@ -669,7 +707,8 @@ namespace BindingUtilities
669707
}
670708
else if (args.IsImageInput())
671709
{
672-
softwareBitmap = LoadImageFile(description, inputDataType, imagePath.c_str(), args, iterationNum);
710+
softwareBitmap =
711+
LoadImageFile(description, inputDataType, imagePath.c_str(), args, iterationNum, colorManagementMode);
673712

674713
// Get Pointers to the SoftwareBitmap data buffers
675714
const BitmapBuffer sbBitmapBuffer(softwareBitmap.LockBuffer(BitmapBufferAccessMode::Read));
@@ -781,11 +820,12 @@ namespace BindingUtilities
781820
ImageFeatureValue CreateBindableImage(const ILearningModelFeatureDescriptor& featureDescriptor,
782821
const std::wstring& imagePath, InputBindingType inputBindingType,
783822
InputDataType inputDataType, const IDirect3DDevice winrtDevice,
784-
const CommandLineArgs& args, uint32_t iterationNum)
823+
const CommandLineArgs& args, uint32_t iterationNum,
824+
ColorManagementMode colorManagementMode)
785825
{
786-
auto softwareBitmap =
787-
imagePath.empty() ? GenerateGarbageImage(featureDescriptor, inputDataType)
788-
: LoadImageFile(featureDescriptor, inputDataType, imagePath.c_str(), args, iterationNum);
826+
auto softwareBitmap = imagePath.empty() ? GenerateGarbageImage(featureDescriptor, inputDataType)
827+
: LoadImageFile(featureDescriptor, inputDataType, imagePath.c_str(),
828+
args, iterationNum, colorManagementMode);
789829
auto videoFrame = CreateVideoFrame(softwareBitmap, inputBindingType, inputDataType, winrtDevice);
790830
return ImageFeatureValue::CreateFromVideoFrame(videoFrame);
791831
}

Tools/WinMLRunner/src/Run.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,23 @@ std::vector<ILearningModelFeatureValue> GenerateInputFeatures(const LearningMode
2323
for (uint32_t inputNum = 0; inputNum < model.InputFeatures().Size(); inputNum++)
2424
{
2525
auto&& description = model.InputFeatures().GetAt(inputNum);
26-
26+
ColorManagementMode colorManagementMode = ColorManagementMode::DoNotColorManage;
27+
if (args.IsImageInput())
28+
{
29+
colorManagementMode = GetColorManagementMode(model);
30+
}
2731
if (inputDataType == InputDataType::Tensor)
2832
{
2933
// If CSV data is provided, then every input will contain the same CSV data
3034
auto tensorFeature = BindingUtilities::CreateBindableTensor(description, imagePath, inputBindingType, inputDataType,
31-
args, iterationNum);
35+
args, iterationNum, colorManagementMode);
3236
inputFeatures.push_back(tensorFeature);
3337
}
3438
else
3539
{
3640
auto imageFeature = BindingUtilities::CreateBindableImage(
3741
description, imagePath, inputBindingType, inputDataType, device.LearningModelDevice.Direct3D11Device(),
38-
args, iterationNum);
42+
args, iterationNum, colorManagementMode);
3943
inputFeatures.push_back(imageFeature);
4044
}
4145
}

0 commit comments

Comments
 (0)