File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -39,11 +39,16 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
39
39
PADDLE_ENFORCE_EQ (
40
40
rank, offsets_tensor->dims ()[0 ],
41
41
" Offsets size should be equal to dimension size of input tensor." );
42
- const int * offsets_data = offsets_tensor->data <int >();
43
- res.resize (rank);
44
- for (size_t i = 0 ; i < rank; ++i) {
45
- res[i] = offsets_data[i];
42
+ const int * offsets_data;
43
+ framework::Tensor cpu_tmp_tensor;
44
+ if (platform::is_cpu_place (offsets_tensor->place ())) {
45
+ offsets_data = offsets_tensor->data <int >();
46
+ } else {
47
+ framework::TensorCopySync (*offsets_tensor, platform::CPUPlace (),
48
+ &cpu_tmp_tensor);
49
+ offsets_data = cpu_tmp_tensor.data <int >();
46
50
}
51
+ res = std::vector<int >(offsets_data, offsets_data + rank);
47
52
} else {
48
53
res = ctx.Attr <std::vector<int >>(" offsets" );
49
54
PADDLE_ENFORCE_EQ (
You can’t perform that action at this time.
0 commit comments