Skip to content

Commit ad7ad4e

Browse files
1. Fixed problem with to() implementation for modules where parameters and/or buffers were declared on a base class.
2. Addressed issue with ParameterDict and ParameterList not doing _to() properly.
1 parent 59152f3 commit ad7ad4e

File tree

12 files changed

+292
-187
lines changed

12 files changed

+292
-187
lines changed

src/TorchSharp/NN/Convolution/Convolution.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ protected Convolution(string name, long in_channels, long out_channels, long[] k
4949
this.padding_mode = padding_mode;
5050

5151
// Set this so the constructor doesn't give a non-null error, and the actual value is set in the
52-
// SetPadding function called right after.
52+
// SetPadding function called right after.
5353
this._reversed_padding_repeated_twice = Array.Empty<long>();
5454
if (padding_type.HasValue)
5555
SetPadding(padding_type.Value);

src/TorchSharp/NN/Module.cs

Lines changed: 112 additions & 70 deletions
Large diffs are not rendered by default.

src/TorchSharp/NN/Normalization/NormBase.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using static TorchSharp.torch;
44
using static TorchSharp.torch.nn;
55
using static TorchSharp.PInvoke.NativeMethods;
6+
67
#nullable enable
78
namespace TorchSharp
89
{
@@ -14,13 +15,13 @@ namespace Modules
1415
{
1516
public abstract class NormBase : torch.nn.Module<Tensor, Tensor>
1617
{
17-
public NormBase(long num_features,
18-
double eps,
19-
double? momentum,
20-
bool affine,
21-
bool track_running_stats,
22-
Device? device,
23-
ScalarType? dtype,
18+
public NormBase(long num_features,
19+
double eps,
20+
double? momentum,
21+
bool affine,
22+
bool track_running_stats,
23+
Device? device,
24+
ScalarType? dtype,
2425
string name) : base(name)
2526
{
2627
this.num_features = num_features;
@@ -115,15 +116,15 @@ public Tensor? num_batches_tracked {
115116
ConditionallyRegisterBuffer(nameof(num_batches_tracked), _num_batches_tracked);
116117
}
117118
}
118-
119+
119120
public long num_features { get; private set; }
120-
121+
121122
public double eps { get; set; }
122-
123+
123124
public double? momentum { get; set; }
124125

125126
public bool affine { get; private set; }
126-
127+
127128
public bool track_running_stats { get; private set; }
128129

129130
[ComponentName(Name = nameof(bias))]

src/TorchSharp/NN/ParameterDict.cs

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace Modules
1414
{
1515
/// <summary>
1616
/// Holds parameters in a dictionary.
17-
///
17+
///
1818
/// ParameterDict can be indexed like a regular dictionary, but the parameters it
1919
/// contains are properly registered, and will be visible by all Module methods.
2020
///
@@ -60,34 +60,43 @@ protected override void RegisterComponents()
6060

6161
private bool _registered = false;
6262

63-
protected internal override Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
64-
{
65-
base._to(deviceType, deviceIndex, non_blocking);
66-
_toEpilog();
67-
return this;
68-
}
69-
70-
protected internal override Module _to(torch.Device device, torch.ScalarType dtype, bool non_blocking)
71-
{
72-
base._to(device, dtype, non_blocking);
73-
_toEpilog();
74-
return this;
75-
}
76-
77-
protected internal override Module _to(torch.ScalarType dtype, bool non_blocking)
78-
{
79-
base._to(dtype, non_blocking);
80-
_toEpilog();
81-
return this;
82-
}
83-
84-
void _toEpilog()
63+
protected override void _toEpilog(torch.ScalarType? dtype, torch.Device device, bool non_blocking)
8564
{
8665
for (int i = 0; i < _list.Count; i++) {
8766
string name = _list[i].Item1;
88-
var param = base.get_parameter(name);
89-
_list[i] = (name, param);
90-
_dict[name] = param;
67+
var param = _list[i].Item2;
68+
69+
using var grad = param.grad;
70+
71+
if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device) &&
72+
(grad is null || !grad.toWillCopy(dtype ?? param.dtype, device ?? param.device)))
73+
continue;
74+
75+
Parameter p;
76+
torch.ScalarType paramType =
77+
dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype;
78+
79+
// When moving the parameter, we don't want the autograd to track this movement on the graph.
80+
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
81+
// disable grad we would need to call .detach() on the moved tensor.
82+
using (var d = torch.no_grad()) {
83+
p = new Parameter(
84+
data: param.to(paramType, device ?? param.device),
85+
requires_grad: param.requires_grad);
86+
_ = p.DetachFromDisposeScope();
87+
88+
// Copy the gradient over as well, if it exists
89+
if (grad is not null) {
90+
using var newGrad = grad.to(paramType, device ?? param.device)
91+
.with_requires_grad(grad.requires_grad);
92+
p.grad = newGrad;
93+
}
94+
}
95+
96+
param?.Dispose();
97+
98+
_list[i] = (name, p);
99+
_dict[name] = p;
91100
}
92101
}
93102

@@ -136,6 +145,12 @@ public override Parameter get_parameter(string target)
136145
return null;
137146
}
138147

148+
public override IEnumerable<(string name, Parameter parameter)> named_parameters(bool recurse)
149+
{
150+
// Ignore the 'recurse' parameter.
151+
return _dict.Select(d => (d.Key, d.Value));
152+
}
153+
139154
public void Add((string, Parameter) item)
140155
{
141156
_dict.Add(item.Item1, item.Item2);

src/TorchSharp/NN/ParameterList.cs

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,35 +33,6 @@ protected override void RegisterComponents()
3333
_registered = true;
3434
}
3535

36-
37-
protected internal override Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
38-
{
39-
base._to(deviceType, deviceIndex, non_blocking);
40-
_toEpilog();
41-
return this;
42-
}
43-
44-
protected internal override Module _to(torch.Device device, torch.ScalarType dtype, bool non_blocking)
45-
{
46-
base._to(device, dtype, non_blocking);
47-
_toEpilog();
48-
return this;
49-
}
50-
51-
protected internal override Module _to(torch.ScalarType dtype, bool non_blocking)
52-
{
53-
base._to(dtype, non_blocking);
54-
_toEpilog();
55-
return this;
56-
}
57-
58-
void _toEpilog()
59-
{
60-
for (int i = 0; i < _list.Count; i++) {
61-
_list[i] = base.get_parameter($"{i}");
62-
}
63-
}
64-
6536
public override IEnumerable<(string name, Parameter parameter)> named_parameters(bool recurse = true)
6637
{
6738
return Enumerable.Range(0, _list.Count).Select(i => ($"{i}", _list[i]));
@@ -80,6 +51,46 @@ public override Parameter get_parameter(string target)
8051
return null;
8152
}
8253

54+
protected override void _toEpilog(torch.ScalarType? dtype, torch.Device device, bool non_blocking)
55+
{
56+
for (int i = 0; i < _list.Count; i++) {
57+
58+
string name = $"{i}";
59+
var param = _list[i];
60+
61+
using var grad = param.grad;
62+
63+
if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device) &&
64+
(grad is null || !grad.toWillCopy(dtype ?? param.dtype, device ?? param.device)))
65+
continue;
66+
67+
Parameter p;
68+
torch.ScalarType paramType =
69+
dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype;
70+
71+
// When moving the parameter, we don't want the autograd to track this movement on the graph.
72+
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
73+
// disable grad we would need to call .detach() on the moved tensor.
74+
using (var d = torch.no_grad()) {
75+
p = new Parameter(
76+
data: param.to(paramType, device ?? param.device),
77+
requires_grad: param.requires_grad);
78+
_ = p.DetachFromDisposeScope();
79+
80+
// Copy the gradient over as well, if it exists
81+
if (grad is not null) {
82+
using var newGrad = grad.to(paramType, device ?? param.device)
83+
.with_requires_grad(grad.requires_grad);
84+
p.grad = newGrad;
85+
}
86+
}
87+
88+
param?.Dispose();
89+
90+
_list[i] = p;
91+
}
92+
}
93+
8394
private bool _registered = false;
8495

8596
public Parameter this[int index] {

src/TorchVision/models/AlexNet.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public static partial class models
2424
///
2525
/// from torchvision import models
2626
/// import exportsd
27-
///
27+
///
2828
/// model = models.alexnet(pretrained=True)
2929
/// f = open("model_weights.dat", "wb")
3030
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -105,7 +105,7 @@ public AlexNet(int numClasses, float dropout = 0.5f, string? weights_file = null
105105

106106
if (!string.IsNullOrEmpty(weights_file)) {
107107

108-
this.load(weights_file, skip: skipfc ? new[] { "classifier.6.weight", "classifier.6.bias" } : null);
108+
this.load(weights_file!, skip: skipfc ? new[] { "classifier.6.weight", "classifier.6.bias" } : null);
109109
}
110110

111111
if (device != null && device.type != DeviceType.CPU)

src/TorchVision/models/GoogleNet.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public static partial class models
2626
///
2727
/// from torchvision import models
2828
/// import exportsd
29-
///
29+
///
3030
/// model = models.inception_v3(pretrained=True)
3131
/// f = open("model_weights.dat", "wb")
3232
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -170,7 +170,7 @@ public GoogleNet(int numClasses = 1000,
170170
break;
171171
}
172172
}
173-
this.load(weights_file, skip: skipfc ? new[] { "fc.weight", "fc.bias", "AuxLogits.fc.weight", "AuxLogits.fc.bias" } : null);
173+
this.load(weights_file!, skip: skipfc ? new[] { "fc.weight", "fc.bias", "AuxLogits.fc.weight", "AuxLogits.fc.bias" } : null);
174174
}
175175

176176
if (device != null && device.type != DeviceType.CPU)

src/TorchVision/models/InceptionV3.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public static partial class models
2525
///
2626
/// from torchvision import models
2727
/// import exportsd
28-
///
28+
///
2929
/// model = models.inception_v3(pretrained=True)
3030
/// f = open("model_weights.dat", "wb")
3131
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -170,7 +170,7 @@ public InceptionV3(int numClasses = 1000,
170170
break;
171171
}
172172
}
173-
this.load(weights_file, skip: skipfc ? new[] { "fc.weight", "fc.bias", "AuxLogits.fc.weight", "AuxLogits.fc.bias" } : null);
173+
this.load(weights_file!, skip: skipfc ? new[] { "fc.weight", "fc.bias", "AuxLogits.fc.weight", "AuxLogits.fc.bias" } : null);
174174
}
175175

176176
if (device != null && device.type != DeviceType.CPU)

src/TorchVision/models/ResNet.cs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static partial class models
3030
///
3131
/// from torchvision import models
3232
/// import exportsd
33-
///
33+
///
3434
/// model = models.resnet18(pretrained=True)
3535
/// f = open("model_weights.dat", "wb")
3636
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -86,7 +86,7 @@ public static Modules.ResNet resnet18(
8686
///
8787
/// from torchvision import models
8888
/// import exportsd
89-
///
89+
///
9090
/// model = models.resnet34(pretrained=True)
9191
/// f = open("model_weights.dat", "wb")
9292
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -142,7 +142,7 @@ public static Modules.ResNet resnet34(
142142
///
143143
/// from torchvision import models
144144
/// import exportsd
145-
///
145+
///
146146
/// model = models.resnet50(pretrained=True)
147147
/// f = open("model_weights.dat", "wb")
148148
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -197,7 +197,7 @@ public static Modules.ResNet resnet50(
197197
///
198198
/// from torchvision import models
199199
/// import exportsd
200-
///
200+
///
201201
/// model = models.wide_resnet50_2(pretrained=True)
202202
/// f = open("model_weights.dat", "wb")
203203
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -251,7 +251,7 @@ public static Modules.ResNet wide_resnet50_2(
251251
///
252252
/// from torchvision import models
253253
/// import exportsd
254-
///
254+
///
255255
/// model = models.resnext50_32x4d(pretrained=True)
256256
/// f = open("model_weights.dat", "wb")
257257
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -305,7 +305,7 @@ public static Modules.ResNet resnext50_32x4d(
305305
///
306306
/// from torchvision import models
307307
/// import exportsd
308-
///
308+
///
309309
/// model = models.resnet101(pretrained=True)
310310
/// f = open("model_weights.dat", "wb")
311311
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -360,7 +360,7 @@ public static Modules.ResNet resnet101(
360360
///
361361
/// from torchvision import models
362362
/// import exportsd
363-
///
363+
///
364364
/// model = models.resnext101_32x8d(pretrained=True)
365365
/// f = open("model_weights.dat", "wb")
366366
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -413,7 +413,7 @@ public static Modules.ResNet resnext101_32x8d(
413413
///
414414
/// from torchvision import models
415415
/// import exportsd
416-
///
416+
///
417417
/// model = models.resnext101_32x8d(pretrained=True)
418418
/// f = open("model_weights.dat", "wb")
419419
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -466,7 +466,7 @@ public static Modules.ResNet resnext101_64x4d(
466466
///
467467
/// from torchvision import models
468468
/// import exportsd
469-
///
469+
///
470470
/// model = models.resnet101(pretrained=True)
471471
/// f = open("model_weights.dat", "wb")
472472
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -521,7 +521,7 @@ public static Modules.ResNet wide_resnet101_2(
521521
///
522522
/// from torchvision import models
523523
/// import exportsd
524-
///
524+
///
525525
/// model = models.resnet152(pretrained=True)
526526
/// f = open("model_weights.dat", "wb")
527527
/// exportsd.save_state_dict(model.state_dict(), f)
@@ -825,7 +825,7 @@ public ResNet(string name,
825825

826826
} else {
827827

828-
this.load(weights_file, skip: skipfc ? new[] { "fc.weight", "fc.bias" } : null);
828+
this.load(weights_file!, skip: skipfc ? new[] { "fc.weight", "fc.bias" } : null);
829829
}
830830

831831
if (device != null && device.type != DeviceType.CPU)

0 commit comments

Comments
 (0)