@@ -12,18 +12,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " paddle/fluid/operators/sampling_id_op.h"
15
+ #include < algorithm>
16
+ #include < iostream>
17
+ #include < iterator>
18
+ #include < random>
19
+ #include < sstream>
20
+ #include < vector>
21
+ #include " paddle/fluid/framework/op_registry.h"
16
22
17
23
namespace paddle {
18
24
namespace operators {
19
25
20
26
using Tensor = framework::Tensor;
21
27
28
+ template <typename DeviceContext, typename T>
29
+ class SamplingIdKernel : public framework ::OpKernel<T> {
30
+ public:
31
+ void Compute (const framework::ExecutionContext& context) const override {
32
+ const Tensor* input = context.Input <Tensor>(" X" );
33
+ const int batch_size = static_cast <int >(input->dims ()[0 ]);
34
+ const int width = static_cast <int >(input->dims ()[1 ]);
35
+
36
+ std::vector<T> ins_vector;
37
+ framework::TensorToVector (*input, context.device_context (), &ins_vector);
38
+
39
+ std::vector<T> ids (batch_size);
40
+ for (size_t i = 0 ; i < batch_size; ++i) {
41
+ double r = getRandReal ();
42
+ int idx = width - 1 ;
43
+ for (int j = 0 ; j < width; ++j) {
44
+ if ((r -= ins_vector[i * width + j]) < 0 ) {
45
+ idx = j;
46
+ break ;
47
+ }
48
+ }
49
+ ids[i] = ins_vector[i * width + idx];
50
+ }
51
+
52
+ std::vector<int64_t > out_dim;
53
+ out_dim.push_back (static_cast <int64_t >(batch_size));
54
+
55
+ Tensor* output = context.Output <Tensor>(" Out" );
56
+ output->Resize (framework::make_ddim (out_dim));
57
+ output->mutable_data <T>(context.GetPlace ());
58
+ framework::TensorFromVector (ids, context.device_context (), output);
59
+ }
60
+
61
+ private:
62
+ double getRandReal () const {
63
+ std::random_device
64
+ rd; // Will be used to obtain a seed for the random number engine
65
+ std::mt19937 gen (rd ()); // Standard mersenne_twister_engine seeded with
66
+ // rd()
67
+ std::uniform_real_distribution<> dis (1.0 , 2.0 );
68
+ return dis (gen);
69
+ }
70
+ };
71
+
22
72
class SamplingIdOp : public framework ::OperatorWithKernel {
23
73
public:
24
74
using framework::OperatorWithKernel::OperatorWithKernel;
25
75
26
- void InferShape (framework::InferShapeContext * ctx) const override {
76
+ void InferShape (framework::InferShapeContext* ctx) const override {
27
77
PADDLE_ENFORCE (ctx->HasInput (" X" ),
28
78
" Input(X) of SamplingIdOp should not be null." );
29
79
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
0 commit comments