Skip to content

Commit b4f28cc

Browse files
authored
Merge pull request #11632 from JiayiFeng/some_small_fixes
Some small fixes
2 parents f0cf70e + e1a46bb commit b4f28cc

File tree

8 files changed

+53
-44
lines changed

8 files changed

+53
-44
lines changed

paddle/fluid/operators/assign_value_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ AssignValue operator
7070

7171
namespace ops = paddle::operators;
7272

73-
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker);
73+
REGISTER_OPERATOR(assign_value, ops::AssignValueOp, ops::AssignValueOpMaker,
74+
paddle::framework::EmptyGradOpMaker);
7475
REGISTER_OP_CPU_KERNEL(assign_value, ops::AssignValueKernel<int>,
7576
ops::AssignValueKernel<float>);

paddle/fluid/operators/random_crop_op.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
3737
AddOutput("SeedOut", "The random seed after random cropping.")
3838
.AsIntermediate();
3939
AddAttr<std::vector<int>>("shape", "The shape of a cropped instance.");
40+
AddAttr<int>("startup_seed",
41+
"If the input 'Seed' is not initialized, the 'startup_seed' "
42+
"will be used to replace it. Even so, the seed after random "
43+
"crop will also be outputed to the 'SeedOut'.")
44+
.SetDefault(0);
4045
AddComment(R"DOC(
4146
This operator takes a batch of instance, and do random cropping on each instance.
4247
It means that cropping positions differs on each instance, which is determined
@@ -49,8 +54,6 @@ class RandomCropOpMaker : public framework::OpProtoAndCheckerMaker {
4954
class RandomCropOpInferShape : public framework::InferShapeBase {
5055
public:
5156
void operator()(framework::InferShapeContext* ctx) const override {
52-
auto seed_dim = ctx->GetInputDim("Seed");
53-
PADDLE_ENFORCE(seed_dim.size() == 1 && seed_dim[0] == 1);
5457
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
5558
auto x_dim = ctx->GetInputDim("X");
5659
PADDLE_ENFORCE_GT(x_dim.size(), static_cast<int64_t>(shape.size()));
@@ -62,7 +65,6 @@ class RandomCropOpInferShape : public framework::InferShapeBase {
6265
out_dim[x_i] = shape[shape_i];
6366
}
6467
ctx->SetOutputDim("Out", framework::make_ddim(out_dim));
65-
ctx->SetOutputDim("SeedOut", framework::make_ddim({1}));
6668
}
6769
};
6870

paddle/fluid/operators/random_crop_op.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,22 @@ template <typename DeviceContext, typename T>
142142
class RandomCropKernel : public framework::OpKernel<T> {
143143
public:
144144
virtual void Compute(const framework::ExecutionContext& ctx) const {
145-
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
146145
int64_t seed = 0;
147-
if (platform::is_cpu_place(seed_tensor.place())) {
148-
seed = *seed_tensor.data<int64_t>();
146+
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed"));
147+
if (seed_tensor.IsInitialized()) {
148+
if (platform::is_cpu_place(seed_tensor.place())) {
149+
seed = *seed_tensor.data<int64_t>();
150+
} else {
151+
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
152+
"your program";
153+
framework::LoDTensor cpu_seed;
154+
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
155+
seed = *cpu_seed.data<int64_t>();
156+
}
149157
} else {
150-
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
151-
"your program";
152-
framework::LoDTensor cpu_seed;
153-
framework::TensorCopySync(seed_tensor, platform::CPUPlace(), &cpu_seed);
154-
seed = *cpu_seed.data<int64_t>();
158+
VLOG(5) << "WARNING: The input 'Seed' is not initialized, use attribute "
159+
"'startup_seed' instead.";
160+
seed = ctx.Attr<int>("startup_seed");
155161
}
156162
auto shape = ctx.Attr<std::vector<int>>("shape");
157163
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
@@ -171,7 +177,7 @@ class RandomCropKernel : public framework::OpKernel<T> {
171177
engine.discard(functor.prod_batchsize_dims_ *
172178
(functor.rank_ - functor.num_batchsize_dims_));
173179
*ctx.Output<framework::LoDTensor>("SeedOut")->mutable_data<int64_t>(
174-
platform::CPUPlace()) = engine();
180+
framework::make_ddim({1}), platform::CPUPlace()) = engine();
175181
}
176182
};
177183

paddle/fluid/operators/reader/create_custom_reader_op.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class CustomReader : public framework::DecoratedReader {
3939
const framework::ProgramDesc program_;
4040
int sub_block_id_;
4141
framework::Executor exe_;
42+
framework::Scope scope_;
4243

4344
std::vector<std::string> source_var_names_;
4445
std::vector<std::string> sink_var_names_;
@@ -158,23 +159,24 @@ void CustomReader::ReadNext(std::vector<framework::LoDTensor>* out) {
158159
// The scope for CustomReader's sub-block should be independent and shouldn't
159160
// be any other computation scope's child. Otherwise, data preprocessing and
160161
// compution cannot be concurrent.
161-
framework::Scope scope;
162+
framework::Scope* exe_scope = &scope_.NewScope();
162163
// 1. Copy LoDTensors from underlying reader's output to source variables.
163164
for (size_t i = 0; i < source_var_names_.size(); ++i) {
164-
framework::Variable* var = scope.Var(source_var_names_[i]);
165+
framework::Variable* var = exe_scope->Var(source_var_names_[i]);
165166
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
166167
tensor->ShareDataWith(underlying_outs[i]);
167168
tensor->set_lod(underlying_outs[i].lod());
168169
}
169170
// 2. Run the sub-block.
170-
exe_.Run(program_, &scope, sub_block_id_, false, true);
171+
exe_.Run(program_, exe_scope, sub_block_id_, false, true);
171172
// 3. Copy LoDTensors from sink variables to out.
172173
out->resize(sink_var_names_.size());
173174
for (size_t i = 0; i < sink_var_names_.size(); ++i) {
174-
const auto& tensor = detail::Ref(scope.FindVar(sink_var_names_[i]))
175+
const auto& tensor = detail::Ref(exe_scope->FindVar(sink_var_names_[i]))
175176
.Get<framework::LoDTensor>();
176177
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
177178
}
179+
scope_.DeleteScope(exe_scope);
178180
}
179181

180182
} // namespace reader

paddle/fluid/operators/reader/create_double_buffer_reader_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ namespace reader {
2323

2424
// 'Double buffer' means we shall maintain two batches of input data at the same
2525
// time. So the kCacheSize shoul be at least 2.
26-
static constexpr size_t kCacheSize = 3;
26+
static constexpr size_t kCacheSize = 5;
2727
// There will be two bacthes out of the channel during training:
2828
// 1. the one waiting to be sent to the channel
2929
// 2. the one just be received from the channel, which is also being used by
3030
// subsequent operators.
3131
// So the channel size should be kChacheSize - 2
32-
static constexpr size_t kChannelSize = 1; // kCacheSize - 2
32+
static constexpr size_t kChannelSize = 3; // kCacheSize - 2
3333

3434
class DoubleBufferReader : public framework::DecoratedReader {
3535
public:

python/paddle/fluid/layers/io.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
110110
class ListenAndServ(object):
111111
"""
112112
**ListenAndServ Layer**
113-
113+
114114
ListenAndServ is used to create a rpc server bind and listen
115115
on specific TCP port, this server will run the sub-block when
116116
received variables from clients.
@@ -212,7 +212,7 @@ def Send(endpoints, send_vars, sync=True):
212212
of send_vars to send
213213
send_vars (list): variables to send to server
214214
sync (bool): whether to wait the request finish
215-
215+
216216
"""
217217
assert (type(send_vars) == list)
218218

@@ -469,10 +469,13 @@ def open_files(filenames,
469469
lod_levels(list): List of ints which declaring data lod_level.
470470
dtypes(list): List of strs which declaring data type.
471471
thread_num(int): The maximal concurrent prefetch thread number.
472-
buffer_size(int): The size of prefetch buffer.
472+
buffer_size(int|None): The size of prefetch buffer. If it is setted None,
473+
buffer size will be thread_num * 3.
474+
Default: None
473475
pass_num(int): Number of passes to run.
474476
for_parallel(Bool): Set it as True if you are going to run
475477
subsequent operators in parallel.
478+
Default: True
476479
477480
Returns:
478481
Variable: A Reader Variable via which we can get file data.
@@ -492,7 +495,7 @@ def open_files(filenames,
492495
image, label = fluid.layers.io.read_file(reader)
493496
"""
494497
if buffer_size is None:
495-
buffer_size = thread_num
498+
buffer_size = thread_num * 3
496499
if isinstance(filenames, basestring):
497500
filenames = [filenames]
498501
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]

python/paddle/fluid/layers/nn.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensor import concat
2424
import utils
2525
import random
26+
from .. import unique_name
2627

2728
__all__ = [
2829
'fc',
@@ -4896,34 +4897,26 @@ def random_crop(x, shape, seed=None):
48964897
>>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224])
48974898
"""
48984899
helper = LayerHelper("random_crop", **locals())
4899-
dtype = helper.input_dtype()
4900+
dtype = x.dtype
49004901
out = helper.create_tmp_variable(dtype)
49014902
if seed is None:
49024903
seed = random.randint(-65536, 65535)
4903-
4904+
op_attrs = {"shape": shape}
49044905
if isinstance(seed, int):
4905-
seed_value = seed
4906-
seed = helper.create_tmp_variable(dtype="int64")
4907-
helper.append_op(
4908-
type="fill_constant",
4909-
inputs={},
4910-
outputs={"Out": seed},
4911-
attrs={
4912-
"dtype": seed.dtype,
4913-
"shape": [1],
4914-
"value": float(seed_value),
4915-
"force_cpu": True
4916-
})
4906+
op_attrs["startup_seed"] = seed
4907+
seed = helper.create_variable(
4908+
name=unique_name.generate("random_crop_seed"),
4909+
dtype="int64",
4910+
persistable=True)
49174911
elif not isinstance(seed, Variable):
49184912
raise ValueError("'seed' must be a Variable or an int.")
4919-
seed_out = helper.create_tmp_variable(dtype="int64")
49204913
helper.append_op(
49214914
type="random_crop",
49224915
inputs={"X": x,
49234916
"Seed": seed},
49244917
outputs={"Out": out,
4925-
"SeedOut": seed_out},
4926-
attrs={"shape": shape})
4918+
"SeedOut": seed},
4919+
attrs=op_attrs)
49274920
return out
49284921

49294922

python/paddle/fluid/layers/tensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def cast(x, dtype):
155155
156156
Examples:
157157
.. code-block:: python
158-
158+
159159
data = fluid.layers.data(name='x', shape=[13], dtype='float32')
160160
result = fluid.layers.cast(x=data, dtype='float64')
161161
"""
@@ -188,7 +188,7 @@ def concat(input, axis=0, name=None):
188188
189189
Examples:
190190
.. code-block:: python
191-
191+
192192
out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth])
193193
"""
194194
helper = LayerHelper('concat', **locals())
@@ -238,15 +238,15 @@ def sums(input, out=None):
238238
return out
239239

240240

241-
def assign(input, output):
241+
def assign(input, output=None):
242242
"""
243243
**Assign**
244244
245245
This function copies the *input* Variable to the *output* Variable.
246246
247247
Args:
248248
input(Variable|numpy.ndarray): The source variable
249-
output(Variable): The destination variable
249+
output(Variable|None): The destination variable
250250
251251
Returns:
252252
Variable: The destination variable that was supplied as the *output*.
@@ -259,6 +259,8 @@ def assign(input, output):
259259
fluid.layers.assign(hidden, out)
260260
"""
261261
helper = LayerHelper('assign', **locals())
262+
if output is None:
263+
output = helper.create_tmp_variable(dtype=input.dtype)
262264
if isinstance(input, Variable):
263265
helper.append_op(
264266
type='assign', inputs={'X': [input]}, outputs={'Out': [output]})

0 commit comments

Comments
 (0)