Skip to content

Commit d4ff8f9

Browse files
Adding 'set/get_default_device()'
1 parent b6f55da commit d4ff8f9

File tree

11 files changed

+157
-2
lines changed

11 files changed

+157
-2
lines changed

src/TorchSharp/FFT.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ public static Tensor ifftshift(Tensor input, long[] dim = null)
350350
/// <param name="requires_grad">If autograd should record operations on the returned tensor.</param>
351351
public static Tensor fftfreq(long n, double d = 1.0, torch.ScalarType? dtype = null, torch.Device device = null, bool requires_grad = false)
352352
{
353+
if (device is null)
354+
{
355+
device = get_default_device();
356+
}
353357
device = torch.InitializeDevice(device);
354358
if (!dtype.HasValue) {
355359
// Determine the element type dynamically.
@@ -376,6 +380,10 @@ public static Tensor fftfreq(long n, double d = 1.0, torch.ScalarType? dtype = n
376380
/// <param name="requires_grad">If autograd should record operations on the returned tensor.</param>
377381
public static Tensor rfftfreq(long n, double d = 1.0, torch.ScalarType? dtype = null, torch.Device device = null, bool requires_grad = false)
378382
{
383+
if (device is null)
384+
{
385+
device = get_default_device();
386+
}
379387
device = torch.InitializeDevice(device);
380388
if (!dtype.HasValue) {
381389
// Determine the element type dynamically.

src/TorchSharp/NN/Module.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,10 @@ protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Devic
10631063
if (!dtype.HasValue)
10641064
dtype = get_default_dtype();
10651065

1066-
if (device == null)
1067-
device = torch.CPU;
1066+
if (device is null)
1067+
{
1068+
device = get_default_device();
1069+
}
10681070

10691071
return (device, dtype.Value);
10701072
}

src/TorchSharp/Tensor/Factories/Tensor.Factories.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ public static partial class torch
2828
/// <param name="requires_grad"> If autograd should record operations on the returned tensor. Default: false.</param>
2929
public static Tensor arange(Scalar start, Scalar stop, Scalar step, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
3030
{
31+
if (device is null)
32+
{
33+
device = get_default_device();
34+
}
3135
device = InitializeDevice(device);
3236

3337
if (!dtype.HasValue) {
@@ -77,6 +81,10 @@ public static Tensor arange(Scalar stop, ScalarType? dtype = null, Device? devic
7781
/// </summary>
7882
public static Tensor eye(long rows, long columns = -1L, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
7983
{
84+
if (device is null)
85+
{
86+
device = get_default_device();
87+
}
8088
device = InitializeDevice(device);
8189
if (!dtype.HasValue) {
8290
// Determine the element type dynamically.
@@ -440,6 +448,10 @@ public static Tensor polar(Tensor abs, Tensor angle)
440448

441449
public static Tensor from_file(string filename, bool? shared = null, long? size = 0, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
442450
{
451+
if (device is null)
452+
{
453+
device = get_default_device();
454+
}
443455
device = InitializeDevice(device);
444456
if (!dtype.HasValue) {
445457
// Determine the element type dynamically.
@@ -456,6 +468,10 @@ public static Tensor from_file(string filename, bool? shared = null, long? size
456468
/// </summary>
457469
public static Tensor linspace(double start, double end, long steps, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
458470
{
471+
if (device is null)
472+
{
473+
device = get_default_device();
474+
}
459475
device = InitializeDevice(device);
460476
if (!dtype.HasValue) {
461477
// Determine the element type dynamically.
@@ -477,6 +493,10 @@ public static Tensor linspace(double start, double end, long steps, ScalarType?
477493
/// </summary>
478494
public static Tensor logspace(double start, double end, long steps, double @base = 10, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
479495
{
496+
if (device is null)
497+
{
498+
device = get_default_device();
499+
}
480500
device = InitializeDevice(device);
481501
if (!dtype.HasValue) {
482502
// Determine the element type dynamically.
@@ -579,6 +599,7 @@ private static void ValidateIntegerRange(long value, ScalarType dtype, string ar
579599

580600
private static ConcurrentDictionary<TorchSharp.PInvoke.GCHandleDeleter, TorchSharp.PInvoke.GCHandleDeleter> deleters;
581601
private static ScalarType default_dtype = ScalarType.Float32;
602+
private static Device default_device = CPU;
582603

583604
static torch()
584605
{

src/TorchSharp/Tensor/Factories/empty.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ public static Tensor empty(int dim0, int dim1, int dim2, int dim3, ScalarType? d
9898
/// </summary>
9999
public static Tensor empty_strided(long[] size, long[] strides, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
100100
{
101+
if (device is null)
102+
{
103+
device = get_default_device();
104+
}
101105
device = InitializeDevice(device);
102106
if (!dtype.HasValue) {
103107
// Determine the element type dynamically.
@@ -130,6 +134,10 @@ public static Tensor empty_strided(long[] size, long[] strides, ScalarType? dtyp
130134
/// </summary>
131135
private static Tensor _empty(ReadOnlySpan<long> size, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
132136
{
137+
if (device is null)
138+
{
139+
device = get_default_device();
140+
}
133141
device = InitializeDevice(device);
134142
if (!dtype.HasValue) {
135143
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/full.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ public static Tensor full(int dim0, int dim1, int dim2, int dim3, Scalar value,
9797
/// </summary>
9898
private static Tensor _full(ReadOnlySpan<long> size, Scalar value, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
9999
{
100+
if (device is null)
101+
{
102+
device = get_default_device();
103+
}
100104
device = InitializeDevice(device);
101105
if (!dtype.HasValue) {
102106
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/ones.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ public static Tensor ones(int dim0, int dim1, int dim2, int dim3, ScalarType? dt
9797
/// </summary>
9898
private static Tensor _ones(ReadOnlySpan<long> size, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
9999
{
100+
if (device is null)
101+
{
102+
device = get_default_device();
103+
}
100104
device = InitializeDevice(device);
101105
if (!dtype.HasValue) {
102106
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/rand.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ private static Tensor _rand(ReadOnlySpan<long> size, ScalarType? dtype = null, D
348348
if (dtype.HasValue && torch.is_integral(dtype.Value))
349349
throw new ArgumentException($"torch.rand() was passed a bad dtype: {dtype}. It must be floating point or complex.", "dtype");
350350

351+
if (device is null)
352+
{
353+
device = get_default_device();
354+
}
351355
device = InitializeDevice(device);
352356
if (!dtype.HasValue) {
353357
// Determine the element type dynamically.
@@ -476,6 +480,10 @@ private static Tensor _randn(ReadOnlySpan<long> size, ScalarType? dtype = null,
476480
if (dtype.HasValue && torch.is_integral(dtype.Value))
477481
throw new ArgumentException($"torch.randn() was passed a bad dtype: {dtype}. It must be floating point or complex.", "dtype");
478482

483+
if (device is null)
484+
{
485+
device = get_default_device();
486+
}
479487
device = InitializeDevice(device);
480488
if (!dtype.HasValue) {
481489
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/zeros.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ public static Tensor zeros(int dim0, int dim1, int dim2, int dim3, ScalarType? d
9494

9595
private static Tensor _zeros(ReadOnlySpan<long> size, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
9696
{
97+
if (device is null)
98+
{
99+
device = get_default_device();
100+
}
97101
device = InitializeDevice(device);
98102
if (!dtype.HasValue) {
99103
// Determine the element type dynamically.

src/TorchSharp/Tensor/torch.SpectralOps.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ public static Tensor istft(Tensor input, long n_fft, long hop_length = -1, long
5151
/// </summary>
5252
public static Tensor bartlett_window(long len, bool periodic = true, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
5353
{
54+
if (device is null)
55+
{
56+
device = get_default_device();
57+
}
5458
device = InitializeDevice(device);
5559
if (!dtype.HasValue) {
5660
// Determine the element type dynamically.
@@ -74,6 +78,10 @@ public static Tensor bartlett_window(long len, bool periodic = true, ScalarType?
7478
/// </summary>
7579
public static Tensor blackman_window(long len, bool periodic = true, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
7680
{
81+
if (device is null)
82+
{
83+
device = get_default_device();
84+
}
7785
device = InitializeDevice(device);
7886
if (!dtype.HasValue) {
7987
// Determine the element type dynamically.
@@ -98,6 +106,10 @@ public static Tensor blackman_window(long len, bool periodic = true, ScalarType?
98106
/// </summary>
99107
public static Tensor hamming_window(long len, bool periodic = true, float alpha = 0.54f, float beta = 0.46f, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
100108
{
109+
if (device is null)
110+
{
111+
device = get_default_device();
112+
}
101113
device = InitializeDevice(device);
102114
if (!dtype.HasValue) {
103115
// Determine the element type dynamically.
@@ -121,6 +133,10 @@ public static Tensor hamming_window(long len, bool periodic = true, float alpha
121133
/// </summary>
122134
public static Tensor hann_window(long len, bool periodic = true, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
123135
{
136+
if (device is null)
137+
{
138+
device = get_default_device();
139+
}
124140
device = InitializeDevice(device);
125141
if (!dtype.HasValue) {
126142
// Determine the element type dynamically.
@@ -144,6 +160,10 @@ public static Tensor hann_window(long len, bool periodic = true, ScalarType? dty
144160
/// </summary>
145161
public static Tensor kaiser_window(long len, bool periodic = true, float beta = 12.0f, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
146162
{
163+
if (device is null)
164+
{
165+
device = get_default_device();
166+
}
147167
device = InitializeDevice(device);
148168
if (!dtype.HasValue) {
149169
// Determine the element type dynamically.

src/TorchSharp/Tensor/torch.Tensors.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,23 @@ public static void set_default_dtype(ScalarType dtype)
6161
/// <returns></returns>
6262
[Pure]public static ScalarType get_default_dtype() => default_dtype;
6363

64+
[Pure]public static Device get_default_device() => default_device;
65+
6466
// https://pytorch.org/docs/stable/generated/torch.set_default_tensor_type
6567
public static void set_default_tensor_type(Tensor t) => set_default_dtype(t.dtype);
6668

69+
public static void set_default_device(Device device)
70+
{
71+
if (device == null)
72+
throw new ArgumentNullException(nameof(device));
73+
default_device = device;
74+
}
75+
76+
public static void set_default_device(string device)
77+
{
78+
set_default_device(new Device(device));
79+
}
80+
6781
// https://pytorch.org/docs/stable/generated/torch.numel
6882
/// <summary>
6983
/// Get the number of elements in the input tensor.

0 commit comments

Comments
 (0)