File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed
Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -38,6 +38,21 @@ void GroupNormKernel(const Context& dev_ctx,
3838 DenseTensor* mean,
3939 DenseTensor* variance);
4040
41+ template <typename T, typename Context>
42+ void GroupNormNDHWCKernel (const Context& dev_ctx,
43+ const DenseTensor& x,
44+ const paddle::optional<DenseTensor>& residual,
45+ const paddle::optional<DenseTensor>& scale,
46+ const paddle::optional<DenseTensor>& bias,
47+ float epsilon,
48+ int groups,
49+ const std::string& data_layout_str,
50+ const std::string& activation,
51+ DenseTensor* y,
52+ DenseTensor* residual_out,
53+ DenseTensor* mean,
54+ DenseTensor* var);
55+
4156#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
4257template <typename T, typename AccT = T>
4358class GroupNormDirectCUDAFunctor {
You can’t perform that action at this time.
0 commit comments