Skip to content

Commit 434f6f9

Browse files
Adde support for fuse_linear_bn_eval
1 parent 5e8ef8f commit 434f6f9

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

src/TorchSharp/NN/Linear.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ public sealed class Linear : torch.nn.Module<Tensor, Tensor>
1717
const string WeightComponentName = nameof(weight);
1818
const string BiasComponentName = nameof(bias);
1919

20+
internal Linear(Parameter weight, Parameter? bias = null) : base(nameof(Linear))
21+
{
22+
this.in_features = weight.shape[1];
23+
this.out_features = weight.shape[0];
24+
25+
this.weight = weight;
26+
if (bias is not null) {
27+
this.bias = bias;
28+
}
29+
}
30+
2031
internal Linear(long inputSize, long outputSize, bool hasBias = true, Device? device = null, ScalarType? dtype = null) : base(nameof(Linear))
2132
{
2233
this.in_features = inputSize;
@@ -85,7 +96,7 @@ public static partial class torch
8596
public static partial class nn
8697
{
8798
/// <summary>
88-
/// Applies a linear transformation to the incoming data.
99+
/// Create a Linear module initialized with random weights and bias.
89100
/// </summary>
90101
/// <param name="inputSize">Size of each input sample</param>
91102
/// <param name="outputSize">Size of each output sample</param>
@@ -97,6 +108,16 @@ public static Linear Linear(long inputSize, long outputSize, bool hasBias = true
97108
return new Linear(inputSize, outputSize, hasBias, device, dtype);
98109
}
99110

111+
/// <summary>
112+
/// Create a Linear module with the given weights and bias.
113+
/// </summary>
114+
/// <param name="weight">The linear weight attribute.</param>
115+
/// <param name="bias">The additive linear bias. Optional.</param>
116+
public static Linear Linear(Parameter weight, Parameter? bias = null)
117+
{
118+
return new Linear(weight, bias);
119+
}
120+
100121
public static partial class functional
101122
{
102123
/// <summary>

src/TorchSharp/Torch.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,16 @@ public static (Parameter weight, Parameter bias) fuse_linear_bn_weights(
493493

494494
return scope.MoveToOuter(weight, bias);
495495
}
496+
497+
public static Linear fuse_linear_bn_eval(Linear linear, BatchNorm bn)
498+
{
499+
if (linear.training || bn.training)
500+
throw new InvalidOperationException("Fusing operators is valid only for eval mode.");
501+
502+
var (weight, bias) = fuse_linear_bn_weights(linear.weight, linear.bias, bn.running_mean!, bn.running_var!, bn.eps, bn.weight, bn.bias!);
503+
504+
return Linear(weight, bias);
505+
}
496506
}
497507
}
498508

test/TorchSharpTest/NN.cs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,29 @@ public void FunctionalBilinearNoBias()
325325
Assert.Equal(40, forward.shape[1]);
326326
Assert.Equal(device.type, forward.device_type);
327327
}
328+
}
329+
330+
[Fact]
331+
public void TestLinearFused()
332+
{
333+
var lin = Linear(15,15);
334+
var bn = BatchNorm1d(15);
335+
lin.eval();
336+
bn.eval();
337+
338+
Assert.NotNull(lin);
339+
Assert.NotNull(lin.bias);
340+
341+
var input = torch.rand(8,15);
342+
var expected = bn.forward(lin.forward(input));
343+
344+
var fused = torch.nn.utils.fuse_linear_bn_eval(lin, bn);
345+
var output = fused.forward(input);
346+
347+
var eStr = expected.str();
348+
var oStr = output.str();
349+
350+
Assert.True(expected.allclose(output, rtol: 1e-3, atol: 1e-3));
328351
}
329352

330353
[Fact]

0 commit comments

Comments
 (0)