Skip to content

Commit caa12d1

Browse files
authored
Merge pull request #16387 from phlrain/pick_matmul_shape
Merge pull request #16347 from phlrain/fix_matmul_check
2 parents e17ce37 + d37bd5c commit caa12d1

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

paddle/fluid/operators/matmul_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel {
290290
context->Attrs().Get<bool>("transpose_Y"));
291291

292292
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
293-
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
294-
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
293+
if (context->IsRuntime()) {
294+
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
295+
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
296+
}
295297
std::vector<int64_t> dim_out;
296298
if (mat_dim_x.batch_size_ != 0) {
297299
dim_out = framework::vectorize(dim_x);

python/paddle/fluid/layers/nn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4732,6 +4732,9 @@ def __check_input(x, y):
47324732

47334733
if len(y_shape) > 2:
47344734
for i, dim_x in enumerate(x_shape[:-2]):
4735+
# don't check neg shape
4736+
if dim_x < 0 or y_shape[i] < 0:
4737+
continue
47354738
if dim_x != y_shape[i]:
47364739
raise ValueError("Invalid inputs for matmul.")
47374740

0 commit comments

Comments
 (0)