Skip to content

Commit 0c51398

Browse files
pmbrown1055ryanlai2
authored andcommitted
Refactor FP16 tensor creation to prepare for client tensorization and preprocessing of images. (#221)
1 parent e8f0ecf commit 0c51398

File tree

1 file changed

+153
-45
lines changed

1 file changed

+153
-45
lines changed

Tools/WinMLRunner/src/BindingUtilities.h

Lines changed: 153 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,65 +6,124 @@
66
#include "d3dx12.h"
77
using namespace winrt::Windows::Media;
88
using namespace winrt::Windows::Storage;
9+
using namespace winrt::Windows::Storage::Streams;
910
using namespace winrt::Windows::AI::MachineLearning;
1011
using namespace winrt::Windows::Foundation::Collections;
1112
using namespace winrt::Windows::Graphics::DirectX;
1213
using namespace winrt::Windows::Graphics::Imaging;
1314
using namespace winrt::Windows::Graphics::DirectX::Direct3D11;
15+
using namespace DirectX::PackedVector;
1416

15-
template <TensorKind T> struct TensorKindToType
17+
template <TensorKind T> struct TensorKindToArithmeticType
1618
{
1719
static_assert(true, "No TensorKind mapped for given type!");
1820
};
19-
template <> struct TensorKindToType<TensorKind::UInt8>
21+
template <> struct TensorKindToArithmeticType<TensorKind::UInt8>
2022
{
2123
typedef uint8_t Type;
2224
};
23-
template <> struct TensorKindToType<TensorKind::Int8>
25+
template <> struct TensorKindToArithmeticType<TensorKind::Int8>
2426
{
2527
typedef uint8_t Type;
2628
};
27-
template <> struct TensorKindToType<TensorKind::UInt16>
29+
template <> struct TensorKindToArithmeticType<TensorKind::UInt16>
2830
{
2931
typedef uint16_t Type;
3032
};
31-
template <> struct TensorKindToType<TensorKind::Int16>
33+
template <> struct TensorKindToArithmeticType<TensorKind::Int16>
3234
{
3335
typedef int16_t Type;
3436
};
35-
template <> struct TensorKindToType<TensorKind::UInt32>
37+
template <> struct TensorKindToArithmeticType<TensorKind::UInt32>
3638
{
3739
typedef uint32_t Type;
3840
};
39-
template <> struct TensorKindToType<TensorKind::Int32>
41+
template <> struct TensorKindToArithmeticType<TensorKind::Int32>
4042
{
4143
typedef int32_t Type;
4244
};
43-
template <> struct TensorKindToType<TensorKind::UInt64>
45+
template <> struct TensorKindToArithmeticType<TensorKind::UInt64>
4446
{
4547
typedef uint64_t Type;
4648
};
47-
template <> struct TensorKindToType<TensorKind::Int64>
49+
template <> struct TensorKindToArithmeticType<TensorKind::Int64>
4850
{
4951
typedef int64_t Type;
5052
};
51-
template <> struct TensorKindToType<TensorKind::Boolean>
53+
template <> struct TensorKindToArithmeticType<TensorKind::Boolean>
5254
{
5355
typedef boolean Type;
5456
};
55-
template <> struct TensorKindToType<TensorKind::Double>
57+
template <> struct TensorKindToArithmeticType<TensorKind::Double>
5658
{
5759
typedef double Type;
5860
};
59-
template <> struct TensorKindToType<TensorKind::Float>
61+
template <> struct TensorKindToArithmeticType<TensorKind::Float>
6062
{
6163
typedef float Type;
6264
};
63-
template <> struct TensorKindToType<TensorKind::Float16>
65+
template <> struct TensorKindToArithmeticType<TensorKind::Float16>
66+
{
67+
typedef float Type;
68+
};
69+
template <> struct TensorKindToArithmeticType<TensorKind::String>
70+
{
71+
typedef winrt::hstring Type;
72+
};
73+
74+
template <TensorKind T> struct TensorKindToPointerType
75+
{
76+
static_assert(true, "No TensorKind mapped for given type!");
77+
};
78+
template <> struct TensorKindToPointerType<TensorKind::UInt8>
79+
{
80+
typedef uint8_t Type;
81+
};
82+
template <> struct TensorKindToPointerType<TensorKind::Int8>
83+
{
84+
typedef uint8_t Type;
85+
};
86+
template <> struct TensorKindToPointerType<TensorKind::UInt16>
87+
{
88+
typedef uint16_t Type;
89+
};
90+
template <> struct TensorKindToPointerType<TensorKind::Int16>
91+
{
92+
typedef int16_t Type;
93+
};
94+
template <> struct TensorKindToPointerType<TensorKind::UInt32>
95+
{
96+
typedef uint32_t Type;
97+
};
98+
template <> struct TensorKindToPointerType<TensorKind::Int32>
99+
{
100+
typedef int32_t Type;
101+
};
102+
template <> struct TensorKindToPointerType<TensorKind::UInt64>
103+
{
104+
typedef uint64_t Type;
105+
};
106+
template <> struct TensorKindToPointerType<TensorKind::Int64>
107+
{
108+
typedef int64_t Type;
109+
};
110+
template <> struct TensorKindToPointerType<TensorKind::Boolean>
111+
{
112+
typedef boolean Type;
113+
};
114+
template <> struct TensorKindToPointerType<TensorKind::Double>
115+
{
116+
typedef double Type;
117+
};
118+
template <> struct TensorKindToPointerType<TensorKind::Float>
119+
{
120+
typedef float Type;
121+
};
122+
template <> struct TensorKindToPointerType<TensorKind::Float16>
64123
{
65124
typedef HALF Type;
66125
};
67-
template <> struct TensorKindToType<TensorKind::String>
126+
template <> struct TensorKindToPointerType<TensorKind::String>
68127
{
69128
typedef winrt::hstring Type;
70129
};
@@ -126,6 +185,63 @@ template <> struct TensorKindToValue<TensorKind::String>
126185
typedef TensorString Type;
127186
};
128187

188+
template <TensorKind T, typename PointerType, typename ArithmeticType > PointerType ConvertArithmeticTypeToPointerType(ArithmeticType value)
189+
{
190+
static_assert(true, "No TensorKind mapped for given type!");
191+
};
192+
template <> uint8_t ConvertArithmeticTypeToPointerType<TensorKind::UInt8>(uint8_t value)
193+
{
194+
return static_cast<uint8_t>(value);
195+
};
196+
template <> uint8_t ConvertArithmeticTypeToPointerType<TensorKind::Int8>(uint8_t value)
197+
{
198+
return static_cast<uint8_t>(value);
199+
};
200+
template <> uint16_t ConvertArithmeticTypeToPointerType<TensorKind::UInt16>(uint16_t value)
201+
{
202+
return static_cast<uint16_t>(value);
203+
};
204+
template <> int16_t ConvertArithmeticTypeToPointerType<TensorKind::Int16>(int16_t value)
205+
{
206+
return static_cast<int16_t>(value);
207+
};
208+
template <> uint32_t ConvertArithmeticTypeToPointerType<TensorKind::UInt32>(uint32_t value)
209+
{
210+
return static_cast<uint32_t>(value);
211+
};
212+
template <> int32_t ConvertArithmeticTypeToPointerType<TensorKind::Int32>(int32_t value)
213+
{
214+
return static_cast<int32_t>(value);
215+
};
216+
template <> uint64_t ConvertArithmeticTypeToPointerType<TensorKind::UInt64>(uint64_t value)
217+
{
218+
return static_cast<uint64_t>(value);
219+
};
220+
template <> int64_t ConvertArithmeticTypeToPointerType<TensorKind::Int64>(int64_t value)
221+
{
222+
return static_cast<int64_t>(value);
223+
};
224+
template <> boolean ConvertArithmeticTypeToPointerType<TensorKind::Boolean>(boolean value)
225+
{
226+
return static_cast<boolean>(value);
227+
};
228+
template <> double ConvertArithmeticTypeToPointerType<TensorKind::Double>(double value)
229+
{
230+
return static_cast<double>(value);
231+
};
232+
template <> float ConvertArithmeticTypeToPointerType<TensorKind::Float>(float value)
233+
{
234+
return static_cast<float>(value);
235+
};
236+
template <> HALF ConvertArithmeticTypeToPointerType<TensorKind::Float16>(float value)
237+
{
238+
return XMConvertFloatToHalf(value);
239+
};
240+
template <> winrt::hstring ConvertArithmeticTypeToPointerType<TensorKind::String>(winrt::hstring value)
241+
{
242+
return static_cast<winrt::hstring>(value);
243+
};
244+
129245
void GetHeightAndWidthFromLearningModelFeatureDescriptor(const ILearningModelFeatureDescriptor& modelFeatureDescriptor,
130246
uint64_t& width, uint64_t& height)
131247
{
@@ -286,31 +402,6 @@ namespace BindingUtilities
286402
return elementStrings;
287403
}
288404

289-
template <typename T>
290-
void WriteDataToBinding(const std::vector<std::string>& elementStrings, T* bindingMemory,
291-
uint32_t bindingMemorySize)
292-
{
293-
if (bindingMemorySize / sizeof(T) != elementStrings.size())
294-
{
295-
throw hresult_invalid_argument(L"CSV Input is size/shape is different from what model expects");
296-
}
297-
T* data = bindingMemory;
298-
for (const auto& elementString : elementStrings)
299-
{
300-
float value;
301-
std::stringstream(elementString) >> value;
302-
if (!std::is_same<T, HALF>::value)
303-
{
304-
*data = static_cast<T>(value);
305-
}
306-
else
307-
{
308-
*reinterpret_cast<HALF*>(data) = XMConvertFloatToHalf(value);
309-
}
310-
data++;
311-
}
312-
}
313-
314405
std::vector<std::string> ParseCSVElementStrings(const std::wstring& csvFilePath)
315406
{
316407
std::ifstream fileStream;
@@ -330,7 +421,9 @@ namespace BindingUtilities
330421
const IVectorView<int64_t>& tensorShape, const InputBindingType inputBindingType)
331422
{
332423
using TensorValue = typename TensorKindToValue<T>::Type;
333-
using DataType = typename TensorKindToType<T>::Type;
424+
using ArithmeticType = typename TensorKindToArithmeticType<T>::Type;
425+
using PointerType = typename TensorKindToPointerType<T>::Type;
426+
334427
std::vector<int64_t> vecShape = {};
335428
for (UINT dim = 0; dim < tensorShape.Size(); dim++)
336429
{
@@ -353,19 +446,34 @@ namespace BindingUtilities
353446
}
354447
}
355448
}
449+
450+
// Map the incoming Tensor as a TensorNative to get the actual data buffer.
356451
auto tensorValue = TensorValue::Create(vecShape);
357452

358453
com_ptr<ITensorNative> spTensorValueNative;
359454
tensorValue.as(spTensorValueNative);
360455

361-
BYTE* actualData;
456+
PointerType* actualData;
362457
uint32_t actualSizeInBytes;
363-
spTensorValueNative->GetBuffer(&actualData,
364-
&actualSizeInBytes); // Need to GetBuffer to have CPU memory backing tensorValue
458+
spTensorValueNative->GetBuffer(reinterpret_cast<BYTE**>(&actualData),
459+
&actualSizeInBytes);
365460

366461
if (args.IsCSVInput())
367462
{
368-
WriteDataToBinding<DataType>(tensorStringInput, reinterpret_cast<DataType*>(actualData), actualSizeInBytes);
463+
if (tensorStringInput.size() != actualSizeInBytes / sizeof(PointerType))
464+
{
465+
throw hresult_invalid_argument(L"CSV input size/shape is different from what model expects");
466+
}
467+
468+
// Write the elementStrings into the iTensorNative
469+
PointerType* dataPtr = actualData;
470+
for (const auto &tensorString : tensorStringInput)
471+
{
472+
ArithmeticType value;
473+
std::stringstream(tensorString) >> value;
474+
*dataPtr = ConvertArithmeticTypeToPointerType<T,PointerType,ArithmeticType>(value);
475+
dataPtr++;
476+
}
369477
}
370478
else if (args.IsImageInput())
371479
{
@@ -629,7 +737,7 @@ namespace BindingUtilities
629737
{
630738
if (desc.Kind() == LearningModelFeatureKind::Tensor)
631739
{
632-
std::wstring name = desc.Name().c_str();
740+
std::wstring name(desc.Name());
633741
if (args.IsSaveTensor() && args.SaveTensorMode() == L"First" && iterationNum > 0)
634742
{
635743
return;

0 commit comments

Comments
 (0)