Skip to content

Commit 6dfe886

Browse files
author
Niklas Gustafsson
committed
Fixed logic to keep track of which device ParameterDict and ParameterList modules are held on.
1 parent a046a66 commit 6dfe886

File tree

5 files changed

+38
-5
lines changed

5 files changed

+38
-5
lines changed

src/TorchSharp/JIT/ScriptModule.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,6 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex
190190
return this;
191191
}
192192

193-
private DeviceType _deviceType = DeviceType.CPU;
194-
private int _deviceIndex = -1;
195-
196193
/// <summary>
197194
/// Convert the parameters and buffers.
198195
/// </summary>

src/TorchSharp/NN/Module.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,8 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex, bo
203203
return this;
204204
}
205205

206-
private DeviceType _deviceType = DeviceType.CPU;
207-
private int _deviceIndex = -1;
206+
protected DeviceType _deviceType = DeviceType.CPU;
207+
protected int _deviceIndex = -1;
208208

209209
/// <summary>
210210
/// Convert the parameters and buffers.

src/TorchSharp/NN/ParameterDict.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ protected override void _toEpilog(torch.ScalarType? dtype, torch.Device device,
9898
_list[i] = (name, p);
9999
_dict[name] = p;
100100
}
101+
102+
if (device is not null) {
103+
_deviceType = device.type;
104+
_deviceIndex = device.index;
105+
}
101106
}
102107

103108
/// <summary>

src/TorchSharp/NN/ParameterList.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ protected override void _toEpilog(torch.ScalarType? dtype, torch.Device? device,
9090

9191
_list[i] = p;
9292
}
93+
94+
if (device is not null) {
95+
_deviceType = device.type;
96+
_deviceIndex = device.index;
97+
}
9398
}
9499

95100
private bool _registered = false;

test/TorchSharpTest/NN.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3170,6 +3170,32 @@ public void TestCustomModuleWithDeviceMove()
31703170
// Reset and then try again with moving back to CPU
31713171
module.zero_grad();
31723172

3173+
// Try moving back to CPU
3174+
module.to(torch.CPU);
3175+
x = torch.randn(2, 2);
3176+
y = torch.randn(2);
3177+
torch.nn.functional.mse_loss(module.call(x), y).backward();
3178+
foreach (var (pName, parm) in module.named_parameters()) {
3179+
var grad = parm.grad;
3180+
Assert.NotNull(grad);
3181+
}
3182+
}
3183+
if (torch.mps_is_available()) {
3184+
var module = new TestModule1(torch.randn(2, 2), true);
3185+
3186+
// Move the device to MPS, and make sure gradients are calculated for all the parameters
3187+
module.to(torch.MPS);
3188+
var x = torch.randn(2, 2, device: torch.MPS);
3189+
var y = torch.randn(2, device: torch.MPS);
3190+
torch.nn.functional.mse_loss(module.call(x), y).backward();
3191+
foreach (var (pName, parm) in module.named_parameters()) {
3192+
var grad = parm.grad;
3193+
Assert.NotNull(grad);
3194+
}
3195+
3196+
// Reset and then try again with moving back to CPU
3197+
module.zero_grad();
3198+
31733199
// Try moving back to CPU
31743200
module.to(torch.CPU);
31753201
x = torch.randn(2, 2);

0 commit comments

Comments
 (0)