Skip to content

Commit e5b1e82

Browse files
Manual merge.
1 parent 738259e commit e5b1e82

File tree

3 files changed

+14
-19
lines changed

3 files changed

+14
-19
lines changed

src/TorchSharp/NN/ParamLessModule.cs

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
2424

2525
// Rather than spending cycles only to discover that this module has neither
2626
// parameters nor buffers, just shortcut the move completely.
27-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
28-
29-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
30-
31-
protected internal override nn.Module _to(ScalarType dtype) => this;
27+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
28+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
29+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
3230

3331
public override void register_buffer(string name, Tensor tensor, bool persistent = true)
3432
{
@@ -60,11 +58,10 @@ protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxe
6058

6159
// Rather than spending cycles only to discover that this module has neither
6260
// parameters nor buffers, just shortcut the move completely.
63-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
64-
65-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
61+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
62+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
63+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
6664

67-
protected internal override nn.Module _to(ScalarType dtype) => this;
6865

6966
public override void register_buffer(string name, Tensor tensor, bool persistent = true)
7067
{
@@ -94,13 +91,11 @@ protected ParamLessModule(string name) : base(name) { }
9491

9592
protected ParamLessModule(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) {}
9693

97-
// Rather than spending cycles only to discover that this module has neither
94+
// Rather than spending cycles only to discover that this module has neither
9895
// parameters nor buffers, just shortcut the move completely.
99-
protected internal override nn.Module _to(Device device, ScalarType dtype) => this;
100-
101-
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) => this;
102-
103-
protected internal override nn.Module _to(ScalarType dtype) => this;
96+
protected internal override nn.Module _to(Device device, ScalarType dtype, bool non_blocking) => this;
97+
protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex, bool non_blocking) => this;
98+
protected internal override nn.Module _to(ScalarType dtype, bool non_blocking) => this;
10499

105100
public override void register_buffer(string name, Tensor tensor, bool persistent = true)
106101
{

src/TorchSharp/NN/Parameter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public Parameter(Tensor data, bool requires_grad = true) :
2929
var scope = data.OwningDisposeScope;
3030
if (scope is not null) {
3131
this.OwningDisposeScope = scope;
32-
scope.Include(this);
32+
scope.Attach(this);
3333
scope.Detach(data);
3434
}
3535
}

src/TorchSharp/NN/Upsample.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ public sealed class Upsample : ParamLessModule<Tensor, Tensor>
1717
{
1818
internal Upsample(long[]? size, double[]? scale_factor, UpsampleMode mode, bool? align_corners, bool? recompute_scale_factor) : base(nameof(Upsample))
1919
{
20-
this.size = size;
21-
this.scale_factor = scale_factor;
20+
this._size = size;
21+
this._scale_factor = scale_factor;
2222
this.mode = mode;
2323
this.align_corners = align_corners;
2424
this.recompute_scale_factor = recompute_scale_factor;
@@ -31,7 +31,7 @@ internal Upsample(long[]? size, double[]? scale_factor, UpsampleMode mode, bool?
3131
/// <returns></returns>
3232
public override Tensor forward(Tensor input)
3333
{
34-
return torch.nn.functional.interpolate(input, size, scale_factor, (InterpolationMode)mode, align_corners, recompute_scale_factor ?? false);
34+
return torch.nn.functional.interpolate(input, _size, _scale_factor, (InterpolationMode)mode, align_corners, recompute_scale_factor ?? false);
3535
}
3636

3737
public bool? recompute_scale_factor { get; set; }

0 commit comments

Comments
 (0)