Skip to content

Commit 6cff3c9

Browse files
authored
Merge pull request #7381 from reyoung/feature/refine_get_places_op
Polish GetPlacesOp
2 parents a320276 + e5e206e commit 6cff3c9

File tree

2 files changed

+26
-21
lines changed

2 files changed

+26
-21
lines changed

paddle/operators/get_places_op.cc

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,32 +39,34 @@ class GetPlacesOp : public framework::OperatorBase {
3939
: OperatorBase(type, inputs, outputs, attrs) {}
4040
void Run(const framework::Scope &scope,
4141
const platform::Place &place) const override {
42-
std::string device_type = Attr<std::string>("device_type");
42+
bool is_gpu;
43+
if (Attr<std::string>("device_type") == "AUTO") {
44+
is_gpu = platform::is_gpu_place(place);
45+
} else {
46+
is_gpu = Attr<std::string>("device_type") == "CUDA";
47+
}
4348
auto device_count = static_cast<size_t>(Attr<int>("device_count"));
4449
if (device_count == 0) {
45-
if (device_type == "CUDA") {
46-
device_count = CUDADevCount();
47-
} else if (device_type == "CPU") {
48-
device_count = std::thread::hardware_concurrency();
49-
}
50+
device_count =
51+
is_gpu ? CUDADevCount() : std::thread::hardware_concurrency();
5052
}
5153
PADDLE_ENFORCE_NE(device_count, 0, "Cannot indicate %s device count",
52-
device_type);
54+
is_gpu ? "GPU" : "CPU");
5355

5456
auto out_var_name = Output("Out");
5557
auto &places =
5658
*(detail::Ref(scope.FindVar(out_var_name),
5759
"Output variable %s cannot be found", out_var_name)
5860
.GetMutable<platform::PlaceList>());
5961
places.reserve(device_count);
60-
if (device_type == "CUDA") {
62+
if (is_gpu) {
6163
PADDLE_ENFORCE_LE(device_count, CUDADevCount(),
6264
"Only %d CUDA devices found, cannot set to %d",
6365
CUDADevCount(), device_count);
6466
for (size_t i = 0; i < device_count; ++i) {
65-
places.emplace_back(platform::CUDAPlace(i));
67+
places.emplace_back(platform::CUDAPlace(static_cast<int>(i)));
6668
}
67-
} else if (device_type == "CPU") {
69+
} else {
6870
for (size_t i = 0; i < device_count; ++i) {
6971
places.emplace_back(platform::CPUPlace());
7072
}
@@ -77,10 +79,10 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
7779
GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
7880
: OpProtoAndCheckerMaker(proto, op_checker) {
7981
AddOutput("Out", "vector of Place");
80-
AddAttr<int>("device_count", "device count").SetDefault(1);
81-
AddAttr<std::string>("device_type",
82-
R"(device type must be in ["CPU", "CUDA"])")
83-
.InEnum({"CPU", "CUDA"});
82+
AddAttr<int>("device_count", "device count").SetDefault(0);
83+
AddAttr<std::string>("device_type", "device type")
84+
.InEnum({"CUDA", "CPU", "AUTO"})
85+
.SetDefault("AUTO");
8486
AddComment(R"DOC(
8587
Returns a list of places based on flags. The list will be used for parallel
8688
execution.

python/paddle/v2/fluid/layers/device.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,22 @@
44

55
from ..layer_helper import LayerHelper
66
from ..framework import unique_name
7+
from ..registry import autodoc
78

89
__all__ = ['get_places']
910

1011

11-
def get_places(device_count=0, device_type="CPU"):
12+
@autodoc
13+
def get_places(device_count=None, device_type=None):
1214
helper = LayerHelper('get_places', **locals())
1315
out_places = helper.create_variable(name=unique_name(helper.name + ".out"))
16+
attrs = dict()
17+
if device_count is not None:
18+
attrs['device_count'] = int(device_count)
19+
if device_type is not None:
20+
attrs['device_type'] = str(device_type)
21+
1422
helper.append_op(
15-
type='get_places',
16-
outputs={"Out": [out_places]},
17-
attrs={
18-
"device_type": device_type,
19-
'device_count': device_count,
20-
})
23+
type='get_places', outputs={"Out": [out_places]}, attrs=attrs)
2124

2225
return out_places

0 commit comments

Comments
 (0)