@@ -14914,28 +14914,37 @@ def deformable_roi_pooling(input,
14914
14914
@deprecated(since="2.0.0", update_to="paddle.shard_index")
14915
14915
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
14916
14916
"""
14917
- Recompute the `input` indices according to the offset of the
14918
- shard. The length of the indices is evenly divided into N shards, and if
14919
- the `shard_id` matches the shard with the input index inside, the index is
14920
- recomputed on the basis of the shard offset, elsewise it is set to
14921
- `ignore_value`. The detail is as follows:
14917
+ Reset the values of `input` according to the shard it beloning to.
14918
+ Every value in `input` must be a non-negative integer, and
14919
+ the parameter `index_num` represents the integer above the maximum
14920
+ value of `input`. Thus, all values in `input` must be in the range
14921
+ [0, index_num) and each value can be regarded as the offset to the beginning
14922
+ of the range. The range is further split into multiple shards. Specifically,
14923
+ we first compute the `shard_size` according to the following formula,
14924
+ which represents the number of integers each shard can hold. So for the
14925
+ i'th shard, it can hold values in the range [i*shard_size, (i+1)*shard_size).
14922
14926
::
14923
14927
14924
14928
shard_size = (index_num + nshards - 1) // nshards
14925
- y = x % shard_size if x // shard_size == shard_id else ignore_value
14926
14929
14927
- NOTE: If the length of indices cannot be evely divided by the shard number,
14928
- the size of the last shard will be less than the calculated `shard_size`
14930
+ For each value `v` in `input`, we reset it to a new value according to the
14931
+ following formula:
14932
+ ::
14933
+
14934
+ v = v - shard_id * shard_size if shard_id * shard_size <= v < (shard_id+1) * shard_size else ignore_value
14935
+
14936
+ That is, the value `v` is set to the new offset within the range represented by the shard `shard_id`
14937
+ if it in the range. Otherwise, we reset it to be `ignore_value`.
14929
14938
14930
14939
Args:
14931
- input (Tensor): Input indices with data type int64 or int32. It's last dimension must be 1.
14932
- index_num (int): An integer defining the range of the index .
14940
+ input (Tensor): Input tensor with data type int64 or int32. It's last dimension must be 1.
14941
+ index_num (int): An integer represents the integer above the maximum value of `input` .
14933
14942
nshards (int): The number of shards.
14934
14943
shard_id (int): The index of the current shard.
14935
14944
ignore_value (int): An integer value out of sharded index range.
14936
14945
14937
14946
Returns:
14938
- Tensor: The sharded index of input .
14947
+ Tensor.
14939
14948
14940
14949
Examples:
14941
14950
.. code-block:: python
0 commit comments