Skip to content

Commit 12737a6

Browse files
authored
【CUDA Kernel No.74、75】lookup_table、lookup_table_grad算子Kernel修复 -part (#75645)
* Add lookup_table_kernel.h * Add lookup_table_grad_kernel.h * Bug fix * Delete CPUKernel * 修改路径、去除cpu声明 * Update lookup_table_grad_kernel.cc * CI * Fix * Fix codestyle * Fix ci * Fix ci * save * Revert "Fix ci" This reverts commit 644eacc. * Revert "Fix ci" This reverts commit f7bf793. * Revert "Fix codestyle" This reverts commit cbba545. * fix * 忘记改了一个报错
1 parent 680b1c6 commit 12737a6

File tree

6 files changed

+114
-8
lines changed

6 files changed

+114
-8
lines changed

paddle/phi/kernels/cpu/lookup_table_grad_kernel.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414

1515
#include <string>
1616
#include <vector>
17-
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/dense_tensor.h"
1919
#include "paddle/phi/core/enforce.h"
20+
#include "paddle/phi/core/kernel_registry.h"
2021
#include "paddle/phi/kernels/funcs/blas/blas.h"
2122
#include "paddle/phi/kernels/funcs/eigen/common.h"
2223

23-
#include "paddle/phi/backends/cpu/cpu_context.h"
24-
#include "paddle/phi/core/kernel_registry.h"
25-
2624
namespace phi {
2725

2826
constexpr int64_t kNoPadding = -1;

paddle/phi/kernels/cpu/lookup_table_kernel.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414

1515
#include <string>
1616
#include <vector>
17-
17+
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/dense_tensor.h"
1919
#include "paddle/phi/core/enforce.h"
20+
#include "paddle/phi/core/kernel_registry.h"
2021
#include "paddle/phi/kernels/funcs/blas/blas.h"
2122
#include "paddle/phi/kernels/funcs/eigen/common.h"
2223

23-
#include "paddle/phi/backends/cpu/cpu_context.h"
24-
#include "paddle/phi/core/kernel_registry.h"
25-
2624
namespace phi {
2725

2826
constexpr int64_t kNoPadding = -1;

paddle/phi/kernels/gpu/lookup_table_grad_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/phi/kernels/gpu/lookup_table_grad_kernel.h"
1516
#include "paddle/phi/backends/gpu/gpu_primitives.h"
1617
#include "paddle/phi/common/memory_utils.h"
1718
#include "paddle/phi/core/kernel_registry.h"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/dense_tensor.h"
18+
#include "paddle/phi/core/device_context.h"
19+
#include "paddle/phi/core/selected_rows.h"
20+
21+
namespace phi {
22+
23+
template <typename T, typename Context>
24+
void LookupTableGradCUDAKernel(const Context& dev_ctx,
25+
const DenseTensor& w,
26+
const DenseTensor& ids_in,
27+
const DenseTensor& out_grad,
28+
bool is_sparse,
29+
bool is_distributed,
30+
int64_t padding_idx,
31+
bool remote_prefetch,
32+
const std::string& entry_config,
33+
bool is_test,
34+
const std::string& entry,
35+
const std::string& table_class,
36+
const std::vector<std::string>& table_names,
37+
int trainer_id,
38+
bool grad_inplace,
39+
const std::vector<std::string>& epmap,
40+
const std::vector<int64_t>& height_sections,
41+
DenseTensor* w_grad);
42+
43+
template <typename T, typename Context>
44+
void LookupTableSparseGradCUDAKernel(
45+
const Context& dev_ctx,
46+
const DenseTensor& w,
47+
const DenseTensor& ids_in,
48+
const DenseTensor& out_grad,
49+
bool is_sparse,
50+
bool is_distributed,
51+
int64_t padding_idx,
52+
bool remote_prefetch,
53+
const std::string& entry_config,
54+
bool is_test,
55+
const std::string& entry,
56+
const std::string& table_class,
57+
const std::vector<std::string>& table_names,
58+
int trainer_id,
59+
bool grad_inplace,
60+
const std::vector<std::string>& epmap,
61+
const std::vector<int64_t>& height_sections,
62+
SelectedRows* w_grad);
63+
64+
} // namespace phi

paddle/phi/kernels/gpu/lookup_table_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/phi/kernels/gpu/lookup_table_kernel.h"
1516
#include "paddle/phi/backends/gpu/gpu_primitives.h"
1617
#include "paddle/phi/common/memory_utils.h"
1718
#include "paddle/phi/core/kernel_registry.h"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
#include <string>
19+
#include <vector>
20+
#include "paddle/phi/core/dense_tensor.h"
21+
#include "paddle/phi/core/enforce.h"
22+
23+
namespace phi {
24+
25+
template <typename T, typename Context>
26+
void LookupTableCUDAKernel(const Context &dev_ctx,
27+
const DenseTensor &w,
28+
const DenseTensor &ids_in,
29+
bool is_sparse,
30+
bool is_distributed,
31+
int64_t padding_idx,
32+
bool remote_prefetch,
33+
const std::string &entry_config,
34+
bool is_test,
35+
const std::string &entry,
36+
const std::string &table_class,
37+
const std::vector<std::string> &table_names,
38+
int trainer_id,
39+
bool grad_inplace,
40+
const std::vector<std::string> &epmap,
41+
const std::vector<int64_t> &height_sections,
42+
DenseTensor *out);
43+
44+
} // namespace phi

0 commit comments

Comments
 (0)