|
14 | 14 |
|
15 | 15 | #include "lite/kernels/xpu/gru_compute.h"
|
16 | 16 |
|
17 |
| -// REGISTER_LITE_KERNEL( |
18 |
| -// gru, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::GRUCompute, def) |
19 |
| -// .BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))}) |
20 |
| -// .BindInput("H0", {LiteType::GetTensorTy(TARGET(kXPU))}) |
21 |
| -// .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kXPU))}) |
22 |
| -// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))}) |
23 |
| -// .BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kXPU))}) |
24 |
| -// .BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kXPU))}) |
25 |
| -// .BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kXPU))}) |
26 |
| -// .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kXPU))}) |
27 |
| -// .Finalize(); |
| 17 | +REGISTER_LITE_KERNEL( |
| 18 | + gru, kXPU, kFloat, kNCHW, paddle::lite::kernels::xpu::GRUCompute, def) |
| 19 | + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 20 | + .BindInput("H0", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 21 | + .BindInput("Weight", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 22 | + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 23 | + .BindOutput("BatchGate", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 24 | + .BindOutput("BatchResetHiddenPrev", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 25 | + .BindOutput("BatchHidden", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 26 | + .BindOutput("Hidden", {LiteType::GetTensorTy(TARGET(kXPU))}) |
| 27 | + .Finalize(); |
0 commit comments