From a6d6fce0ae480ee5aab368c0c4d39ec468f78c4f Mon Sep 17 00:00:00 2001 From: huang <1421037099@qq.com> Date: Fri, 27 Sep 2024 15:33:16 +0800 Subject: [PATCH] [NPU] fix bug chatglm pass --- backends/npu/passes/chatglm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/backends/npu/passes/chatglm.py b/backends/npu/passes/chatglm.py index fe2819992db..0b8c270d09f 100644 --- a/backends/npu/passes/chatglm.py +++ b/backends/npu/passes/chatglm.py @@ -43,9 +43,9 @@ def pattern( x=hidden_state, ) qkv_proj = ir.PassDesc.OP.matmul_v2(X=input_norm.Output("out"), Y=qkv_weight) + qkv = ir.PassDesc.OP.elementwise_add(X=qkv_proj.Output("Out"), Y=qkv_bias) blha = ir.PassDesc.OP.block_multihead_attention( - qkv=qkv_proj.Output("Out"), - qkv_bias=qkv_bias, + qkv=qkv.Output("Out"), key_cache=key_cache, value_cache=value_cache, seq_lens_encoder=seq_lens_encoder, @@ -183,9 +183,9 @@ def pattern( x=embedding.Output("Out"), ) qkv_proj = ir.PassDesc.OP.matmul_v2(X=input_norm.Output("out"), Y=qkv_weight) + qkv = ir.PassDesc.OP.elementwise_add(X=qkv_proj.Output("Out"), Y=qkv_bias) blha = ir.PassDesc.OP.block_multihead_attention( - qkv=qkv_proj.Output("Out"), - qkv_bias=qkv_bias, + qkv=qkv.Output("Out"), key_cache=key_cache, value_cache=value_cache, seq_lens_encoder=seq_lens_encoder, @@ -324,9 +324,9 @@ def pattern( x=hidden_state, ) qkv_proj = ir.PassDesc.OP.matmul_v2(X=input_norm.Output("out"), Y=qkv_weight) + qkv = ir.PassDesc.OP.elementwise_add(X=qkv_proj.Output("Out"), Y=qkv_bias) blha = ir.PassDesc.OP.block_multihead_attention( - qkv=qkv_proj.Output("Out"), - qkv_bias=qkv_bias, + qkv=qkv.Output("Out"), key_cache=key_cache, value_cache=value_cache, seq_lens_encoder=seq_lens_encoder,