Skip to content

Commit e06f443

Browse files
author
lilong12
authored
add the framework support for distfc (#21197) (#21463)
* add the framework support for distfc and ut, test=develop * fix the implementation of shard_index_op, test=develop
1 parent 9c63b7c commit e06f443

File tree

6 files changed

+74
-7
lines changed

6 files changed

+74
-7
lines changed

paddle/fluid/operators/shard_index_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ __global__ void ShardIndexInner(const T* in_data, T* out_data,
2626
const int64_t numel, const int index_num,
2727
const int nshards, const int shard_id,
2828
const int ignore_value) {
29-
int shard_size = index_num / nshards;
29+
int shard_size = (index_num + nshards - 1) / nshards;
3030
int idx = blockIdx.x * blockDim.x + threadIdx.x;
3131
if (idx < numel) {
3232
assert(in_data[idx] >= 0 && in_data[idx] < index_num);

paddle/fluid/operators/shard_index_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ShardIndexCPUKernel : public framework::OpKernel<T> {
3434
PADDLE_ENFORCE(shard_id >= 0 && shard_id < nshards,
3535
"shard_id(%d) is not in range [0, %d)", shard_id, nshards);
3636

37-
int shard_size = index_num / nshards;
37+
int shard_size = (index_num + nshards - 1) / nshards;
3838

3939
out->Resize(in->dims());
4040
out->set_lod(in->lod());

python/paddle/fluid/framework.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,35 @@ def _set_error_clip(self, error_clip):
12501250
"""
12511251
self.error_clip = error_clip
12521252

1253+
def _set_info(self, key, value):
1254+
"""
1255+
Set key-value information for this variable.
1256+
1257+
Args:
1258+
key(str): Key for this information.
1259+
value(object): The value associated to the key.
1260+
1261+
Returns:
1262+
None
1263+
"""
1264+
if not hasattr(self, "_info"):
1265+
self._info = {}
1266+
self._info[key] = value
1267+
1268+
def _get_info(self, key):
1269+
"""
1270+
Get the information of this variable corresponding to key.
1271+
1272+
Args:
1273+
key(str): Key for this information.
1274+
1275+
Returns:
1276+
object
1277+
"""
1278+
if hasattr(self, "_info") and key in self._info:
1279+
return self._info[key]
1280+
return None
1281+
12531282
def _slice_indices(self, slice, length):
12541283
"""
12551284
Reference implementation for the slice.indices method.

python/paddle/fluid/layers/nn.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17655,10 +17655,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
1765517655
"""
1765617656
op_type = 'shard_index'
1765717657
helper = LayerHelper(op_type, **locals())
17658-
if index_num % nshards != 0:
17659-
raise ValueError(
17660-
'The index_num(%d) cannot be evenly divided by nshards(%d)' %
17661-
(index_num, nshards))
1766217658
if shard_id < 0 or shard_id >= nshards:
1766317659
raise ValueError('The shard_id(%d) should be in [0, %d)' %
1766417660
(shard_id, nshards))

python/paddle/fluid/tests/unittests/test_shard_index_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def common_setup(self, index_num, nshards, shard_id, ignore_value):
3131
x = [np.random.randint(0, index_num - 1) for i in range(N)]
3232
x = np.array(x).astype('int32').reshape([N, 1])
3333

34-
shard_size = index_num // nshards
34+
shard_size = (index_num + nshards - 1) // nshards
3535
out = np.zeros(shape=x.shape).astype('int32')
3636
for i in range(N):
3737
if x[i] // shard_size == shard_id:
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
TestCases for Dataset,
16+
including create, config, run, etc.
17+
"""
18+
19+
from __future__ import print_function
20+
import paddle.fluid as fluid
21+
import numpy as np
22+
import os
23+
import shutil
24+
import unittest
25+
26+
27+
class TestVarInfo(unittest.TestCase):
28+
""" TestCases for Dataset. """
29+
30+
def test_var_info(self):
31+
""" Testcase for get and set info for variable. """
32+
value = np.random.randn(1)
33+
var = fluid.layers.create_global_var([1], value, "float32")
34+
var._set_info("name", "test")
35+
ret = var._get_info("name")
36+
assert ret == "test"
37+
ret = var._get_info("not_exist")
38+
assert ret == None
39+
40+
41+
if __name__ == '__main__':
42+
unittest.main()

0 commit comments

Comments
 (0)