Skip to content

Commit 62ac268

Browse files
committed
Improve tensor influence ranking
1 parent 490a8fe commit 62ac268

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

examples/imatrix/imatrix.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct Stats {
3535

3636
struct Tally {
3737
std::string tensor;
38-
float value = 0;
38+
double bias = 0;
3939
int count = 0;
4040
};
4141

@@ -370,19 +370,20 @@ bool IMatrixCollector::load_imatrix(const char * fname, std::vector<Tally> * tal
370370
}
371371

372372
// Recreate the state as expected by save_imatrix(), and correct for weighted sum.
373-
float total = 0;
373+
double total = 0;
374374
for (int i = 0; i < nval; i++) {
375375
e.values[i] += tmp[i];
376-
total += tmp[i];
377376
e.counts[i] += ncall;
377+
const double avg_sq = (1.0 * e.values[i]) / e.counts[i];
378+
total += avg_sq;
378379
}
379380
e.ncall += ncall;
380381

381382
if (tally) {
382383
tally->emplace_back();
383-
auto & [tensor, value, count] = (*tally)[i];
384+
auto & [tensor, bias, count] = (*tally)[i];
384385
tensor = name_as_vec.data();
385-
value = total;
386+
bias = total;
386387
count = nval;
387388
}
388389
}
@@ -647,25 +648,25 @@ int main(int argc, char ** argv) {
647648
LOG_ERR("Error: cannot compute statistics for %s\n\n", params.in_files[0].c_str());
648649
return 1;
649650
}
650-
float total = 0;
651+
double total = 0;
651652
for (const auto & tallie : tallies) {
652-
total += tallie.value / static_cast<float>(tallie.count);
653+
total += tallie.bias;
653654
}
654655

655656
struct tally_sort {
656657
bool operator()(const Tally& x, const Tally & y) const {
657-
return x.value / static_cast<float>(x.count) > y.value / static_cast<float>(y.count);
658+
return x.bias > y.bias;
658659
}
659660
};
660661
std::sort(tallies.begin(), tallies.end(), tally_sort());
661662

662663
LOG_INF("\nComputing statistics for %s (%d tensors)\n", params.in_files[0].c_str(), static_cast<int>(tallies.size()));
663-
LOG_INF("\n Layer\t Tensor\t μ(Importance Scores)\t Contribution\n");
664-
LOG_INF("================================================================================\n");
665-
for (const auto & [tensor, value, count] : tallies) {
664+
LOG_INF("\n Layer\t Tensor\t Total Bias\tAvg Bias\t Contribution\n");
665+
LOG_INF("===============================================================================================\n");
666+
for (const auto & [tensor, bias, count] : tallies) {
666667
std::string layer, name;
667668
process_tensor_name(tensor, layer, name);
668-
LOG_INF("%5s\t%30s\t%15.2f\t%20.4f %%\n", layer.c_str(), name.c_str(), value / count, 100.0f * (value / count / total));
669+
LOG_INF("%5s\t%30s\t%15.2f\t%15.4f\t%19.4f%%\n", layer.c_str(), name.c_str(), bias, bias / count, 100.0 * bias / total);
669670
}
670671
LOG_INF("\n");
671672
return 0;

0 commit comments

Comments
 (0)