Skip to content

Commit a339e8b

Browse files
[oneDNN] Restrictions on matmul broadcast optimiztion (#59744)
1 parent 0c6b490 commit a339e8b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

paddle/phi/kernels/fusion/onednn/fused_matmul_kernel.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,8 @@ class FusedMatmulOneDNNHandler
180180
auto residual_data_tz = vectorize(residual_data->dims());
181181
auto chosen_memory_format = funcs::OneDNNMemoryFormat::any;
182182
dnnl::memory::desc residual_data_md;
183-
if (residual_data_tz.size() == 4 && residual_data_tz[0] == 1 &&
183+
if (out_ddims.size() > 0 && out_ddims[0] > 1 &&
184+
residual_data_tz.size() == 4 && residual_data_tz[0] == 1 &&
184185
residual_data_tz[1] > 1 && residual_data_tz[2] > 1 &&
185186
residual_data_tz[3] > 1) {
186187
chosen_memory_format = funcs::OneDNNMemoryFormat::nchw;
@@ -320,7 +321,8 @@ void ExecuteFusedMatmul(const OneDNNContext &dev_ctx,
320321
if (residual_data) {
321322
auto residual_data_vec = vectorize(residual_data->dims());
322323
std::shared_ptr<dnnl::memory> residual_data_memory_p;
323-
if (residual_data_vec.size() == 4 && residual_data_vec[0] == 1 &&
324+
if (std::max((x_dims)[0], (y_dims)[0]) > 1 &&
325+
residual_data_vec.size() == 4 && residual_data_vec[0] == 1 &&
324326
residual_data_vec[1] > 1 && residual_data_vec[2] > 1 &&
325327
residual_data_vec[3] > 1) {
326328
residual_data_memory_p = handler.AcquireSrcMemoryStride(residual_data);

0 commit comments

Comments
 (0)