@@ -224,27 +224,60 @@ static int compute_tensor_statistics(std::vector<tensor_statistics> & tstats, co
224224 return calc_mode;
225225}
226226
227- static void compute_cossim (std::vector<tensor_statistics> & tstats) {
227+ static void compute_layer_statistics (std::vector<tensor_statistics> & tstats) {
228228 static const std::regex pattern (R"( blk\.(\d+)\.)" );
229+
230+ auto build_avg = [](const Stats & s) -> std::vector<float > {
231+ if (s.counts .empty ()) return {};
232+ const size_t n_mat = s.counts .size ();
233+ const size_t len = !s.in_sum .empty () ? s.in_sum .size ()
234+ : s.in_sum2 .size ();
235+ if (len == 0 || len % n_mat != 0 ) return {};
236+ const size_t row = len / n_mat;
237+ std::vector<float > v;
238+ v.reserve (len);
239+ if (!s.in_sum .empty ()) {
240+ for (size_t m = 0 ; m < n_mat; ++m) {
241+ const float c = (float )s.counts [m];
242+ if (c <= 0 ) return {};
243+ const size_t off = m*row;
244+ for (size_t j = 0 ; j < row; ++j) v.push_back (s.in_sum [off+j]/c);
245+ }
246+ } else {
247+ for (size_t m = 0 ; m < n_mat; ++m) {
248+ const float c = (float )s.counts [m];
249+ if (c <= 0 ) return {};
250+ const size_t off = m*row;
251+ for (size_t j = 0 ; j < row; ++j) v.push_back (s.in_sum2 [off+j]/c);
252+ }
253+ }
254+ return v;
255+ };
256+ // compute the cosine similarity between the same tensors in consecutive layers
229257 for (auto & ts : tstats) {
258+ ts.cossim = 0 ;
259+
230260 if (std::smatch match; std::regex_search (ts.tensor , match, pattern)) {
231261 const int blk = std::stoi (match[1 ]);
262+ if (blk <= 0 ) continue ;
232263 std::string tname (ts.tensor );
233264 tname.replace (match.position (1 ), match.length (1 ), std::to_string (blk-1 ));
234265 auto prev = std::find_if (tstats.begin (), tstats.end (),
235266 [tname](const tensor_statistics & t) { return t.tensor == tname; });
236- if (prev != tstats.end ()) {
237- const float dot_product = std::inner_product (ts.stats .in_sum2 .begin (), ts.stats .in_sum2 .end (),
238- prev->stats .in_sum2 .begin (), 0 .0f );
239- const float magnitude = std::sqrt (std::inner_product (ts.stats .in_sum2 .begin (), ts.stats .in_sum2 .end (),
240- ts.stats .in_sum2 .begin (), 0 .0f ));
241- const float prev_magnitude = std::sqrt (std::inner_product (prev->stats .in_sum2 .begin (), prev->stats .in_sum2 .end (),
242- prev->stats .in_sum2 .begin (), 0 .0f ));
243- const float cos_sim = dot_product / (magnitude * prev_magnitude);
244- ts.cossim = cos_sim;
267+ if (prev == tstats.end ()) continue ;
268+ const auto curr_avg = build_avg (ts.stats );
269+ const auto prev_avg = build_avg (prev->stats );
270+ if (curr_avg.size () == prev_avg.size () && !curr_avg.empty ()) {
271+ float dot_prod = 0 .0f , vec1 = 0 .0f , vec2 = 0 .0f ;
272+ for (size_t i = 0 ; i < curr_avg.size (); ++i) {
273+ dot_prod += curr_avg[i]*prev_avg[i];
274+ vec1 += curr_avg[i]*curr_avg[i];
275+ vec2 += prev_avg[i]*prev_avg[i];
276+ }
277+ if (vec1 > 0 && vec2 > 0 ) ts.cossim = dot_prod / (std::sqrt (vec1)*std::sqrt (vec2));
245278 }
246- } else {
247- ts. cossim = 0 ;
279+ }
280+ }
248281 }
249282 }
250283}
0 commit comments