@@ -39,32 +39,34 @@ class GetPlacesOp : public framework::OperatorBase {
39
39
: OperatorBase(type, inputs, outputs, attrs) {}
40
40
void Run (const framework::Scope &scope,
41
41
const platform::Place &place) const override {
42
- std::string device_type = Attr<std::string>(" device_type" );
42
+ bool is_gpu;
43
+ if (Attr<std::string>(" device_type" ) == " AUTO" ) {
44
+ is_gpu = platform::is_gpu_place (place);
45
+ } else {
46
+ is_gpu = Attr<std::string>(" device_type" ) == " CUDA" ;
47
+ }
43
48
auto device_count = static_cast <size_t >(Attr<int >(" device_count" ));
44
49
if (device_count == 0 ) {
45
- if (device_type == " CUDA" ) {
46
- device_count = CUDADevCount ();
47
- } else if (device_type == " CPU" ) {
48
- device_count = std::thread::hardware_concurrency ();
49
- }
50
+ device_count =
51
+ is_gpu ? CUDADevCount () : std::thread::hardware_concurrency ();
50
52
}
51
53
PADDLE_ENFORCE_NE (device_count, 0 , " Cannot indicate %s device count" ,
52
- device_type );
54
+ is_gpu ? " GPU " : " CPU " );
53
55
54
56
auto out_var_name = Output (" Out" );
55
57
auto &places =
56
58
*(detail::Ref (scope.FindVar (out_var_name),
57
59
" Output variable %s cannot be found" , out_var_name)
58
60
.GetMutable <platform::PlaceList>());
59
61
places.reserve (device_count);
60
- if (device_type == " CUDA " ) {
62
+ if (is_gpu ) {
61
63
PADDLE_ENFORCE_LE (device_count, CUDADevCount (),
62
64
" Only %d CUDA devices found, cannot set to %d" ,
63
65
CUDADevCount (), device_count);
64
66
for (size_t i = 0 ; i < device_count; ++i) {
65
- places.emplace_back (platform::CUDAPlace (i ));
67
+ places.emplace_back (platform::CUDAPlace (static_cast < int >(i) ));
66
68
}
67
- } else if (device_type == " CPU " ) {
69
+ } else {
68
70
for (size_t i = 0 ; i < device_count; ++i) {
69
71
places.emplace_back (platform::CPUPlace ());
70
72
}
@@ -77,10 +79,10 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
77
79
GetPlacesOpProtoMaker (OpProto *proto, OpAttrChecker *op_checker)
78
80
: OpProtoAndCheckerMaker(proto, op_checker) {
79
81
AddOutput (" Out" , " vector of Place" );
80
- AddAttr<int >(" device_count" , " device count" ).SetDefault (1 );
81
- AddAttr<std::string>(" device_type" ,
82
- R"( device type must be in [ "CPU", "CUDA"] )" )
83
- .InEnum ({ " CPU " , " CUDA " } );
82
+ AddAttr<int >(" device_count" , " device count" ).SetDefault (0 );
83
+ AddAttr<std::string>(" device_type" , " device type " )
84
+ . InEnum ({ " CUDA " , " CPU" , " AUTO " } )
85
+ .SetDefault ( " AUTO " );
84
86
AddComment (R"DOC(
85
87
Returns a list of places based on flags. The list will be used for parallel
86
88
execution.
0 commit comments