@@ -1629,6 +1629,22 @@ class vk_perf_logger {
16291629 timings[name].push_back(time);
16301630 return;
16311631 }
1632+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
1633+ const ggml_tensor * dst = node;
1634+ const ggml_tensor * q = node->src[0];
1635+ const ggml_tensor * k = node->src[1];
1636+ const ggml_tensor * v = node->src[2];
1637+ const ggml_tensor * m = node->src[3];
1638+ std::stringstream name;
1639+ name << ggml_op_name(node->op) <<
1640+ " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
1641+ " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
1642+ " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
1643+ " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
1644+ " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
1645+ timings[name.str()].push_back(time);
1646+ return;
1647+ }
16321648 timings[ggml_op_name(node->op)].push_back(time);
16331649 }
16341650 private:
0 commit comments