Skip to content

Commit 749bc24

Browse files
authored
cherry-pick #36021 fix unique/unstack zero tensor (#36163)
* fix unique unstack dim 0 * fix unique_op format
1 parent 1db28fd commit 749bc24

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

paddle/fluid/operators/unique_op.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,10 @@ class UniqueKernel : public framework::OpKernel<T> {
403403
bool return_index = context.Attr<bool>("return_index");
404404
bool return_inverse = context.Attr<bool>("return_inverse");
405405
bool return_counts = context.Attr<bool>("return_counts");
406-
406+
if (x->numel() == 0) {
407+
out->mutable_data<T>(context.GetPlace());
408+
return;
409+
}
407410
if (axis_vec.empty()) {
408411
framework::VisitDataTypeTiny(
409412
data_type,

paddle/fluid/operators/unstack_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class UnStackKernel : public framework::OpKernel<T> {
149149
dx_datas[i] = dx[i]->mutable_data<T>(ctx.GetPlace());
150150
}
151151
auto dy_data = dy->data<T>();
152-
152+
if (dy->numel() == 0) return;
153153
int pre = 1;
154154
for (int i = 0; i < axis; ++i) pre *= dy->dims()[i];
155155
int total_num = dy->numel();

python/paddle/fluid/layers/nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10315,6 +10315,8 @@ def unstack(x, axis=0, num=None):
1031510315
if in_dygraph_mode():
1031610316
if num == None:
1031710317
num = x.shape[axis]
10318+
if num == 0:
10319+
return []
1031810320
return _C_ops.unstack(x, num, 'axis', int(axis), 'num', num)
1031910321

1032010322
helper = LayerHelper('unstack', **locals())

0 commit comments

Comments
 (0)