Skip to content

Commit 0ccb9cb

Browse files
[cherry-pick] adapt c_embedding to phi namespace for custom devices (#60774) (#61045)
Co-authored-by: Tian <[email protected]>
1 parent 60325a1 commit 0ccb9cb

File tree

4 files changed

+201
-5
lines changed

4 files changed

+201
-5
lines changed

paddle/phi/kernels/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ if(WITH_MKLDNN)
199199
"fusion/onednn/*.cc")
200200
endif()
201201

202+
if(WITH_CUSTOM_DEVICE)
203+
set(cc_search_pattern ${cc_search_pattern} "custom/*.cc")
204+
endif()
205+
202206
file(
203207
GLOB kernel_cc
204208
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/c_embedding_grad_kernel.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/api/backward/backward_api.h"
18+
#include "paddle/phi/api/include/api.h"
19+
#include "paddle/phi/backends/all_context.h"
20+
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
23+
namespace phi {
24+
25+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
26+
template <typename T, typename Context>
27+
void CEmbeddingGradKernel(const Context& dev_ctx,
28+
const DenseTensor& w,
29+
const DenseTensor& ids,
30+
const DenseTensor& out_grad,
31+
int64_t start_index,
32+
DenseTensor* w_grad) {
33+
w_grad->Resize(w.dims());
34+
dev_ctx.template Alloc(w_grad, w.dtype());
35+
const auto& index_type = ids.dtype();
36+
if (index_type == phi::DataType::INT32 ||
37+
index_type == phi::DataType::INT64) {
38+
auto K = ids.numel();
39+
auto N = w.dims()[0];
40+
auto D = w.dims()[1];
41+
42+
auto x_tmp = std::make_shared<phi::DenseTensor>();
43+
x_tmp->ShareDataWith(ids).Resize({K});
44+
auto w_tmp = std::make_shared<phi::DenseTensor>();
45+
w_tmp->set_meta(w.meta());
46+
dev_ctx.Alloc(w_tmp.get(), w_tmp->dtype());
47+
auto out_grad_tmp = std::make_shared<phi::DenseTensor>();
48+
out_grad_tmp->ShareDataWith(out_grad).Resize({K, D});
49+
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp),
50+
out_grad_tensor(out_grad_tmp);
51+
52+
auto start_index_tensor = paddle::experimental::full_like(
53+
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
54+
auto end_index_tensor = paddle::experimental::full_like(
55+
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
56+
auto ids_mask_tensor = paddle::experimental::logical_and(
57+
x_tensor.greater_equal(start_index_tensor),
58+
x_tensor.less_than(end_index_tensor));
59+
auto real_ids_tensor = (x_tensor - start_index_tensor)
60+
.multiply(paddle::experimental::cast(
61+
ids_mask_tensor, x_tensor.dtype()));
62+
auto out_grad_tensor_mul_mask =
63+
paddle::experimental::reshape(out_grad_tensor, {K, D})
64+
.multiply(paddle::experimental::reshape(
65+
paddle::experimental::cast(ids_mask_tensor, w.dtype()),
66+
{K, 1}));
67+
paddle::Tensor w_grad_tensor;
68+
paddle::experimental::embedding_grad(real_ids_tensor,
69+
w_tensor,
70+
out_grad_tensor_mul_mask,
71+
-1,
72+
false,
73+
&w_grad_tensor);
74+
w_grad->ShareDataWith(
75+
*reinterpret_cast<phi::DenseTensor*>(w_grad_tensor.impl().get()));
76+
77+
} else {
78+
PADDLE_THROW(phi::errors::Unavailable(
79+
"Custom Device c_embedding_grad ids only support int32 or int64."));
80+
}
81+
}
82+
#endif
83+
} // namespace phi
84+
85+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
86+
PD_REGISTER_KERNEL(c_embedding_grad,
87+
Custom,
88+
ALL_LAYOUT,
89+
phi::CEmbeddingGradKernel,
90+
float,
91+
phi::dtype::float16,
92+
phi::dtype::bfloat16) {}
93+
#endif
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/c_embedding_kernel.h"
16+
#include "glog/logging.h"
17+
#include "paddle/phi/api/backward/backward_api.h"
18+
#include "paddle/phi/api/include/api.h"
19+
#include "paddle/phi/backends/all_context.h"
20+
#include "paddle/phi/common/float16.h"
21+
#include "paddle/phi/core/kernel_registry.h"
22+
23+
namespace phi {
24+
25+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
26+
template <typename T, typename Context>
27+
void CEmbeddingKernel(const Context& dev_ctx,
28+
const DenseTensor& w,
29+
const DenseTensor& ids,
30+
int64_t start_index,
31+
int64_t vocab_size,
32+
DenseTensor* out) {
33+
const auto& index_type = ids.dtype();
34+
if (index_type == phi::DataType::INT32 ||
35+
index_type == phi::DataType::INT64) {
36+
auto out_dims = out->dims();
37+
auto K = ids.numel();
38+
auto N = w.dims()[0];
39+
auto D = w.dims()[1];
40+
41+
auto x_tmp = std::make_shared<phi::DenseTensor>();
42+
x_tmp->ShareDataWith(ids).Resize({K});
43+
auto w_tmp = std::make_shared<phi::DenseTensor>();
44+
w_tmp->ShareDataWith(w).Resize({N, D});
45+
paddle::Tensor x_tensor(x_tmp), w_tensor(w_tmp);
46+
47+
auto start_index_tensor = paddle::experimental::full_like(
48+
x_tensor, start_index, x_tensor.dtype(), x_tensor.place());
49+
auto end_index_tensor = paddle::experimental::full_like(
50+
x_tensor, start_index + N, x_tensor.dtype(), x_tensor.place());
51+
auto ids_mask_tensor = paddle::experimental::logical_and(
52+
x_tensor.greater_equal(start_index_tensor),
53+
x_tensor.less_than(end_index_tensor));
54+
auto ids_tensor = (x_tensor - start_index_tensor)
55+
.multiply(paddle::experimental::cast(
56+
ids_mask_tensor, x_tensor.dtype()));
57+
auto out_tensor =
58+
paddle::experimental::reshape(
59+
paddle::experimental::cast(ids_mask_tensor, w_tensor.dtype()),
60+
{K, 1})
61+
.multiply(paddle::experimental::reshape(
62+
paddle::experimental::embedding(
63+
ids_tensor, w_tensor, -1, false),
64+
{K, D}));
65+
out->ShareDataWith(
66+
*reinterpret_cast<phi::DenseTensor*>(out_tensor.impl().get()))
67+
.Resize(out_dims);
68+
} else {
69+
PADDLE_THROW(phi::errors::Unavailable(
70+
"Custom Device c_embedding ids only support int32 or int64."));
71+
}
72+
}
73+
#endif
74+
} // namespace phi
75+
76+
#ifdef PADDLE_WITH_CUSTOM_DEVICE
77+
PD_REGISTER_KERNEL(c_embedding,
78+
Custom,
79+
ALL_LAYOUT,
80+
phi::CEmbeddingKernel,
81+
float,
82+
phi::dtype::float16,
83+
phi::dtype::bfloat16) {}
84+
#endif

test/legacy_test/c_embedding_op_base.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ def get_c_embedding(start, end, table, ids):
3434
return output
3535

3636

37-
def c_embedding_wrapper(table, index, start_index=0):
38-
return paddle._legacy_C_ops.c_embedding(
39-
table, index, "start_index", start_index
40-
)
37+
def c_embedding_wrapper(table, index, start_index=0, vocab_size=-1):
38+
return paddle._C_ops.c_embedding(table, index, start_index, vocab_size)
4139

4240

4341
class TestCEmbeddingCPU(OpTest):
@@ -58,11 +56,15 @@ def initcase(self):
5856
)
5957
self.start_index = 10
6058
self.end_index = self.start_index + 17
59+
self.vocab_size = 34
6160

6261
self.inputs = {'W': table, 'Ids': ids}
6362
np_out = get_c_embedding(self.start_index, self.end_index, table, ids)
6463
self.outputs = {'Out': np_out.reshape((2, 4, 64))}
65-
self.attrs = {'start_index': self.start_index}
64+
self.attrs = {
65+
'start_index': self.start_index,
66+
'vocab_size': self.vocab_size,
67+
}
6668
if core.is_compiled_with_xpu():
6769
self.__class__.use_xpu = True
6870

@@ -87,12 +89,20 @@ def test_check_output(self):
8789
self.check_output_with_place(core.CUDAPlace(0))
8890
elif core.is_compiled_with_xpu():
8991
self.check_output_with_place(core.XPUPlace(0))
92+
else:
93+
current_place = paddle.framework._current_expected_place()
94+
if isinstance(current_place, paddle.CustomPlace):
95+
self.check_output_with_place(current_place)
9096

9197
def test_check_grad(self):
9298
if core.is_compiled_with_cuda():
9399
self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out')
94100
elif core.is_compiled_with_xpu():
95101
self.check_grad_with_place(core.XPUPlace(0), ['W'], 'Out')
102+
else:
103+
current_place = paddle.framework._current_expected_place()
104+
if isinstance(current_place, paddle.CustomPlace):
105+
self.check_grad_with_place(current_place, ['W'], 'Out')
96106

97107
def init_dtype(self):
98108
if core.is_compiled_with_cuda():
@@ -101,6 +111,11 @@ def init_dtype(self):
101111
elif core.is_compiled_with_xpu():
102112
self.dtype = "float32"
103113
self.ids_dtype = "int64"
114+
else:
115+
current_place = paddle.framework._current_expected_place()
116+
if isinstance(current_place, paddle.CustomPlace):
117+
self.dtype = "float32"
118+
self.ids_dtype = "int64"
104119

105120

106121
class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):

0 commit comments

Comments
 (0)