14
14
15
15
#pragma once
16
16
17
+ #ifdef __NVCC__
18
+ #include < thrust/device_ptr.h>
19
+ #include < thrust/functional.h>
20
+ #include < thrust/reduce.h>
21
+ #else
22
+ #include < algorithm>
23
+ #endif
24
+
17
25
#include " paddle/fluid/framework/op_registry.h"
18
26
#include " paddle/fluid/platform/for_range.h"
19
27
@@ -26,73 +34,83 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
26
34
27
35
void InferShape (framework::InferShapeContext *ctx) const override {
28
36
PADDLE_ENFORCE (ctx->HasInput (" X" ), " Input(X) must exist" );
29
- auto max_len = ctx->Attrs ().Get <int >(" max_len" );
30
- PADDLE_ENFORCE_GT (max_len, 1 , " Attr(max_len) must be larger than 1" );
31
37
PADDLE_ENFORCE (ctx->HasOutput (" Y" ), " Output(Y) must exist" );
32
- auto dim = framework::vectorize2int (ctx->GetInputDim (" X" ));
33
- dim.push_back (max_len);
34
- ctx->SetOutputDim (" Y" , framework::make_ddim (dim));
38
+
39
+ auto maxlen = ctx->Attrs ().Get <int >(" maxlen" );
40
+ if (maxlen > 0 ) { // We can only infershape when maxlen > 0
41
+ auto dim = framework::vectorize2int (ctx->GetInputDim (" X" ));
42
+ dim.push_back (maxlen);
43
+ ctx->SetOutputDim (" Y" , framework::make_ddim (dim));
44
+ }
35
45
}
36
46
};
37
47
38
48
class SequenceMaskOpMaker : public framework ::OpProtoAndCheckerMaker {
39
49
public:
40
50
void Make () override {
41
- AddInput (" X" , " The input of sequence_mask op." );
51
+ AddInput (" X" , " The input tensor of sequence_mask op." );
42
52
AddOutput (" Y" , " The output mask of sequence_mask op." );
43
- AddAttr<int >(" max_len" , " The maximum length of the sequence." )
44
- .GreaterThan (1 );
53
+ AddAttr<int >(" maxlen" ,
54
+ " The maximum length of the sequence. If maxlen < 0, maxlen "
55
+ " = max(Input(X))." )
56
+ .SetDefault (-1 )
57
+ .AddCustomChecker ([](int &v) {
58
+ PADDLE_ENFORCE (v < 0 || v >= 1 ,
59
+ " Attr(maxlen) must be less than 0 or larger than 1" );
60
+ });
45
61
AddAttr<int >(" out_dtype" , " Output data type" );
46
62
AddComment (R"DOC(
47
63
SequenceMask Operator
48
64
49
- This operator outputs a Mask according to Input(X) and Attr(max_len ).
65
+ This operator outputs a Mask according to Input(X) and Attr(maxlen ).
50
66
Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the
51
- Output(Y) is a mask with shape [d_1, d_2, ..., d_n, max_len ], where:
67
+ Output(Y) is a mask with shape [d_1, d_2, ..., d_n, maxlen ], where:
52
68
53
69
Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n))
70
+
71
+ If maxlen < 0, maxlen = max(X)
54
72
)DOC" );
55
73
}
56
74
};
57
75
58
76
template <typename Tx, typename Ty>
59
77
struct SequenceMaskForRangeFunctor {
60
- HOSTDEVICE SequenceMaskForRangeFunctor (const Tx *x, Ty *y, int max_len )
61
- : x_(x), y_(y), max_len_(max_len ) {}
78
+ HOSTDEVICE SequenceMaskForRangeFunctor (const Tx *x, Ty *y, int maxlen )
79
+ : x_(x), y_(y), maxlen_(maxlen ) {}
62
80
63
81
HOSTDEVICE void operator ()(int y_idx) const {
64
- int x_idx = y_idx / max_len_ ;
65
- int j = y_idx % max_len_ ;
82
+ int x_idx = y_idx / maxlen_ ;
83
+ int j = y_idx % maxlen_ ;
66
84
y_[y_idx] = static_cast <Ty>(j < x_[x_idx] ? 1 : 0 );
67
85
}
68
86
69
87
private:
70
88
const Tx *x_;
71
89
Ty *y_;
72
- int max_len_ ;
90
+ int maxlen_ ;
73
91
};
74
92
75
93
template <typename DeviceContext, typename Tx>
76
94
struct SequenceMaskFunctor {
77
95
using Tensor = framework::LoDTensor;
78
96
79
97
SequenceMaskFunctor (const DeviceContext &ctx, const Tx *x, Tensor *y,
80
- int limits, int max_len )
81
- : ctx_(ctx), x_(x), y_(y), limits_(limits), max_len_(max_len ) {}
98
+ int limits, int maxlen )
99
+ : ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen ) {}
82
100
83
101
template <typename Ty>
84
102
void operator ()() const {
85
103
auto *y_data = y_->mutable_data <Ty>(ctx_.GetPlace ());
86
104
platform::ForRange<DeviceContext> for_range (ctx_, limits_);
87
- for_range (SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, max_len_ ));
105
+ for_range (SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_ ));
88
106
}
89
107
90
108
private:
91
109
const DeviceContext &ctx_;
92
110
const Tx *x_;
93
111
Tensor *y_;
94
112
int limits_;
95
- int max_len_ ;
113
+ int maxlen_ ;
96
114
};
97
115
98
116
template <typename DeviceContext, typename Tx>
@@ -103,13 +121,32 @@ class SequenceMaskKernel : public framework::OpKernel<Tx> {
103
121
void Compute (const framework::ExecutionContext &ctx) const override {
104
122
auto *x = ctx.Input <Tensor>(" X" );
105
123
auto *y = ctx.Output <Tensor>(" Y" );
106
- auto max_len = ctx.Attr <int >(" max_len" );
124
+ auto maxlen = ctx.Attr <int >(" maxlen" );
125
+
126
+ auto *x_data = x->data <Tx>();
127
+ auto x_numel = x->numel ();
128
+ if (maxlen < 0 ) {
129
+ #ifdef __NVCC__
130
+ VLOG (10 )
131
+ << " SequenceMaskOp on GPU may be slow when maxlen is not provided." ;
132
+ maxlen = static_cast <int >(
133
+ thrust::reduce (thrust::device_pointer_cast (x_data),
134
+ thrust::device_pointer_cast (x_data) + x_numel,
135
+ static_cast <Tx>(0 ), thrust::maximum<Tx>()));
136
+ #else
137
+ maxlen = static_cast <int >(*std::max_element (x_data, x_data + x_numel));
138
+ #endif
139
+ auto y_dim = framework::vectorize2int (x->dims ());
140
+ y_dim.push_back (maxlen);
141
+ y->Resize (framework::make_ddim (y_dim));
142
+ }
143
+
107
144
auto out_dtype = static_cast <framework::proto::VarType::Type>(
108
145
ctx.Attr <int >(" out_dtype" ));
109
146
auto &dev_ctx = ctx.template device_context <DeviceContext>();
110
- framework::VisitDataType (out_dtype, SequenceMaskFunctor<DeviceContext, Tx>(
111
- dev_ctx, x-> data < Tx>(), y,
112
- x-> numel () * max_len, max_len ));
147
+ framework::VisitDataType (out_dtype,
148
+ SequenceMaskFunctor<DeviceContext, Tx>(
149
+ dev_ctx, x_data, y, x_numel * maxlen, maxlen ));
113
150
}
114
151
};
115
152
0 commit comments