@@ -75,56 +75,141 @@ namespace WinMLRunnerTest
7575 }
7676 }
7777
78- bool CompareTensorsProvidedEpsilonAndRelativeTolerance (
79- const std::wstring &expectedOutputTensorFile,
80- const std::wstring &actualOutputTensorFile,
81- float relativeTolerance,
82- float epsilon)
78+ bool
79+ CompareTensorsProvidedEpsilonAndRelativeTolerance (const std::vector<std::pair<int , float >>& expectedOutputTensors,
80+ const std::vector<std::pair<int , float >>& actualOutputTensors,
81+ float relativeTolerance, float epsilon)
8382 {
84- bool check = false ;
85- std::ifstream expectedFileStream;
86- std::ifstream actualFileStream;
87- expectedFileStream.open (expectedOutputTensorFile);
88- actualFileStream.open (actualOutputTensorFile);
89- std::string actualValue;
90- std::string expectedValue;
91- if (expectedFileStream.fail () || actualFileStream.fail ())
83+ if (expectedOutputTensors.size () != 0 && actualOutputTensors.size () != 0 &&
84+ expectedOutputTensors.size () != actualOutputTensors.size ())
9285 {
93- return false ;
86+ Assert::Fail (
87+ L" One of the output tensors is empty or expected and Actual Output tensors are different sizes\n " );
88+ }
89+ bool doesActualMatchExpected = true ;
90+ for (int i = 0 ; i < expectedOutputTensors.size (); i++)
91+ {
92+ float actualValueNum = actualOutputTensors[i].second ;
93+ float expectedValueNum = expectedOutputTensors[i].second ;
94+ if (std::abs (actualValueNum - expectedValueNum) > 0.001 &&
95+ std::abs (actualValueNum - expectedValueNum) >
96+ relativeTolerance * std::abs (expectedValueNum) + epsilon) // Check if the values are too different.
97+ {
98+ printf (" Expected and Actual tensor value is too different at Index: %d. Expected: %f, Actual: %f\n " , i,
99+ expectedValueNum, actualValueNum);
100+ doesActualMatchExpected = false ;
101+ }
102+ }
103+ return doesActualMatchExpected;
104+ }
105+
106+ void PopulateTensorLists (const std::wstring& tensorFile, std::vector<std::pair<int , float >>& tensorList)
107+ {
108+ std::ifstream tensorFileStream;
109+ tensorFileStream.open (tensorFile);
110+ std::string index;
111+ std::string value;
112+ if (tensorFileStream.fail ())
113+ {
114+ Assert::Fail (L" Failed to open tensor files\n " );
94115 }
95116 bool isFirstRow = true ;
96- while (!(expectedFileStream .eof () || actualFileStream. eof () ))
117+ while (!tensorFileStream .eof ())
97118 {
98- std::getline (expectedFileStream, expectedValue, ' ,' );
99- std::getline (expectedFileStream, expectedValue, ' \n ' );
100- std::getline (actualFileStream, actualValue, ' ,' );
101- std::getline (actualFileStream, actualValue, ' \n ' );
119+ std::getline (tensorFileStream, index, ' ,' );
120+ std::getline (tensorFileStream, value, ' \n ' );
102121 if (isFirstRow)
103122 {
104123 isFirstRow = false ;
105124 continue ;
106125 }
107- float actualValueNum = (actualValue == " " ) ? 0 : std::stof (actualValue);
108- float expectedValueNum = (expectedValue == " " ) ? 0 : std::stof (expectedValue);
109- if (std::abs (actualValueNum - expectedValueNum) >
110- relativeTolerance * std::abs (expectedValueNum) + epsilon) // Check if the values are too different.
126+ if (value != " " && index != " " )
111127 {
112- return false ;
128+ tensorList. push_back ( std::make_pair ( std::stoi (index), std::stof (value))) ;
113129 }
114130 }
115- return true ;
131+ }
132+
133+ // This method sorts the expected output tensors and actual output tensors from largest tensor value to smallest
134+ // tensor value. It takes the percentage decrease between the highest tensor value to the next highest tensor value
135+ // for both sorted tensor lists and so on. Using relative tolerance and epsilon, we can compare between expected
136+ // percentage decrease with actual percentage decrease.
137+ bool CompareTensorValuesRelative (std::vector<std::pair<int , float >>& expectedOutputTensors,
138+ std::vector<std::pair<int , float >>& actualOutputTensors,
139+ const float relativeTolerance, const float epsilon,
140+ const float smallestValueToCompare)
141+ {
142+ if (expectedOutputTensors.size () != 0 && actualOutputTensors.size () != 0 &&
143+ expectedOutputTensors.size () != actualOutputTensors.size ())
144+ {
145+ Assert::Fail (
146+ L" One of the output tensors is empty or expected and Actual Output tensors are different sizes\n " );
147+ }
148+ // Sort expected and actual output tensors from highest to lowest. NOTE: This will modify the original
149+ // parameters.
150+ std::sort (expectedOutputTensors.begin (), expectedOutputTensors.end (),
151+ [](auto & left, auto & right) { return left.second > right.second ; });
152+ std::sort (actualOutputTensors.begin (), actualOutputTensors.end (),
153+ [](auto & left, auto & right) { return left.second > right.second ; });
154+
155+ bool currentValueIsLargeEnough = true ;
156+ bool doesActualMatchExpected = true ;
157+ int currentIndex = 0 ;
158+ while (currentValueIsLargeEnough && currentIndex < expectedOutputTensors.size ())
159+ {
160+ // Compare expected vs actual prediction index
161+ if (expectedOutputTensors[currentIndex].first != actualOutputTensors[currentIndex].first )
162+ {
163+ printf (" Top Expected Index:%d and Actual Index:%d don't match!" ,
164+ expectedOutputTensors[currentIndex].first , actualOutputTensors[currentIndex].first );
165+ doesActualMatchExpected = false ;
166+ }
167+ else if (currentIndex > 0 )
168+ {
169+ float expectedTensorRatio =
170+ (expectedOutputTensors[currentIndex].second - expectedOutputTensors[currentIndex - 1 ].second ) /
171+ expectedOutputTensors[currentIndex - 1 ].second ;
172+ float actualTensorRatio =
173+ (actualOutputTensors[currentIndex].second - actualOutputTensors[currentIndex - 1 ].second ) /
174+ actualOutputTensors[currentIndex - 1 ].second ;
175+ // Compare the percentage difference between top values
176+ if (std::abs (expectedTensorRatio - actualTensorRatio) >
177+ relativeTolerance * std::abs (expectedTensorRatio) + epsilon)
178+ {
179+ printf (" Actual ratio difference of top values between index %d and index %d don't match expected "
180+ " ratio difference" ,
181+ currentIndex - 1 , currentIndex);
182+ doesActualMatchExpected = false ;
183+ }
184+ }
185+ currentValueIsLargeEnough = expectedOutputTensors[++currentIndex].second > smallestValueToCompare;
186+ }
187+ return doesActualMatchExpected;
116188 }
117189
118190 bool CompareTensors (const std::wstring& expectedOutputTensorFile, const std::wstring& actualOutputTensorFile)
119191 {
120- return CompareTensorsProvidedEpsilonAndRelativeTolerance (expectedOutputTensorFile, actualOutputTensorFile,
121- 0 .003f , 0 );
192+ std::vector<std::pair<int , float >> expectedOutputTensors;
193+ std::vector<std::pair<int , float >> actualOutputTensors;
194+ PopulateTensorLists (expectedOutputTensorFile, expectedOutputTensors);
195+ PopulateTensorLists (actualOutputTensorFile, actualOutputTensors);
196+ return CompareTensorsProvidedEpsilonAndRelativeTolerance (expectedOutputTensors, actualOutputTensors, 0 .003f , 0 );
122197 }
123198
124199 bool CompareTensorsFP16 (const std::wstring& expectedOutputTensorFile, const std::wstring& actualOutputTensorFile)
125200 {
126- return CompareTensorsProvidedEpsilonAndRelativeTolerance (expectedOutputTensorFile, actualOutputTensorFile,
127- 0 .06f , 0 );
201+ std::vector<std::pair<int , float >> expectedOutputTensors;
202+ std::vector<std::pair<int , float >> actualOutputTensors;
203+ PopulateTensorLists (expectedOutputTensorFile, expectedOutputTensors);
204+ PopulateTensorLists (actualOutputTensorFile, actualOutputTensors);
205+ bool compareAllTensorsResult =
206+ CompareTensorsProvidedEpsilonAndRelativeTolerance (expectedOutputTensors, actualOutputTensors, 0 .06f , 0 );
207+ if (!compareAllTensorsResult) // fall back to more forgiving comparison that compares order of top indexes
208+ {
209+ // After calling CompareTensorValuesRelative, the tensor lists will be sorted from largest to smallest
210+ return CompareTensorValuesRelative (expectedOutputTensors, actualOutputTensors, 0 .1f , 0 .05f , 0 .001f );
211+ }
212+ return true ;
128213 }
129214
130215 TEST_CLASS (GarbageInputTest){ public : TEST_CLASS_INITIALIZE (SetupClass){
@@ -777,4 +862,4 @@ namespace WinMLRunnerTest
777862 }
778863 */
779864 };
780- }
865+ }
0 commit comments