File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
paddle/phi/kernels/fusion/onednn Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -180,7 +180,8 @@ class FusedMatmulOneDNNHandler
180
180
auto residual_data_tz = vectorize (residual_data->dims ());
181
181
auto chosen_memory_format = funcs::OneDNNMemoryFormat::any;
182
182
dnnl::memory::desc residual_data_md;
183
- if (residual_data_tz.size () == 4 && residual_data_tz[0 ] == 1 &&
183
+ if (out_ddims.size () > 0 && out_ddims[0 ] > 1 &&
184
+ residual_data_tz.size () == 4 && residual_data_tz[0 ] == 1 &&
184
185
residual_data_tz[1 ] > 1 && residual_data_tz[2 ] > 1 &&
185
186
residual_data_tz[3 ] > 1 ) {
186
187
chosen_memory_format = funcs::OneDNNMemoryFormat::nchw;
@@ -320,7 +321,8 @@ void ExecuteFusedMatmul(const OneDNNContext &dev_ctx,
320
321
if (residual_data) {
321
322
auto residual_data_vec = vectorize (residual_data->dims ());
322
323
std::shared_ptr<dnnl::memory> residual_data_memory_p;
323
- if (residual_data_vec.size () == 4 && residual_data_vec[0 ] == 1 &&
324
+ if (std::max ((x_dims)[0 ], (y_dims)[0 ]) > 1 &&
325
+ residual_data_vec.size () == 4 && residual_data_vec[0 ] == 1 &&
324
326
residual_data_vec[1 ] > 1 && residual_data_vec[2 ] > 1 &&
325
327
residual_data_vec[3 ] > 1 ) {
326
328
residual_data_memory_p = handler.AcquireSrcMemoryStride (residual_data);
You can’t perform that action at this time.
0 commit comments