Skip to content

Commit 9744a4a

Browse files
committed
Determine calculation mode
1 parent 78ddb47 commit 9744a4a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

tools/imatrix/imatrix.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,19 @@ static void process_tensor_name(const std::string & input, std::string & layer,
127127
}
128128
}
129129

130-
static void compute_statistics(std::vector<tensor_statistics> & tstats, const std::string & name, const Stats & e) {
130+
static int compute_tensor_statistics(std::vector<tensor_statistics> & tstats, const std::string & name, const Stats & e) {
131131
if (e.in_sum2.size() % e.counts.size() != 0) {
132132
LOG_ERR("%s: activation size mismatch for tensor %s (%zu vs %zu)\n", __func__, name.c_str(), e.counts.size(), e.in_sum2.size());
133-
return;
133+
return -1;;
134134
}
135135
if (e.counts.empty()) {
136136
LOG_ERR("%s: there are no activations for tensor %s. The imatrix may be suboptimal\n", __func__, name.c_str());
137-
return;
137+
return -1;
138138
}
139139

140140
const int n_mat = e.counts.size();
141141
const int row_size = e.in_sum2.size() / n_mat;
142+
const int calc_mode = e.in_sum.empty() ? 2 : 1;
142143

143144
std::vector<float> activations;
144145

@@ -1104,13 +1105,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
11041105

11051106
static bool show_statistics(const common_params & params) {
11061107
std::vector<tensor_statistics> ts;
1108+
int tensor_calc_mode = 0;
1109+
11071110
if (params.in_files.empty() || params.in_files.size() > 1) {
11081111
LOG_ERR("\nError: a single imatrix file is required to compute tensor statistics\n\n");
11091112
return false;
11101113
}
11111114
if (g_collector.load_imatrix(params.in_files[0].c_str())) {
11121115
for (const auto & [name, stats] :g_collector.get_mstats()) {
1113-
compute_statistics(ts, name, stats);
1116+
tensor_calc_mode =compute_tensor_statistics(ts, name, stats);
11141117
}
11151118
} else {
11161119
LOG_ERR("\nError: %s is not a valid imatrix file\n\n", params.in_files[0].c_str());

0 commit comments

Comments
 (0)