Skip to content

Commit f873d3a

Browse files
author
lilong12
authored
bug fix shard_index (#37042) (#37421)
1 parent 4dc426f commit f873d3a

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

paddle/fluid/operators/shard_index_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class ShardIndexOp : public framework::OperatorWithKernel {
3131
"but the value given is %d.",
3232
x_dims.size()));
3333
if (ctx->IsRuntime() || x_dims[x_dims.size() - 1] > 0) {
34-
PADDLE_ENFORCE_GE(x_dims[x_dims.size() - 1], 1U,
34+
PADDLE_ENFORCE_EQ(x_dims[x_dims.size() - 1], 1U,
3535
platform::errors::InvalidArgument(
3636
"The last dimension of Input(X) should be 1, "
3737
"but the value given is %d.",

python/paddle/fluid/layers/nn.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14914,28 +14914,37 @@ def deformable_roi_pooling(input,
1491414914
@deprecated(since="2.0.0", update_to="paddle.shard_index")
1491514915
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
1491614916
"""
14917-
Recompute the `input` indices according to the offset of the
14918-
shard. The length of the indices is evenly divided into N shards, and if
14919-
the `shard_id` matches the shard with the input index inside, the index is
14920-
recomputed on the basis of the shard offset, elsewise it is set to
14921-
`ignore_value`. The detail is as follows:
14917+
Reset the values of `input` according to the shard it beloning to.
14918+
Every value in `input` must be a non-negative integer, and
14919+
the parameter `index_num` represents the integer above the maximum
14920+
value of `input`. Thus, all values in `input` must be in the range
14921+
[0, index_num) and each value can be regarded as the offset to the beginning
14922+
of the range. The range is further split into multiple shards. Specifically,
14923+
we first compute the `shard_size` according to the following formula,
14924+
which represents the number of integers each shard can hold. So for the
14925+
i'th shard, it can hold values in the range [i*shard_size, (i+1)*shard_size).
1492214926
::
1492314927

1492414928
shard_size = (index_num + nshards - 1) // nshards
14925-
y = x % shard_size if x // shard_size == shard_id else ignore_value
1492614929

14927-
NOTE: If the length of indices cannot be evely divided by the shard number,
14928-
the size of the last shard will be less than the calculated `shard_size`
14930+
For each value `v` in `input`, we reset it to a new value according to the
14931+
following formula:
14932+
::
14933+
14934+
v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value
14935+
14936+
That is, the value `v` is set to the new offset within the range represented by the shard `shard_id`
14937+
if it in the range. Otherwise, we reset it to be `ignore_value`.
1492914938

1493014939
Args:
14931-
input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1.
14932-
index_num (int): An integer defining the range of the index.
14940+
input (Tensor): Input tensor with data type int64 or int32. It's last dimension must be 1.
14941+
index_num (int): An integer represents the integer above the maximum value of `input`.
1493314942
nshards (int): The number of shards.
1493414943
shard_id (int): The index of the current shard.
1493514944
ignore_value (int): An integer value out of sharded index range.
1493614945

1493714946
Returns:
14938-
Tensor: The sharded index of input.
14947+
Tensor.
1493914948

1494014949
Examples:
1494114950
.. code-block:: python

0 commit comments

Comments
 (0)