Skip to content

Commit 1736f0f

Browse files
Added bespoke _to() implementations for a number of built-in modules. This is a minor performance improvement.
1 parent 2b01a71 commit 1736f0f

File tree

7 files changed

+286
-42
lines changed

7 files changed

+286
-42
lines changed

src/TorchSharp/NN/Bilinear.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,39 @@ public Parameter weight {
6969
}
7070
}
7171

72+
// Rather than spending cycles discovering what parameters exist, we can just hardcode it.
73+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) {
74+
if (_weight is not null && ReplaceParameter(dtype, device, _weight, out Parameter? w)) {
75+
weight = w!;
76+
}
77+
if (_bias is not null && ReplaceParameter(dtype, device, _bias, out Parameter? b)) {
78+
bias = b!;
79+
}
80+
return this;
81+
}
82+
83+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
84+
{
85+
var device = new Device(deviceType, deviceIndex);
86+
if (_weight is not null && ReplaceParameter(_weight.dtype, device, _weight, out Parameter? w)) {
87+
weight = w!;
88+
}
89+
if (_bias is not null && ReplaceParameter(_bias.dtype, device, _bias, out Parameter? b)) {
90+
bias = b!;
91+
}
92+
return this;
93+
}
94+
95+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) {
96+
if (_weight is not null && ReplaceParameter(dtype, _weight.device, _weight, out Parameter? w)) {
97+
weight = w!;
98+
}
99+
if (_bias is not null && ReplaceParameter(dtype, _bias.device, _bias, out Parameter? b)) {
100+
bias = b!;
101+
}
102+
return this;
103+
}
104+
72105
[ComponentName(Name = BiasComponentName)]
73106
private Parameter? _bias;
74107
[ComponentName(Name = WeightComponentName)]

src/TorchSharp/NN/Convolution/Convolution.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,39 @@ public Parameter weight {
154154
}
155155
}
156156

157+
// Rather than spending cycles discovering what parameters exist, we can just hardcode it.
158+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) {
159+
if (_weight is not null && ReplaceParameter(dtype, device, _weight, out Parameter? w)) {
160+
weight = w!;
161+
}
162+
if (_bias is not null && ReplaceParameter(dtype, device, _bias, out Parameter? b)) {
163+
bias = b!;
164+
}
165+
return this;
166+
}
167+
168+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
169+
{
170+
var device = new Device(deviceType, deviceIndex);
171+
if (_weight is not null && ReplaceParameter(_weight.dtype, device, _weight, out Parameter? w)) {
172+
weight = w!;
173+
}
174+
if (_bias is not null && ReplaceParameter(_bias.dtype, device, _bias, out Parameter? b)) {
175+
bias = b!;
176+
}
177+
return this;
178+
}
179+
180+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) {
181+
if (_weight is not null && ReplaceParameter(dtype, _weight.device, _weight, out Parameter? w)) {
182+
weight = w!;
183+
}
184+
if (_bias is not null && ReplaceParameter(dtype, _bias.device, _bias, out Parameter? b)) {
185+
bias = b!;
186+
}
187+
return this;
188+
}
189+
157190
[ComponentName(Name = BiasComponentName)]
158191
protected Parameter? _bias;
159192
[ComponentName(Name = WeightComponentName)]

src/TorchSharp/NN/Linear.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,39 @@ public Parameter weight {
7979
}
8080
}
8181

82+
// Rather than spending cycles discovering what parameters exist, we can just hardcode it.
83+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) {
84+
if (_weight is not null && ReplaceParameter(dtype, device, _weight, out var w)) {
85+
weight = w!;
86+
}
87+
if (_bias is not null && ReplaceParameter(dtype, device, _bias, out var b)) {
88+
bias = b!;
89+
}
90+
return this;
91+
}
92+
93+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
94+
{
95+
var device = new Device(deviceType, deviceIndex);
96+
if (_weight is not null && ReplaceParameter(_weight.dtype, device, _weight, out var w)) {
97+
weight = w!;
98+
}
99+
if (_bias is not null && ReplaceParameter(_bias.dtype, device, _bias, out var b)) {
100+
bias = b!;
101+
}
102+
return this;
103+
}
104+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) {
105+
if (_weight is not null && ReplaceParameter(dtype, _weight.device, _weight, out var w)) {
106+
weight = w!;
107+
}
108+
if (_bias is not null && ReplaceParameter(dtype, _bias.device, _bias, out var b)) {
109+
bias = b!;
110+
}
111+
return this;
112+
}
113+
114+
82115
[ComponentName(Name = BiasComponentName)]
83116
private Parameter? _bias;
84117
[ComponentName(Name = WeightComponentName)]

src/TorchSharp/NN/Module.cs

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -253,44 +253,21 @@ protected virtual void _toEpilog(ScalarType? dtype, Device? device, bool non_blo
253253

254254
var props = GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance);
255255

256-
var propsByName = new Dictionary<string, PropertyInfo>();
257-
foreach (var p in props) {
258-
// There may be duplicates, and this just overwrites it.
259-
propsByName[p.Name] = p;
260-
}
256+
// var propsByName = new Dictionary<string, PropertyInfo>();
257+
// foreach (var p in props) {
258+
// // There may be duplicates, and this just overwrites it.
259+
// propsByName[p.Name] = p;
260+
// }
261+
262+
var propsByName = props.ToDictionary(prop => prop.Name);
261263

262264
foreach (var (name, param) in named_parameters(false).ToList()) {
263-
using var grad = param.grad;
264-
265-
if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device) &&
266-
(grad is null || !grad.toWillCopy(dtype ?? param.dtype, device ?? param.device)))
267-
continue;
268-
269-
Parameter p;
270-
ScalarType paramType =
271-
dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype;
272-
273-
// When moving the parameter, we don't want the autograd to track this movement on the graph.
274-
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
275-
// disable grad we would need to call .detach() on the moved tensor.
276-
using (var d = torch.no_grad()) {
277-
p = new Parameter(
278-
data: param.to(paramType, device ?? param.device),
279-
requires_grad: param.requires_grad);
280-
_ = p.DetachFromDisposeScope();
281-
282-
// Copy the gradient over as well, if it exists
283-
if (grad is not null) {
284-
using var newGrad = grad.to(paramType, device ?? param.device)
285-
.with_requires_grad(grad.requires_grad);
286-
p.grad = newGrad;
287-
}
288-
}
265+
266+
if (!ReplaceParameter(dtype, device, param, out var p)) continue;
289267

290268
if (propsByName.TryGetValue(name, out var property)) {
291269
property.SetValue(this, p);
292-
}
293-
else {
270+
} else {
294271
param?.Dispose();
295272

296273
ConditionallyRegisterParameter(name, p);
@@ -304,17 +281,11 @@ protected virtual void _toEpilog(ScalarType? dtype, Device? device, bool non_blo
304281

305282
foreach (var (name, buffer) in named_buffers(false).ToList()) {
306283

307-
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
284+
if (!ReplaceBuffer(dtype, device, buffer, out var t)) continue;
308285

309-
ScalarType bufferType =
310-
dtype != null && (buffer.dtype.IsFloatingPoint() || buffer.dtype.IsComplex()) ? dtype.Value : buffer.dtype;
311-
312-
// Buffers don't get grads so we don't need to detach them afterwards
313-
var t = buffer.to(bufferType, device ?? buffer.device, disposeAfter: true).DetachFromDisposeScope();
314286
if (propsByName.TryGetValue(name, out var property)) {
315287
property.SetValue(this, t);
316-
}
317-
else {
288+
} else {
318289
ConditionallyRegisterBuffer(name, t);
319290
if (fieldsByComponentName.TryGetValue(name, out var field))
320291
field.SetValue(this, t);
@@ -327,6 +298,51 @@ protected virtual void _toEpilog(ScalarType? dtype, Device? device, bool non_blo
327298
}
328299
}
329300

301+
protected static bool ReplaceBuffer(ScalarType? dtype, Device? device, Tensor buffer, out Tensor? result)
302+
{
303+
result = null;
304+
305+
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) return false;
306+
307+
ScalarType bufferType =
308+
dtype != null && (buffer.dtype.IsFloatingPoint() || buffer.dtype.IsComplex()) ? dtype.Value : buffer.dtype;
309+
310+
// Buffers don't get grads so we don't need to detach them afterwards
311+
result = buffer.to(bufferType, device ?? buffer.device, disposeAfter: true).DetachFromDisposeScope();
312+
return true;
313+
}
314+
315+
protected static bool ReplaceParameter(ScalarType? dtype, Device? device, Parameter param, out Parameter? p)
316+
{
317+
Tensor? grad = param.grad;
318+
p = null;
319+
320+
if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device) &&
321+
(grad is null || !grad.toWillCopy(dtype ?? param.dtype, device ?? param.device)))
322+
return false;
323+
324+
ScalarType paramType =
325+
dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype;
326+
327+
// When moving the parameter, we don't want the autograd to track this movement on the graph.
328+
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
329+
// disable grad we would need to call .detach() on the moved tensor.
330+
using (var d = torch.no_grad()) {
331+
p = new Parameter(
332+
data: param.to(paramType, device ?? param.device),
333+
requires_grad: param.requires_grad);
334+
_ = p.DetachFromDisposeScope();
335+
336+
// Copy the gradient over as well, if it exists
337+
if (grad is not null) {
338+
using var newGrad = grad.to(paramType, device ?? param.device)
339+
.with_requires_grad(grad.requires_grad);
340+
p.grad = newGrad;
341+
}
342+
}
343+
return true;
344+
}
345+
330346
private static IEnumerable<FieldInfo> GetFieldsRecursive(Type type, BindingFlags bindingFlags) {
331347

332348
Type? currentType = type;

src/TorchSharp/NN/Normalization/GroupNorm.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ internal GroupNorm(long num_groups, long num_channels, double eps, bool affine,
3333

3434
public override Tensor forward(Tensor tensor)
3535
{
36-
if (tensor.Dimensions < 3)
36+
if (tensor.Dimensions < 3)
3737
throw new ArgumentException($"Invalid number of dimensions for GroupNorm argument: {tensor.Dimensions}");
3838
return F.group_norm(tensor, num_groups, weight, bias, eps);
3939
}
@@ -66,6 +66,39 @@ public Parameter weight {
6666
}
6767
}
6868

69+
// Rather than spending cycles discovering what parameters exist, we can just hardcode it.
70+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) {
71+
if (_weight is not null && ReplaceParameter(dtype, device, _weight, out Parameter? w)) {
72+
weight = w!;
73+
}
74+
if (_bias is not null && ReplaceParameter(dtype, device, _bias, out Parameter? b)) {
75+
bias = b!;
76+
}
77+
return this;
78+
}
79+
80+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
81+
{
82+
var device = new Device(deviceType, deviceIndex);
83+
if (_weight is not null && ReplaceParameter(_weight.dtype, device, _weight, out Parameter? w)) {
84+
weight = w!;
85+
}
86+
if (_bias is not null && ReplaceParameter(_bias.dtype, device, _bias, out Parameter? b)) {
87+
bias = b!;
88+
}
89+
return this;
90+
}
91+
92+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) {
93+
if (_weight is not null && ReplaceParameter(dtype, _weight.device, _weight, out Parameter? w)) {
94+
weight = w!;
95+
}
96+
if (_bias is not null && ReplaceParameter(dtype, _bias.device, _bias, out Parameter? b)) {
97+
bias = b!;
98+
}
99+
return this;
100+
}
101+
69102
[ComponentName(Name = nameof(bias))]
70103
private Parameter? _bias;
71104
[ComponentName(Name = nameof(weight))]

src/TorchSharp/NN/Normalization/LayerNorm.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,39 @@ public Parameter weight {
8484
}
8585
}
8686

87+
// Rather than spending cycles discovering what parameters exist, we can just hardcode it.
88+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) {
89+
if (_weight is not null && ReplaceParameter(dtype, device, _weight, out Parameter? w)) {
90+
weight = w!;
91+
}
92+
if (_bias is not null && ReplaceParameter(dtype, device, _bias, out Parameter? b)) {
93+
bias = b!;
94+
}
95+
return this;
96+
}
97+
98+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking)
99+
{
100+
var device = new Device(deviceType, deviceIndex);
101+
if (_weight is not null && ReplaceParameter(_weight.dtype, device, _weight, out Parameter? w)) {
102+
weight = w!;
103+
}
104+
if (_bias is not null && ReplaceParameter(_bias.dtype, device, _bias, out Parameter? b)) {
105+
bias = b!;
106+
}
107+
return this;
108+
}
109+
110+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) {
111+
if (_weight is not null && ReplaceParameter(dtype, _weight.device, _weight, out Parameter? w)) {
112+
weight = w!;
113+
}
114+
if (_bias is not null && ReplaceParameter(dtype, _bias.device, _bias, out Parameter? b)) {
115+
bias = b!;
116+
}
117+
return this;
118+
}
119+
87120
[ComponentName(Name = BiasComponentName)]
88121
private Parameter? _bias;
89122
[ComponentName(Name = WeightComponentName)]

0 commit comments

Comments
 (0)