Skip to content

Commit 47b6e5f

Browse files
authored
[Yaml] add yaml for Uniform random and add unit test. (#41517) (#41619)
* gather op * add mod * [Yaml] final state for uniform and uniform_random
1 parent a0b0a32 commit 47b6e5f

File tree

6 files changed

+64
-70
lines changed

6 files changed

+64
-70
lines changed

paddle/fluid/operators/uniform_random_op.cc

Lines changed: 7 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ limitations under the License. */
1616
#include <string>
1717

1818
#include "paddle/fluid/framework/generator.h"
19+
#include "paddle/fluid/framework/infershape_utils.h"
1920
#include "paddle/fluid/framework/op_registry.h"
2021
#include "paddle/fluid/framework/operator.h"
2122
#include "paddle/fluid/platform/bfloat16.h"
23+
#include "paddle/phi/infermeta/nullary.h"
2224

2325
namespace paddle {
2426
namespace operators {
@@ -122,74 +124,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
122124
public:
123125
using framework::OperatorWithKernel::OperatorWithKernel;
124126

125-
void InferShape(framework::InferShapeContext *ctx) const override {
126-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "UniformRandomOp");
127-
128-
PADDLE_ENFORCE_LT(
129-
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
130-
platform::errors::InvalidArgument(
131-
"The uniform_random's min must less then max. But received min = "
132-
"%f great than or equal max = %f.",
133-
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
134-
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
135-
platform::errors::InvalidArgument(
136-
"The uniform_random's diag_num must greater than or "
137-
"equal 0. But recevied diag_num (%d) < 0.",
138-
ctx->Attrs().Get<int>("diag_num")));
139-
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
140-
platform::errors::InvalidArgument(
141-
"The uniform_random's diag_step must greater than or "
142-
"equal 0. But recevied diag_step (%d) < 0.",
143-
ctx->Attrs().Get<int>("diag_step")));
144-
145-
if (ctx->HasInputs("ShapeTensorList")) {
146-
// top prority shape
147-
auto inputs_name = ctx->Inputs("ShapeTensorList");
148-
PADDLE_ENFORCE_GT(inputs_name.size(), 0,
149-
platform::errors::InvalidArgument(
150-
"Input(ShapeTensorList)'size of "
151-
"Op(uniform_random) can't be zero."
152-
"Please check the Attr(shape)'s size of"
153-
"Op(fluid.layers.uniform_random).)"));
154-
auto out_dims = std::vector<int>(inputs_name.size(), -1);
155-
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));
156-
157-
return;
158-
}
159-
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
160-
if (ctx->HasInput("ShapeTensor") && shape.empty()) {
161-
auto shape_dims = ctx->GetInputDim("ShapeTensor");
162-
PADDLE_ENFORCE_EQ(
163-
shape_dims.size(), 1,
164-
platform::errors::InvalidArgument(
165-
"ShapeError: Input(ShapeTensor)' dimension size of "
166-
"Op(uniform_random) must be 1."
167-
"But received ShapeTensor's dimensions = %d, shape = [%s]",
168-
shape_dims.size(), shape_dims));
169-
int num_ele = 1;
170-
for (int i = 0; i < shape_dims.size(); ++i) {
171-
num_ele *= shape_dims[i];
172-
}
173-
auto vec_dims = std::vector<int64_t>(num_ele, -1);
174-
auto out_dims = phi::make_ddim(vec_dims);
175-
ctx->SetOutputDim("Out", out_dims);
176-
return;
177-
}
178-
179-
PADDLE_ENFORCE_EQ(shape.empty(), false,
180-
platform::errors::InvalidArgument(
181-
"if there is no Input(ShapeTensorList) and no "
182-
"Input(ShapeTensor),the "
183-
"attr(shape) information must "
184-
"be set by Attr(shape)."));
185-
std::vector<int64_t> tensor_shape;
186-
tensor_shape.reserve(shape.size());
187-
for (auto dim : shape) {
188-
tensor_shape.push_back(static_cast<int64_t>(dim));
189-
}
190-
ctx->SetOutputDim("Out", phi::make_ddim(tensor_shape));
191-
}
192-
193127
protected:
194128
framework::OpKernelType GetExpectedKernelType(
195129
const framework::ExecutionContext &ctx) const override {
@@ -274,12 +208,16 @@ class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
274208
} // namespace operators
275209
} // namespace paddle
276210

211+
DECLARE_INFER_SHAPE_FUNCTOR(uniform_random, UniformRandomInferShapeFunctor,
212+
PD_INFER_META(phi::UniformRandomInferMeta));
213+
277214
REGISTER_OPERATOR(
278215
uniform_random, paddle::operators::UniformRandomOp,
279216
paddle::operators::UniformRandomOpMaker,
280217
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
281218
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
282-
paddle::operators::UniformRandomOpVarTypeInference);
219+
paddle::operators::UniformRandomOpVarTypeInference,
220+
UniformRandomInferShapeFunctor);
283221

284222
REGISTER_OP_CPU_KERNEL(
285223
uniform_random_batch_size_like,

paddle/phi/infermeta/nullary.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) {
6363
out->set_dtype(dtype);
6464
}
6565

66+
void UniformRandomInferMeta(const IntArray& shape,
67+
DataType dtype,
68+
float min,
69+
float max,
70+
int seed,
71+
MetaTensor* out) {
72+
auto out_dims = phi::make_ddim(shape.GetData());
73+
out->set_dims(out_dims);
74+
out->set_dtype(dtype);
75+
out->set_layout(DataLayout::NCHW);
76+
}
77+
6678
void RandintInferMeta(
6779
int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out) {
6880
PADDLE_ENFORCE_NOT_NULL(

paddle/phi/infermeta/nullary.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,11 @@ void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
6565
DataType dtype,
6666
MetaTensor* out);
6767

68+
void UniformRandomInferMeta(const IntArray& shape,
69+
DataType dtype,
70+
float min,
71+
float max,
72+
int seed,
73+
MetaTensor* out);
74+
6875
} // namespace phi

python/paddle/fluid/tests/unittests/test_uniform_random_op.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from paddle.fluid.op import Operator
2727
import paddle.fluid as fluid
2828
from paddle.fluid import Program, program_guard
29+
from paddle.fluid.framework import _test_eager_guard
2930

3031

3132
def output_hist(out):
@@ -52,6 +53,7 @@ def output_hist_diag(out):
5253
class TestUniformRandomOp_attr_tensorlist(OpTest):
5354
def setUp(self):
5455
self.op_type = "uniform_random"
56+
self.python_api = paddle.uniform
5557
self.new_shape = (1000, 784)
5658
shape_tensor = []
5759
for index, ele in enumerate(self.new_shape):
@@ -84,6 +86,7 @@ def init_attrs(self):
8486
class TestUniformRandomOp_attr_tensorlist_int32(OpTest):
8587
def setUp(self):
8688
self.op_type = "uniform_random"
89+
self.python_api = paddle.uniform
8790
self.new_shape = (1000, 784)
8891
shape_tensor = []
8992
for index, ele in enumerate(self.new_shape):
@@ -110,6 +113,7 @@ def verify_output(self, outs):
110113
class TestUniformRandomOp_attr_tensor(OpTest):
111114
def setUp(self):
112115
self.op_type = "uniform_random"
116+
self.python_api = paddle.uniform
113117
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int64")}
114118
self.init_attrs()
115119
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
@@ -131,6 +135,7 @@ def verify_output(self, outs):
131135
class TestUniformRandomOp_attr_tensor_int32(OpTest):
132136
def setUp(self):
133137
self.op_type = "uniform_random"
138+
self.python_api = paddle.uniform
134139
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int32")}
135140
self.init_attrs()
136141
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
@@ -152,6 +157,7 @@ def verify_output(self, outs):
152157
class TestUniformRandomOp(OpTest):
153158
def setUp(self):
154159
self.op_type = "uniform_random"
160+
self.python_api = paddle.uniform
155161
self.inputs = {}
156162
self.init_attrs()
157163
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
@@ -174,6 +180,18 @@ def verify_output(self, outs):
174180
np.allclose(
175181
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
176182

183+
def test_check_api(self):
184+
places = self._get_places()
185+
for place in places:
186+
with fluid.dygraph.base.guard(place=place):
187+
out = self.python_api(self.attrs['shape'], 'float32',
188+
self.attrs['min'], self.attrs['max'],
189+
self.attrs['seed'])
190+
191+
def test_check_api_eager(self):
192+
with _test_eager_guard():
193+
self.test_check_api()
194+
177195

178196
class TestUniformRandomOpError(unittest.TestCase):
179197
def test_errors(self):

python/paddle/tensor/random.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,14 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
548548
if not isinstance(dtype, core.VarDesc.VarType):
549549
dtype = convert_np_dtype_to_dtype_(dtype)
550550

551-
if paddle.in_dynamic_mode():
551+
if in_dygraph_mode():
552+
shape = utils.convert_shape_to_list(shape)
553+
return _C_ops.final_state_uniform_random(shape, dtype,
554+
float(min),
555+
float(max), seed,
556+
_current_expected_place())
557+
558+
if _in_legacy_dygraph():
552559
shape = utils.convert_shape_to_list(shape)
553560
return _C_ops.uniform_random('shape', shape, 'min',
554561
float(min), 'max',

python/paddle/utils/code_gen/api.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,6 +1975,18 @@
19751975
func : unfold
19761976
backward : unfold_grad
19771977

1978+
- api : uniform_random
1979+
args : (IntArray shape, DataType dtype, float min, float max, int seed, Place place={})
1980+
output : Tensor(out)
1981+
infer_meta :
1982+
func : UniformRandomInferMeta
1983+
param: [shape, dtype, min, max, seed]
1984+
kernel :
1985+
func : uniform_random
1986+
param: [shape, dtype, min, max, seed]
1987+
data_type : dtype
1988+
backend : place
1989+
19781990
# The `axis` argument of Python API paddle.unique is not vector
19791991
- api : unique
19801992
args : (Tensor x, bool return_index, bool return_inverse, bool return_counts, int[] axis, DataType dtype=DataType::INT64)

0 commit comments

Comments
 (0)