Skip to content

Commit bf73b9f

Browse files
Started implementing functional::normalize()
1 parent df3b503 commit bf73b9f

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/Native/LibTorchSharp/THSNormalization.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,15 @@ Tensor THSNN_batch_norm(const Tensor input, Tensor running_mean, const Tensor ru
607607
CATCH_TENSOR(torch::nn::functional::batch_norm(*input, *running_mean, *running_var, opts));
608608
}
609609

610+
Tensor THSNN_normalize(const Tensor input, const double p, const int64_t dim, const double eps)
611+
{
612+
auto opts = torch::nn::functional::NormalizeFuncOptions()
613+
.p(p)
614+
.dim(dim)
615+
.eps(eps);
616+
CATCH_TENSOR(torch::nn::functional::normalize(*input, opts));
617+
}
618+
610619
Tensor THSNN_group_norm(const Tensor input, const int64_t num_groups, const Tensor weight, const Tensor bias, const double eps)
611620
{
612621
auto opts = torch::nn::functional::GroupNormFuncOptions(num_groups)

src/TorchSharp/NN/Normalization/Functional.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@ public static partial class nn
1010
{
1111
public static partial class functional
1212
{
13+
/// <summary>
14+
/// Perform normalization of inputs over specified dimension.
15+
/// </summary>
16+
/// <param name="input">Input tensor of any shape.</param>
17+
/// <param name="p">the exponent value in the norm formulation</param>
18+
/// <param name="dim">the dimension to reduce</param>
19+
/// <param name="eps">small value to avoid division by zero</param>
20+
public static Tensor normalize(Tensor input, double p = 2.0, long dim = 1L, double eps = 1e-12)
21+
{
22+
return null;
23+
}
24+
1325
/// <summary>
1426
/// Applies Batch Normalization for each channel across a batch of data.
1527
/// </summary>

0 commit comments

Comments
 (0)