66#include " d3dx12.h"
77using namespace winrt ::Windows::Media;
88using namespace winrt ::Windows::Storage;
9+ using namespace winrt ::Windows::Storage::Streams;
910using namespace winrt ::Windows::AI::MachineLearning;
1011using namespace winrt ::Windows::Foundation::Collections;
1112using namespace winrt ::Windows::Graphics::DirectX;
1213using namespace winrt ::Windows::Graphics::Imaging;
1314using namespace winrt ::Windows::Graphics::DirectX::Direct3D11;
15+ using namespace DirectX ::PackedVector;
1416
15- template <TensorKind T> struct TensorKindToType
17+ template <TensorKind T> struct TensorKindToArithmeticType
1618{
1719 static_assert (true , " No TensorKind mapped for given type!" );
1820};
19- template <> struct TensorKindToType <TensorKind::UInt8>
21+ template <> struct TensorKindToArithmeticType <TensorKind::UInt8>
2022{
2123 typedef uint8_t Type;
2224};
23- template <> struct TensorKindToType <TensorKind::Int8>
25+ template <> struct TensorKindToArithmeticType <TensorKind::Int8>
2426{
2527 typedef uint8_t Type;
2628};
27- template <> struct TensorKindToType <TensorKind::UInt16>
29+ template <> struct TensorKindToArithmeticType <TensorKind::UInt16>
2830{
2931 typedef uint16_t Type;
3032};
31- template <> struct TensorKindToType <TensorKind::Int16>
33+ template <> struct TensorKindToArithmeticType <TensorKind::Int16>
3234{
3335 typedef int16_t Type;
3436};
35- template <> struct TensorKindToType <TensorKind::UInt32>
37+ template <> struct TensorKindToArithmeticType <TensorKind::UInt32>
3638{
3739 typedef uint32_t Type;
3840};
39- template <> struct TensorKindToType <TensorKind::Int32>
41+ template <> struct TensorKindToArithmeticType <TensorKind::Int32>
4042{
4143 typedef int32_t Type;
4244};
43- template <> struct TensorKindToType <TensorKind::UInt64>
45+ template <> struct TensorKindToArithmeticType <TensorKind::UInt64>
4446{
4547 typedef uint64_t Type;
4648};
47- template <> struct TensorKindToType <TensorKind::Int64>
49+ template <> struct TensorKindToArithmeticType <TensorKind::Int64>
4850{
4951 typedef int64_t Type;
5052};
51- template <> struct TensorKindToType <TensorKind::Boolean>
53+ template <> struct TensorKindToArithmeticType <TensorKind::Boolean>
5254{
5355 typedef boolean Type;
5456};
55- template <> struct TensorKindToType <TensorKind::Double>
57+ template <> struct TensorKindToArithmeticType <TensorKind::Double>
5658{
5759 typedef double Type;
5860};
59- template <> struct TensorKindToType <TensorKind::Float>
61+ template <> struct TensorKindToArithmeticType <TensorKind::Float>
6062{
6163 typedef float Type;
6264};
63- template <> struct TensorKindToType <TensorKind::Float16>
65+ template <> struct TensorKindToArithmeticType <TensorKind::Float16>
66+ {
67+ typedef float Type;
68+ };
69+ template <> struct TensorKindToArithmeticType <TensorKind::String>
70+ {
71+ typedef winrt::hstring Type;
72+ };
73+
74+ template <TensorKind T> struct TensorKindToPointerType
75+ {
76+ static_assert (true , " No TensorKind mapped for given type!" );
77+ };
78+ template <> struct TensorKindToPointerType <TensorKind::UInt8>
79+ {
80+ typedef uint8_t Type;
81+ };
82+ template <> struct TensorKindToPointerType <TensorKind::Int8>
83+ {
84+ typedef uint8_t Type;
85+ };
86+ template <> struct TensorKindToPointerType <TensorKind::UInt16>
87+ {
88+ typedef uint16_t Type;
89+ };
90+ template <> struct TensorKindToPointerType <TensorKind::Int16>
91+ {
92+ typedef int16_t Type;
93+ };
94+ template <> struct TensorKindToPointerType <TensorKind::UInt32>
95+ {
96+ typedef uint32_t Type;
97+ };
98+ template <> struct TensorKindToPointerType <TensorKind::Int32>
99+ {
100+ typedef int32_t Type;
101+ };
102+ template <> struct TensorKindToPointerType <TensorKind::UInt64>
103+ {
104+ typedef uint64_t Type;
105+ };
106+ template <> struct TensorKindToPointerType <TensorKind::Int64>
107+ {
108+ typedef int64_t Type;
109+ };
110+ template <> struct TensorKindToPointerType <TensorKind::Boolean>
111+ {
112+ typedef boolean Type;
113+ };
114+ template <> struct TensorKindToPointerType <TensorKind::Double>
115+ {
116+ typedef double Type;
117+ };
118+ template <> struct TensorKindToPointerType <TensorKind::Float>
119+ {
120+ typedef float Type;
121+ };
122+ template <> struct TensorKindToPointerType <TensorKind::Float16>
64123{
65124 typedef HALF Type;
66125};
67- template <> struct TensorKindToType <TensorKind::String>
126+ template <> struct TensorKindToPointerType <TensorKind::String>
68127{
69128 typedef winrt::hstring Type;
70129};
@@ -126,6 +185,63 @@ template <> struct TensorKindToValue<TensorKind::String>
126185 typedef TensorString Type;
127186};
128187
188+ template <TensorKind T, typename PointerType, typename ArithmeticType > PointerType ConvertArithmeticTypeToPointerType (ArithmeticType value)
189+ {
190+ static_assert (true , " No TensorKind mapped for given type!" );
191+ };
192+ template <> uint8_t ConvertArithmeticTypeToPointerType<TensorKind::UInt8>(uint8_t value)
193+ {
194+ return static_cast <uint8_t >(value);
195+ };
196+ template <> uint8_t ConvertArithmeticTypeToPointerType<TensorKind::Int8>(uint8_t value)
197+ {
198+ return static_cast <uint8_t >(value);
199+ };
200+ template <> uint16_t ConvertArithmeticTypeToPointerType<TensorKind::UInt16>(uint16_t value)
201+ {
202+ return static_cast <uint16_t >(value);
203+ };
204+ template <> int16_t ConvertArithmeticTypeToPointerType<TensorKind::Int16>(int16_t value)
205+ {
206+ return static_cast <int16_t >(value);
207+ };
208+ template <> uint32_t ConvertArithmeticTypeToPointerType<TensorKind::UInt32>(uint32_t value)
209+ {
210+ return static_cast <uint32_t >(value);
211+ };
212+ template <> int32_t ConvertArithmeticTypeToPointerType<TensorKind::Int32>(int32_t value)
213+ {
214+ return static_cast <int32_t >(value);
215+ };
216+ template <> uint64_t ConvertArithmeticTypeToPointerType<TensorKind::UInt64>(uint64_t value)
217+ {
218+ return static_cast <uint64_t >(value);
219+ };
220+ template <> int64_t ConvertArithmeticTypeToPointerType<TensorKind::Int64>(int64_t value)
221+ {
222+ return static_cast <int64_t >(value);
223+ };
224+ template <> boolean ConvertArithmeticTypeToPointerType<TensorKind::Boolean>(boolean value)
225+ {
226+ return static_cast <boolean>(value);
227+ };
228+ template <> double ConvertArithmeticTypeToPointerType<TensorKind::Double>(double value)
229+ {
230+ return static_cast <double >(value);
231+ };
232+ template <> float ConvertArithmeticTypeToPointerType<TensorKind::Float>(float value)
233+ {
234+ return static_cast <float >(value);
235+ };
236+ template <> HALF ConvertArithmeticTypeToPointerType<TensorKind::Float16>(float value)
237+ {
238+ return XMConvertFloatToHalf (value);
239+ };
240+ template <> winrt::hstring ConvertArithmeticTypeToPointerType<TensorKind::String>(winrt::hstring value)
241+ {
242+ return static_cast <winrt::hstring>(value);
243+ };
244+
129245void GetHeightAndWidthFromLearningModelFeatureDescriptor (const ILearningModelFeatureDescriptor& modelFeatureDescriptor,
130246 uint64_t & width, uint64_t & height)
131247{
@@ -286,31 +402,6 @@ namespace BindingUtilities
286402 return elementStrings;
287403 }
288404
289- template <typename T>
290- void WriteDataToBinding (const std::vector<std::string>& elementStrings, T* bindingMemory,
291- uint32_t bindingMemorySize)
292- {
293- if (bindingMemorySize / sizeof (T) != elementStrings.size ())
294- {
295- throw hresult_invalid_argument (L" CSV Input is size/shape is different from what model expects" );
296- }
297- T* data = bindingMemory;
298- for (const auto & elementString : elementStrings)
299- {
300- float value;
301- std::stringstream (elementString) >> value;
302- if (!std::is_same<T, HALF>::value)
303- {
304- *data = static_cast <T>(value);
305- }
306- else
307- {
308- *reinterpret_cast <HALF*>(data) = XMConvertFloatToHalf (value);
309- }
310- data++;
311- }
312- }
313-
314405 std::vector<std::string> ParseCSVElementStrings (const std::wstring& csvFilePath)
315406 {
316407 std::ifstream fileStream;
@@ -330,7 +421,9 @@ namespace BindingUtilities
330421 const IVectorView<int64_t >& tensorShape, const InputBindingType inputBindingType)
331422 {
332423 using TensorValue = typename TensorKindToValue<T>::Type;
333- using DataType = typename TensorKindToType<T>::Type;
424+ using ArithmeticType = typename TensorKindToArithmeticType<T>::Type;
425+ using PointerType = typename TensorKindToPointerType<T>::Type;
426+
334427 std::vector<int64_t > vecShape = {};
335428 for (UINT dim = 0 ; dim < tensorShape.Size (); dim++)
336429 {
@@ -353,19 +446,34 @@ namespace BindingUtilities
353446 }
354447 }
355448 }
449+
450+ // Map the incoming Tensor as a TensorNative to get the actual data buffer.
356451 auto tensorValue = TensorValue::Create (vecShape);
357452
358453 com_ptr<ITensorNative> spTensorValueNative;
359454 tensorValue.as (spTensorValueNative);
360455
361- BYTE * actualData;
456+ PointerType * actualData;
362457 uint32_t actualSizeInBytes;
363- spTensorValueNative->GetBuffer (&actualData,
364- &actualSizeInBytes); // Need to GetBuffer to have CPU memory backing tensorValue
458+ spTensorValueNative->GetBuffer (reinterpret_cast <BYTE**>( &actualData) ,
459+ &actualSizeInBytes);
365460
366461 if (args.IsCSVInput ())
367462 {
368- WriteDataToBinding<DataType>(tensorStringInput, reinterpret_cast <DataType*>(actualData), actualSizeInBytes);
463+ if (tensorStringInput.size () != actualSizeInBytes / sizeof (PointerType))
464+ {
465+ throw hresult_invalid_argument (L" CSV input size/shape is different from what model expects" );
466+ }
467+
468+ // Write the elementStrings into the iTensorNative
469+ PointerType* dataPtr = actualData;
470+ for (const auto &tensorString : tensorStringInput)
471+ {
472+ ArithmeticType value;
473+ std::stringstream (tensorString) >> value;
474+ *dataPtr = ConvertArithmeticTypeToPointerType<T,PointerType,ArithmeticType>(value);
475+ dataPtr++;
476+ }
369477 }
370478 else if (args.IsImageInput ())
371479 {
@@ -629,7 +737,7 @@ namespace BindingUtilities
629737 {
630738 if (desc.Kind () == LearningModelFeatureKind::Tensor)
631739 {
632- std::wstring name = desc.Name (). c_str ( );
740+ std::wstring name ( desc.Name ());
633741 if (args.IsSaveTensor () && args.SaveTensorMode () == L" First" && iterationNum > 0 )
634742 {
635743 return ;
0 commit comments