diff --git a/backends/npu/passes/chatglm.py b/backends/npu/passes/chatglm.py index fe2819992db..0b8c270d09f 100644 --- a/backends/npu/passes/chatglm.py +++ b/backends/npu/passes/chatglm.py @@ -43,9 +43,9 @@ def pattern( x=hidden_state, ) qkv_proj = ir.PassDesc.OP.matmul_v2(X=input_norm.Output("out"), Y=qkv_weight) + qkv = ir.PassDesc.OP.elementwise_add(X=qkv_proj.Output("Out"), Y=qkv_bias) blha = ir.PassDesc.OP.block_multihead_attention( - qkv=qkv_proj.Output("Out"), - qkv_bias=qkv_bias, + qkv=qkv.Output("Out"), key_cache=key_cache, value_cache=value_cache, seq_lens_encoder=seq_lens_encoder, @@ -183,9 +183,9 @@ def pattern( x=embedding.Output("Out"), ) qkv_proj = ir.PassDesc.OP.matmul_v2(X=input_norm.Output("out"), Y=qkv_weight) + qkv = ir.PassDesc.OP.elementwise_add(X=qkv_proj.Output("Out"), Y=qkv_bias) blha = ir.PassDesc.OP.block_multihead_attention( - qkv=qkv_proj.Output("Out"), - qkv_bias=qkv_bias, + qkv=qkv.Output("Out"), key_cache=key_cache, value_cache=value_cache, seq_lens_encoder=seq_lens_encoder, @@ -324,9 +324,9 @@ def pattern( x=hidden_state, ) qkv_proj = ir.PassDesc.OP.matmul_v2(X=input_norm.Output("out"), Y=qkv_weight) + qkv = ir.PassDesc.OP.elementwise_add(X=qkv_proj.Output("Out"), Y=qkv_bias) blha = ir.PassDesc.OP.block_multihead_attention( - qkv=qkv_proj.Output("Out"), - qkv_bias=qkv_bias, + qkv=qkv.Output("Out"), key_cache=key_cache, value_cache=value_cache, seq_lens_encoder=seq_lens_encoder,