Skip to content

Commit 601626a

Browse files
authored
[cherry-pick][code-gen] Support code-gen for opmaker of sparse op (#46993) (#47417)
* support generating code of opmaker for backward op invoke forward op (#46912) * [code-gen] Support code-gen for opmaker of sparse op (#46993) * support generating code of opmaker for backward op invoke forward op * gsupport code-gen of opmaker for sparse op * refind logic of choose phi kernrel * fix complie budg * fix code_gen bug * fix bug * fix kernel signature code-gen * fix complie bug of VarType * fix complie bug of VarType * fix test_sparse_conv_op * fix test_sparse_norm_op * [Phi] Refactor logic of judging whether having a phi kernrel (#46920) * refind logic of choose phi kernrel * fix complie budg * update cmake
1 parent 23c05f2 commit 601626a

39 files changed

+829
-611
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ paddle/fluid/pybind/eager_op_function.cc
7171

7272
# these files (directories) are generated before build system generation
7373
paddle/fluid/operators/generated_op.cc
74+
paddle/fluid/operators/generated_sparse_op.cc
7475
paddle/phi/ops/compat/generated_sig.cc
76+
paddle/phi/ops/compat/generated_sparse_sig.cc
7577
paddle/phi/api/yaml/parsed_apis/
7678
python/paddle/utils/code_gen/
7779
paddle/fluid/pybind/tmp_eager_op_function_impl.h

paddle/fluid/eager/auto_code_generator/eager_generator.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ static std::unordered_set<std::string> black_ops_list = {"run_program",
5555
"fused_gate_attention",
5656
"fused_feedforward",
5757
"fused_attention",
58-
"fused_gemm_epilogue"};
58+
"fused_gemm_epilogue",
59+
"sparse_divide_scalar",
60+
"sparse_scale"};
5961

6062
static std::string LegalizeVariableName(const std::string& var_name) {
6163
std::string ret = var_name;
@@ -3161,6 +3163,12 @@ static void DygraphCodeGeneration(const std::string& output_dir,
31613163
continue;
31623164
}
31633165

3166+
// Skip the sparse op
3167+
if (op_type.compare(0, 7, "sparse_") == 0 && op_type != "sparse_momentum" &&
3168+
op_type != "sparse_attention") {
3169+
continue;
3170+
}
3171+
31643172
GradNodeGenerationInfo bwd_info;
31653173

31663174
bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info);

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ cc_test(
190190
cc_library(
191191
var_type_traits
192192
SRCS var_type_traits.cc
193-
DEPS framework_proto scope tensor_array sparse_coo_tensor)
193+
DEPS framework_proto scope tensor_array sparse_coo_tensor sparse_csr_tensor)
194194
if(WITH_GPU)
195195
target_link_libraries(var_type_traits dynload_cuda)
196196
endif()

paddle/fluid/framework/framework.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ message VarType {
156156
PSTRING = 29;
157157
// the data type of phi::SparseCooTensor
158158
SPARSE_COO = 30;
159+
// the data type of phi::SparseCsrTensor
160+
SPARSE_CSR = 31;
159161
}
160162

161163
required Type type = 1;
@@ -189,6 +191,7 @@ message VarType {
189191
optional TensorDesc strings = 9;
190192
optional TensorDesc vocab = 10;
191193
optional TensorDesc sparse_coo = 11;
194+
optional TensorDesc sparse_csr = 12;
192195
}
193196

194197
message VarDesc {

paddle/fluid/framework/infershape_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
106106
return var_type == proto::VarType::SPARSE_COO;
107107
}
108108

109+
bool IsSparseCsrTensorInput(const std::string& name) const override {
110+
auto var_type = ctx_.GetInputVarType(name);
111+
return var_type == proto::VarType::SPARSE_CSR;
112+
}
113+
109114
bool IsDenseTensorOutput(const std::string& name) const override {
110115
auto var_types = ctx_.GetOutputsVarType(name);
111116
return std::all_of(var_types.begin(),

paddle/fluid/framework/operator.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
529529
return var->IsType<phi::SparseCooTensor>();
530530
}
531531

532+
bool IsSparseCsrTensorInput(const std::string& name) const override {
533+
const auto* var = ctx_.InputVar(name);
534+
return var->IsType<phi::SparseCsrTensor>();
535+
}
536+
532537
bool IsDenseTensorOutput(const std::string& name) const override {
533538
auto vars = ctx_.MultiOutputVar(name);
534539
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {

paddle/fluid/framework/tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "paddle/fluid/framework/mixed_vector.h"
1919
#include "paddle/phi/core/dense_tensor.h"
2020
#include "paddle/phi/core/sparse_coo_tensor.h"
21+
#include "paddle/phi/core/sparse_csr_tensor.h"
2122

2223
namespace paddle {
2324
namespace framework {

paddle/fluid/framework/var_type_traits.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ namespace phi {
5555
class DenseTensor;
5656
class SelectedRows;
5757
class SparseCooTensor;
58+
class SparseCsrTensor;
5859
} // namespace phi
5960

6061
// Users should add forward declarations here
@@ -182,6 +183,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
182183
Tensor,
183184
phi::SelectedRows,
184185
phi::SparseCooTensor,
186+
phi::SparseCsrTensor,
185187
std::vector<Scope *>,
186188
LoDRankTable,
187189
Strings,

paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ bool PluginArgumentMappingContext::IsSparseCooTensorInput(
108108
const std::string& name) const {
109109
return false;
110110
}
111+
bool PluginArgumentMappingContext::IsSparseCsrTensorInput(
112+
const std::string& name) const {
113+
return false;
114+
}
111115
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
112116
const std::string& name) const {
113117
return false;

paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {
4848

4949
bool IsSparseCooTensorInput(const std::string& name) const override;
5050

51+
bool IsSparseCsrTensorInput(const std::string& name) const override;
52+
5153
bool IsDenseTensorVectorInput(const std::string& name) const override;
5254

5355
bool IsDenseTensorOutput(const std::string& name) const override;

0 commit comments

Comments
 (0)