File tree Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Expand file tree Collapse file tree 1 file changed +5
-2
lines changed Original file line number Diff line number Diff line change @@ -28,12 +28,15 @@ class ArgMaxXPUKernel : public framework::OpKernel<T> {
28
28
auto * out = ctx.Output <framework::LoDTensor>(" Out" );
29
29
auto dtype = ctx.Attr <int >(" dtype" );
30
30
PADDLE_ENFORCE_EQ (
31
- (dtype < 0 || dtype == 3 ), true ,
31
+ (dtype < 0 || dtype == 2 || dtype == 3 ), true ,
32
32
platform::errors::InvalidArgument (
33
- " The attribute of dtype in xpu argmin/argmax must be [%s], but "
33
+ " The attribute of dtype in xpu argmin/argmax must be [%s] or [%s], "
34
+ " but "
34
35
" received [%s]" ,
35
36
paddle::framework::DataTypeToString (
36
37
framework::proto::VarType::INT64),
38
+ paddle::framework::DataTypeToString (
39
+ framework::proto::VarType::INT32),
37
40
paddle::framework::DataTypeToString (
38
41
static_cast <framework::proto::VarType::Type>(dtype))));
39
42
You can’t perform that action at this time.
0 commit comments