@@ -13,6 +13,36 @@ using namespace winrt::Windows::Graphics::DirectX;
1313using namespace winrt ::Windows::Graphics::Imaging;
1414using namespace winrt ::Windows::Graphics::DirectX::Direct3D11;
1515
16+ template <TensorKind T> struct TensorKindToType { static_assert (true , " No TensorKind mapped for given type!" ); };
17+ template <> struct TensorKindToType <TensorKind::UInt8> { typedef uint8_t Type; };
18+ template <> struct TensorKindToType <TensorKind::Int8> { typedef uint8_t Type; };
19+ template <> struct TensorKindToType <TensorKind::UInt16> { typedef uint16_t Type; };
20+ template <> struct TensorKindToType <TensorKind::Int16> { typedef int16_t Type; };
21+ template <> struct TensorKindToType <TensorKind::UInt32> { typedef uint32_t Type; };
22+ template <> struct TensorKindToType <TensorKind::Int32> { typedef int32_t Type; };
23+ template <> struct TensorKindToType <TensorKind::UInt64> { typedef uint64_t Type; };
24+ template <> struct TensorKindToType <TensorKind::Int64> { typedef int64_t Type; };
25+ template <> struct TensorKindToType <TensorKind::Boolean> { typedef boolean Type; };
26+ template <> struct TensorKindToType <TensorKind::Double> { typedef double Type; };
27+ template <> struct TensorKindToType <TensorKind::Float> { typedef float Type; };
28+ template <> struct TensorKindToType <TensorKind::Float16> { typedef float Type; };
29+ template <> struct TensorKindToType <TensorKind::String> { typedef winrt::hstring Type; };
30+
31+ template <TensorKind T> struct TensorKindToValue { static_assert (true , " No TensorKind mapped for given type!" ); };
32+ template <> struct TensorKindToValue <TensorKind::UInt8> { typedef TensorUInt8Bit Type; };
33+ template <> struct TensorKindToValue <TensorKind::Int8> { typedef TensorInt8Bit Type; };
34+ template <> struct TensorKindToValue <TensorKind::UInt16> { typedef TensorUInt16Bit Type; };
35+ template <> struct TensorKindToValue <TensorKind::Int16> { typedef TensorInt16Bit Type; };
36+ template <> struct TensorKindToValue <TensorKind::UInt32> { typedef TensorUInt32Bit Type; };
37+ template <> struct TensorKindToValue <TensorKind::Int32> { typedef TensorInt32Bit Type; };
38+ template <> struct TensorKindToValue <TensorKind::UInt64> { typedef TensorUInt64Bit Type; };
39+ template <> struct TensorKindToValue <TensorKind::Int64> { typedef TensorInt64Bit Type; };
40+ template <> struct TensorKindToValue <TensorKind::Boolean> { typedef TensorBoolean Type; };
41+ template <> struct TensorKindToValue <TensorKind::Double> { typedef TensorDouble Type; };
42+ template <> struct TensorKindToValue <TensorKind::Float> { typedef TensorFloat Type; };
43+ template <> struct TensorKindToValue <TensorKind::Float16> { typedef TensorFloat16Bit Type; };
44+ template <> struct TensorKindToValue <TensorKind::String> { typedef TensorString Type; };
45+
1646namespace BindingUtilities
1747{
1848 static unsigned int seed = 0 ;
@@ -175,6 +205,41 @@ namespace BindingUtilities
175205 return elementStrings;
176206 }
177207
208+ template <TensorKind T>
209+ static ITensor CreateTensor (
210+ const CommandLineArgs& args,
211+ std::vector<std::string>& tensorStringInput,
212+ TensorFeatureDescriptor& tensorDescriptor)
213+ {
214+ using TensorValue = typename TensorKindToValue<T>::Type;
215+ using DataType = typename TensorKindToType<T>::Type;
216+
217+ if (!args.CsvPath ().empty ())
218+ {
219+ ModelBinding<DataType> binding (tensorDescriptor);
220+ WriteDataToBinding<DataType>(tensorStringInput, binding);
221+ return TensorValue::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
222+ }
223+ else if (args.IsGarbageInput ())
224+ {
225+ auto tensorValue = TensorValue::Create (tensorDescriptor.Shape ());
226+
227+ com_ptr<ITensorNative> spTensorValueNative;
228+ tensorValue.as (spTensorValueNative);
229+
230+ BYTE* actualData;
231+ uint32_t actualSizeInBytes;
232+ spTensorValueNative->GetBuffer (&actualData, &actualSizeInBytes);
233+
234+ return tensorValue;
235+ }
236+ else
237+ {
238+ // Creating Tensors for Input Images haven't been added yet.
239+ throw hresult_not_implemented (L" Creating Tensors for Input Images haven't been implemented yet!" );
240+ }
241+ }
242+
178243 // Binds tensor floats, ints, doubles from CSV data.
179244 ITensor CreateBindableTensor (const ILearningModelFeatureDescriptor& description, const CommandLineArgs& args)
180245 {
@@ -188,6 +253,10 @@ namespace BindingUtilities
188253 }
189254
190255 std::vector<std::string> elementStrings;
256+ if (!args.CsvPath ().empty ())
257+ {
258+ elementStrings = ParseCSVElementStrings (args.CsvPath ());
259+ }
191260 switch (tensorDescriptor.TensorKind ())
192261 {
193262 case TensorKind::Undefined:
@@ -197,167 +266,57 @@ namespace BindingUtilities
197266 }
198267 case TensorKind::Float:
199268 {
200- ModelBinding<float > binding (description);
201- if (args.IsGarbageInput ())
202- {
203- memset (binding.GetData (), 0 , sizeof (float ) * binding.GetDataBufferSize ());
204- }
205- else
206- {
207- elementStrings = ParseCSVElementStrings (args.CsvPath ());
208- WriteDataToBinding<float >(elementStrings, binding);
209- }
210- return TensorFloat::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
269+ return CreateTensor<TensorKind::Float>(args, elementStrings, tensorDescriptor);
211270 }
212271 break ;
213272 case TensorKind::Float16:
214273 {
215- ModelBinding<float > binding (description);
216- if (args.IsGarbageInput ())
217- {
218- memset (binding.GetData (), 0 , sizeof (float ) * binding.GetDataBufferSize ());
219- }
220- else
221- {
222- elementStrings = ParseCSVElementStrings (args.CsvPath ());
223- WriteDataToBinding<float >(elementStrings, binding);
224- }
225- return TensorFloat16Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
274+ return CreateTensor<TensorKind::Float16>(args, elementStrings, tensorDescriptor);
226275 }
227276 break ;
228277 case TensorKind::Double:
229278 {
230- ModelBinding<double > binding (description);
231- if (args.IsGarbageInput ())
232- {
233- memset (binding.GetData (), 0 , sizeof (double ) * binding.GetDataBufferSize ());
234- }
235- else
236- {
237- elementStrings = ParseCSVElementStrings (args.CsvPath ());
238- WriteDataToBinding<double >(elementStrings, binding);
239- }
240- return TensorDouble::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
279+ return CreateTensor<TensorKind::Double>(args, elementStrings, tensorDescriptor);
241280 }
242281 break ;
243282 case TensorKind::Int8:
244283 {
245- ModelBinding<uint8_t > binding (description);
246- if (args.IsGarbageInput ())
247- {
248- memset (binding.GetData (), 0 , sizeof (uint8_t ) * binding.GetDataBufferSize ());
249- }
250- else
251- {
252- elementStrings = ParseCSVElementStrings (args.CsvPath ());
253- WriteDataToBinding<uint8_t >(elementStrings, binding);
254- }
255- return TensorInt8Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
284+ return CreateTensor<TensorKind::Int8>(args, elementStrings, tensorDescriptor);
256285 }
257286 break ;
258287 case TensorKind::UInt8:
259288 {
260- ModelBinding<uint8_t > binding (description);
261- if (args.IsGarbageInput ())
262- {
263- memset (binding.GetData (), 0 , sizeof (uint8_t ) * binding.GetDataBufferSize ());
264- }
265- else
266- {
267- elementStrings = ParseCSVElementStrings (args.CsvPath ());
268- WriteDataToBinding<uint8_t >(elementStrings, binding);
269- }
270- return TensorUInt8Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
289+ return CreateTensor<TensorKind::UInt8>(args, elementStrings, tensorDescriptor);
271290 }
272291 break ;
273292 case TensorKind::Int16:
274293 {
275- ModelBinding<int16_t > binding (description);
276- if (args.IsGarbageInput ())
277- {
278- memset (binding.GetData (), 0 , sizeof (int16_t ) * binding.GetDataBufferSize ());
279- }
280- else
281- {
282- elementStrings = ParseCSVElementStrings (args.CsvPath ());
283- WriteDataToBinding<int16_t >(elementStrings, binding);
284- }
285- return TensorInt16Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
294+ return CreateTensor<TensorKind::Int16>(args, elementStrings, tensorDescriptor);
286295 }
287296 break ;
288297 case TensorKind::UInt16:
289298 {
290- ModelBinding<uint16_t > binding (description);
291- if (args.IsGarbageInput ())
292- {
293- memset (binding.GetData (), 0 , sizeof (uint16_t ) * binding.GetDataBufferSize ());
294- }
295- else
296- {
297- elementStrings = ParseCSVElementStrings (args.CsvPath ());
298- WriteDataToBinding<uint16_t >(elementStrings, binding);
299- }
300- return TensorUInt16Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
299+ return CreateTensor<TensorKind::UInt16>(args, elementStrings, tensorDescriptor);
301300 }
302301 break ;
303302 case TensorKind::Int32:
304303 {
305- ModelBinding<int32_t > binding (description);
306- if (args.IsGarbageInput ())
307- {
308- memset (binding.GetData (), 0 , sizeof (int32_t ) * binding.GetDataBufferSize ());
309- }
310- else
311- {
312- elementStrings = ParseCSVElementStrings (args.CsvPath ());
313- WriteDataToBinding<int32_t >(elementStrings, binding);
314- }
315- return TensorInt32Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
304+ return CreateTensor<TensorKind::Int32>(args, elementStrings, tensorDescriptor);
316305 }
317306 break ;
318307 case TensorKind::UInt32:
319308 {
320- ModelBinding<uint32_t > binding (description);
321- if (args.IsGarbageInput ())
322- {
323- memset (binding.GetData (), 0 , sizeof (uint32_t ) * binding.GetDataBufferSize ());
324- }
325- else
326- {
327- elementStrings = ParseCSVElementStrings (args.CsvPath ());
328- WriteDataToBinding<uint32_t >(elementStrings, binding);
329- }
330- return TensorUInt32Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
309+ return CreateTensor<TensorKind::UInt32>(args, elementStrings, tensorDescriptor);
331310 }
332311 break ;
333312 case TensorKind::Int64:
334313 {
335- ModelBinding<int64_t > binding (description);
336- if (args.IsGarbageInput ())
337- {
338- memset (binding.GetData (), 0 , sizeof (int64_t ) * binding.GetDataBufferSize ());
339- }
340- else
341- {
342- elementStrings = ParseCSVElementStrings (args.CsvPath ());
343- WriteDataToBinding<int64_t >(elementStrings, binding);
344- }
345- return TensorInt64Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
314+ return CreateTensor<TensorKind::Int64>(args, elementStrings, tensorDescriptor);
346315 }
347316 break ;
348317 case TensorKind::UInt64:
349318 {
350- ModelBinding<uint64_t > binding (description);
351- if (args.IsGarbageInput ())
352- {
353- memset (binding.GetData (), 0 , sizeof (uint64_t ) * binding.GetDataBufferSize ());
354- }
355- else
356- {
357- elementStrings = ParseCSVElementStrings (args.CsvPath ());
358- WriteDataToBinding<uint64_t >(elementStrings, binding);
359- }
360- return TensorUInt64Bit::CreateFromArray (binding.GetShapeBuffer (), binding.GetDataBuffer ());
319+ return CreateTensor<TensorKind::UInt64>(args, elementStrings, tensorDescriptor);
361320 }
362321 break ;
363322 }
0 commit comments