@@ -893,8 +893,18 @@ class LLM
893893
894894 auto &output = layer.layer .get_output (_attr.prefill_grpid , " output" );
895895 axcl_MemcpyAsync (embed_tmp.data (), (void *)output.phyAddr , embed_tmp.size () * sizeof (unsigned short ), AXCL_MEMCPY_DEVICE_TO_HOST, stream, layer.layer .get_devid ());
896+ if (m < _attr.axmodel_num - 1 )
897+ {
898+ if (llama_layers[m + 1 ].layer .get_devid () != layer.layer .get_devid ())
899+ {
900+ axcl_SynchronizeStream (stream, layer.layer .get_devid ());
901+ }
902+ }
903+ else if (m == _attr.axmodel_num - 1 )
904+ {
905+ axcl_SynchronizeStream (stream, layer.layer .get_devid ());
906+ }
896907
897- axcl_SynchronizeStream (stream, layer.layer .get_devid ());
898908 // ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32());
899909 }
900910 if (p == (prefill_split_num - 1 ))
@@ -1001,7 +1011,6 @@ class LLM
10011011 {
10021012 axcl_MemcpyAsync ((void *)llama_post.get_input (" input" ).phyAddr ,
10031013 (void *)layer.layer .get_output (decode_grpid, " output" ).phyAddr , llama_post.get_input (" input" ).nSize , AXCL_MEMCPY_DEVICE_TO_DEVICE, stream, llama_post.get_devid ());
1004- axcl_SynchronizeStream (stream, layer.layer .get_devid ());
10051014 }
10061015 else
10071016 {
@@ -1018,7 +1027,6 @@ class LLM
10181027 {
10191028 axcl_MemcpyAsync ((void *)llama_layers[m + 1 ].layer .get_input (decode_grpid, " input" ).phyAddr ,
10201029 (void *)layer.layer .get_output (decode_grpid, " output" ).phyAddr , layer.layer .get_input (decode_grpid, " input" ).nSize , AXCL_MEMCPY_DEVICE_TO_DEVICE, stream, layer.layer .get_devid ());
1021- axcl_SynchronizeStream (stream, layer.layer .get_devid ());
10221030 }
10231031 else
10241032 {
0 commit comments