@@ -13,35 +13,119 @@ using namespace winrt::Windows::Graphics::DirectX;
1313using namespace winrt ::Windows::Graphics::Imaging;
1414using namespace winrt ::Windows::Graphics::DirectX::Direct3D11;
1515
16- template <TensorKind T> struct TensorKindToType { static_assert (true , " No TensorKind mapped for given type!" ); };
17- template <> struct TensorKindToType <TensorKind::UInt8> { typedef uint8_t Type; };
18- template <> struct TensorKindToType <TensorKind::Int8> { typedef uint8_t Type; };
19- template <> struct TensorKindToType <TensorKind::UInt16> { typedef uint16_t Type; };
20- template <> struct TensorKindToType <TensorKind::Int16> { typedef int16_t Type; };
21- template <> struct TensorKindToType <TensorKind::UInt32> { typedef uint32_t Type; };
22- template <> struct TensorKindToType <TensorKind::Int32> { typedef int32_t Type; };
23- template <> struct TensorKindToType <TensorKind::UInt64> { typedef uint64_t Type; };
24- template <> struct TensorKindToType <TensorKind::Int64> { typedef int64_t Type; };
25- template <> struct TensorKindToType <TensorKind::Boolean> { typedef boolean Type; };
26- template <> struct TensorKindToType <TensorKind::Double> { typedef double Type; };
27- template <> struct TensorKindToType <TensorKind::Float> { typedef float Type; };
28- template <> struct TensorKindToType <TensorKind::Float16> { typedef float Type; };
29- template <> struct TensorKindToType <TensorKind::String> { typedef winrt::hstring Type; };
30-
31- template <TensorKind T> struct TensorKindToValue { static_assert (true , " No TensorKind mapped for given type!" ); };
32- template <> struct TensorKindToValue <TensorKind::UInt8> { typedef TensorUInt8Bit Type; };
33- template <> struct TensorKindToValue <TensorKind::Int8> { typedef TensorInt8Bit Type; };
34- template <> struct TensorKindToValue <TensorKind::UInt16> { typedef TensorUInt16Bit Type; };
35- template <> struct TensorKindToValue <TensorKind::Int16> { typedef TensorInt16Bit Type; };
36- template <> struct TensorKindToValue <TensorKind::UInt32> { typedef TensorUInt32Bit Type; };
37- template <> struct TensorKindToValue <TensorKind::Int32> { typedef TensorInt32Bit Type; };
38- template <> struct TensorKindToValue <TensorKind::UInt64> { typedef TensorUInt64Bit Type; };
39- template <> struct TensorKindToValue <TensorKind::Int64> { typedef TensorInt64Bit Type; };
40- template <> struct TensorKindToValue <TensorKind::Boolean> { typedef TensorBoolean Type; };
41- template <> struct TensorKindToValue <TensorKind::Double> { typedef TensorDouble Type; };
42- template <> struct TensorKindToValue <TensorKind::Float> { typedef TensorFloat Type; };
43- template <> struct TensorKindToValue <TensorKind::Float16> { typedef TensorFloat16Bit Type; };
44- template <> struct TensorKindToValue <TensorKind::String> { typedef TensorString Type; };
16+ template <TensorKind T> struct TensorKindToType
17+ {
18+ static_assert (true , " No TensorKind mapped for given type!" );
19+ };
20+ template <> struct TensorKindToType <TensorKind::UInt8>
21+ {
22+ typedef uint8_t Type;
23+ };
24+ template <> struct TensorKindToType <TensorKind::Int8>
25+ {
26+ typedef uint8_t Type;
27+ };
28+ template <> struct TensorKindToType <TensorKind::UInt16>
29+ {
30+ typedef uint16_t Type;
31+ };
32+ template <> struct TensorKindToType <TensorKind::Int16>
33+ {
34+ typedef int16_t Type;
35+ };
36+ template <> struct TensorKindToType <TensorKind::UInt32>
37+ {
38+ typedef uint32_t Type;
39+ };
40+ template <> struct TensorKindToType <TensorKind::Int32>
41+ {
42+ typedef int32_t Type;
43+ };
44+ template <> struct TensorKindToType <TensorKind::UInt64>
45+ {
46+ typedef uint64_t Type;
47+ };
48+ template <> struct TensorKindToType <TensorKind::Int64>
49+ {
50+ typedef int64_t Type;
51+ };
52+ template <> struct TensorKindToType <TensorKind::Boolean>
53+ {
54+ typedef boolean Type;
55+ };
56+ template <> struct TensorKindToType <TensorKind::Double>
57+ {
58+ typedef double Type;
59+ };
60+ template <> struct TensorKindToType <TensorKind::Float>
61+ {
62+ typedef float Type;
63+ };
64+ template <> struct TensorKindToType <TensorKind::Float16>
65+ {
66+ typedef float Type;
67+ };
68+ template <> struct TensorKindToType <TensorKind::String>
69+ {
70+ typedef winrt::hstring Type;
71+ };
72+
73+ template <TensorKind T> struct TensorKindToValue
74+ {
75+ static_assert (true , " No TensorKind mapped for given type!" );
76+ };
77+ template <> struct TensorKindToValue <TensorKind::UInt8>
78+ {
79+ typedef TensorUInt8Bit Type;
80+ };
81+ template <> struct TensorKindToValue <TensorKind::Int8>
82+ {
83+ typedef TensorInt8Bit Type;
84+ };
85+ template <> struct TensorKindToValue <TensorKind::UInt16>
86+ {
87+ typedef TensorUInt16Bit Type;
88+ };
89+ template <> struct TensorKindToValue <TensorKind::Int16>
90+ {
91+ typedef TensorInt16Bit Type;
92+ };
93+ template <> struct TensorKindToValue <TensorKind::UInt32>
94+ {
95+ typedef TensorUInt32Bit Type;
96+ };
97+ template <> struct TensorKindToValue <TensorKind::Int32>
98+ {
99+ typedef TensorInt32Bit Type;
100+ };
101+ template <> struct TensorKindToValue <TensorKind::UInt64>
102+ {
103+ typedef TensorUInt64Bit Type;
104+ };
105+ template <> struct TensorKindToValue <TensorKind::Int64>
106+ {
107+ typedef TensorInt64Bit Type;
108+ };
109+ template <> struct TensorKindToValue <TensorKind::Boolean>
110+ {
111+ typedef TensorBoolean Type;
112+ };
113+ template <> struct TensorKindToValue <TensorKind::Double>
114+ {
115+ typedef TensorDouble Type;
116+ };
117+ template <> struct TensorKindToValue <TensorKind::Float>
118+ {
119+ typedef TensorFloat Type;
120+ };
121+ template <> struct TensorKindToValue <TensorKind::Float16>
122+ {
123+ typedef TensorFloat16Bit Type;
124+ };
125+ template <> struct TensorKindToValue <TensorKind::String>
126+ {
127+ typedef TensorString Type;
128+ };
45129
46130namespace BindingUtilities
47131{
@@ -235,19 +319,21 @@ namespace BindingUtilities
235319 for (UINT dim = 0 ; dim < tensorDescriptorShape.Size (); dim++)
236320 {
237321 INT64 dimSize = tensorDescriptorShape.GetAt (dim);
238- if (dimSize > 0 ) // If the dimension is greater than 0, then it is known.
322+ if (dimSize > 0 ) // If the dimension is greater than 0, then it is known.
239323 {
240324 vecShape.push_back (dimSize);
241325 }
242- else // otherwise, make sure that the dimension is -1, representing free dimension. If not, then it's an invalid model.
326+ else // otherwise, make sure that the dimension is -1, representing free dimension. If not, then it's an
327+ // invalid model.
243328 {
244329 if (dimSize == -1 )
245330 {
246331 vecShape.push_back (1 );
247332 }
248333 else
249334 {
250- throw hresult_invalid_argument (L" Failed to create a tensor with an unknown dimension of: " + dimSize);
335+ throw hresult_invalid_argument (L" Failed to create a tensor with an unknown dimension of: " +
336+ dimSize);
251337 }
252338 }
253339 }
@@ -258,7 +344,8 @@ namespace BindingUtilities
258344
259345 BYTE* actualData;
260346 uint32_t actualSizeInBytes;
261- spTensorValueNative->GetBuffer (&actualData, &actualSizeInBytes); // Need to GetBuffer to have CPU memory backing tensorValue
347+ spTensorValueNative->GetBuffer (
348+ &actualData, &actualSizeInBytes); // Need to GetBuffer to have CPU memory backing tensorValue
262349 return tensorValue;
263350 }
264351 else
@@ -419,8 +506,8 @@ namespace BindingUtilities
419506 com_ptr<ITensorNative> itn = results.Lookup (desc.Name ()).as <ITensorNative>();
420507 HRESULT (itn->GetBuffer (reinterpret_cast <BYTE**>(&tensor), &uCapacity));
421508 int size = 0 ;
422- float maxValue = 0 ;
423- int maxIndex = 0 ;
509+ unsigned int topK = args. TopK () ;
510+ std::vector<std::pair< float , int >> maxKValues ;
424511 std::ofstream fout;
425512 if (args.IsSaveTensor ())
426513 {
@@ -445,12 +532,12 @@ namespace BindingUtilities
445532 break ;
446533 case TensorKind::Float16:
447534 {
448- output.ProcessTensorResult <HALF>(args, tensor, uCapacity, maxValue, maxIndex, fout );
535+ output.ProcessTensorResult <HALF>(args, tensor, uCapacity, maxKValues, fout, topK );
449536 }
450537 break ;
451538 case TensorKind::Float:
452539 {
453- output.ProcessTensorResult <float >(args, tensor, uCapacity, maxValue, maxIndex, fout );
540+ output.ProcessTensorResult <float >(args, tensor, uCapacity, maxKValues, fout, topK );
454541 }
455542 break ;
456543 case TensorKind::Int64:
@@ -472,16 +559,27 @@ namespace BindingUtilities
472559 if (args.IsSaveTensor ())
473560 {
474561 fout.close ();
475- std::string iterationResult =
476- " Index: " + std::to_string (maxIndex) + " ; Value: " + std::to_string (maxValue);
477- output.SaveResult (iterationNum, iterationResult, static_cast <int >(hash_data (tensor, uCapacity)));
562+ for (auto & pair : maxKValues)
563+ {
564+ auto maxValue = pair.first ;
565+ auto maxIndex = pair.second ;
566+ std::string iterationResult =
567+ " Index: " + std::to_string (maxIndex) + " ; Value: " + std::to_string (maxValue);
568+ output.SaveResult (iterationNum, iterationResult,
569+ static_cast <int >(hash_data (tensor, uCapacity)));
570+ }
478571 }
479572 if (!args.IsGarbageInput () && iterationNum == 0 )
480573 {
481- std::cout << " Outputting results.. " << std::endl;
574+ std::cout << " Outputting top " << args. TopK () << " values " << std::endl;
482575 std::cout << " Feature Name: " << name << std::endl;
483- std::wcout << " resultVector[" << maxIndex << " ] has the maximal value of " << maxValue
484- << std::endl;
576+ for (auto & pair : maxKValues)
577+ {
578+ auto maxValue = pair.first ;
579+ auto maxIndex = pair.second ;
580+ std::wcout << " index: " << maxIndex << " , value: " << maxValue
581+ << std::endl;
582+ }
485583 }
486584 }
487585 else if (desc.Kind () == LearningModelFeatureKind::Sequence)
0 commit comments