Skip to content

Commit 6bdf60b

Browse files
Merge pull request #1319 from NiklasGustafsson/missing
Adding 'set/get_default_device()'
2 parents b6f55da + fedb729 commit 6bdf60b

File tree

7 files changed

+87
-5
lines changed

7 files changed

+87
-5
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The constructor of dispose scopes is no longer `public`. Use `torch.NewDisposeSc
1212

1313
__API Changes__:
1414

15+
#1317 How to set default device type in torchsharp.<br/>
1516
#1314 Grant read-only access to DataLoader attributes<br/>
1617
#1313 Add 'non_blocking' argument to tensor and module 'to()' signatures.<br/>
1718
#1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.<br/>

src/TorchSharp/NN/Module.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,10 @@ protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Devic
10631063
if (!dtype.HasValue)
10641064
dtype = get_default_dtype();
10651065

1066-
if (device == null)
1067-
device = torch.CPU;
1066+
if (device is null)
1067+
{
1068+
device = get_default_device();
1069+
}
10681070

10691071
return (device, dtype.Value);
10701072
}

src/TorchSharp/Tensor/Factories/Tensor.Factories.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ private static void ValidateIntegerRange(long value, ScalarType dtype, string ar
579579

580580
private static ConcurrentDictionary<TorchSharp.PInvoke.GCHandleDeleter, TorchSharp.PInvoke.GCHandleDeleter> deleters;
581581
private static ScalarType default_dtype = ScalarType.Float32;
582+
private static Device default_device = new Device(DeviceType.CPU, -1);
582583

583584
static torch()
584585
{

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ public Tensor to(DeviceType deviceType, int deviceIndex = -1, bool copy = false,
880880
/// <param name="non_blocking">Try to convert asynchronously with respect to the host if possible, e.g., converting a CPU Tensor with pinned memory to a CUDA Tensor.</param>
881881
public Tensor to(ScalarType type, torch.Device device, bool copy = false, bool disposeAfter = false, bool non_blocking = false)
882882
{
883-
torch.InitializeDevice(device);
883+
device = torch.InitializeDevice(device);
884884
var res = NativeMethods.THSTensor_to_type_and_device(Handle, (sbyte)type, (int)device.type, device.index, copy, non_blocking);
885885
if (res == IntPtr.Zero)
886886
CheckForErrors();

src/TorchSharp/Tensor/torch.Tensors.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,23 @@ public static void set_default_dtype(ScalarType dtype)
6161
/// <returns></returns>
6262
[Pure]public static ScalarType get_default_dtype() => default_dtype;
6363

64+
[Pure]public static Device get_default_device() => default_device;
65+
6466
// https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type
6567
public static void set_default_tensor_type(Tensor t) => set_default_dtype(t.dtype);
6668

69+
public static void set_default_device(Device device)
70+
{
71+
if (device == null)
72+
throw new ArgumentNullException(nameof(device));
73+
default_device = device;
74+
}
75+
76+
public static void set_default_device(string device)
77+
{
78+
set_default_device(new Device(device));
79+
}
80+
6781
// https://pytorch.org/docs/stable/generated/torch.numel
6882
/// <summary>
6983
/// Get the number of elements in the input tensor.

src/TorchSharp/Torch.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,10 @@ public static void InitializeDeviceType(DeviceType deviceType)
298298

299299
public static Device InitializeDevice(Device? device)
300300
{
301-
if (device == null)
302-
device = new Device(DeviceType.CPU, -1);
301+
if (device is null)
302+
{
303+
device = get_default_device();
304+
}
303305
InitializeDeviceType(device.type);
304306
return device;
305307
}

test/TorchSharpTest/TestTorchTensor.cs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8375,5 +8375,67 @@ public void TestSaveAndLoadModuleWithLarger2GBTensorCUDA()
83758375
}
83768376
}
83778377
}
8378+
8379+
[Fact]
8380+
public void DefaultDTypeCreation()
8381+
{
8382+
var dt = torch.get_default_dtype();
8383+
8384+
var t = torch.zeros(5,5);
8385+
Assert.Equal(torch.float32, t.dtype);
8386+
8387+
try {
8388+
torch.set_default_dtype(torch.float64);
8389+
8390+
t = torch.zeros(5,5);
8391+
Assert.Equal(torch.float64, t.dtype);
8392+
8393+
t = torch.ones(5,5);
8394+
Assert.Equal(torch.float64, t.dtype);
8395+
8396+
t = torch.rand(5,5);
8397+
Assert.Equal(torch.float64, t.dtype);
8398+
8399+
t = torch.randn(5,5);
8400+
Assert.Equal(torch.float64, t.dtype);
8401+
8402+
t = torch.logspace(5, 15, 20);
8403+
Assert.Equal(torch.float64, t.dtype);
8404+
}
8405+
finally {
8406+
torch.set_default_dtype(dt);
8407+
}
8408+
}
8409+
8410+
[Fact]
8411+
public void DefaultDeviceCreation()
8412+
{
8413+
var dt = torch.get_default_device();
8414+
8415+
var t = torch.zeros(5,5);
8416+
Assert.Equal(DeviceType.CPU, t.device_type);
8417+
8418+
try {
8419+
torch.set_default_device(torch.META);
8420+
8421+
t = torch.zeros(5,5);
8422+
Assert.Equal(DeviceType.META, t.device_type);
8423+
8424+
t = torch.ones(5,5);
8425+
Assert.Equal(DeviceType.META, t.device_type);
8426+
8427+
t = torch.rand(5,5);
8428+
Assert.Equal(DeviceType.META, t.device_type);
8429+
8430+
t = torch.randn(5,5);
8431+
Assert.Equal(DeviceType.META, t.device_type);
8432+
8433+
t = torch.logspace(5, 15, 20);
8434+
Assert.Equal(DeviceType.META, t.device_type);
8435+
}
8436+
finally {
8437+
torch.set_default_device(dt);
8438+
}
8439+
}
83788440
}
83798441
}

0 commit comments

Comments
 (0)