Skip to content

Commit fedb729

Browse files
Simplify initializing 'device' when it's passed as 'null' to tensor factories.
1 parent 17723e6 commit fedb729

File tree

10 files changed

+6
-80
lines changed

10 files changed

+6
-80
lines changed

src/TorchSharp/FFT.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,6 @@ public static Tensor ifftshift(Tensor input, long[] dim = null)
350350
/// <param name="requires_grad">If autograd should record operations on the returned tensor.</param>
351351
public static Tensor fftfreq(long n, double d = 1.0, torch.ScalarType? dtype = null, torch.Device device = null, bool requires_grad = false)
352352
{
353-
if (device is null)
354-
{
355-
device = get_default_device();
356-
}
357353
device = torch.InitializeDevice(device);
358354
if (!dtype.HasValue) {
359355
// Determine the element type dynamically.
@@ -380,10 +376,6 @@ public static Tensor fftfreq(long n, double d = 1.0, torch.ScalarType? dtype = n
380376
/// <param name="requires_grad">If autograd should record operations on the returned tensor.</param>
381377
public static Tensor rfftfreq(long n, double d = 1.0, torch.ScalarType? dtype = null, torch.Device device = null, bool requires_grad = false)
382378
{
383-
if (device is null)
384-
{
385-
device = get_default_device();
386-
}
387379
device = torch.InitializeDevice(device);
388380
if (!dtype.HasValue) {
389381
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/Tensor.Factories.cs

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ public static partial class torch
2828
/// <param name="requires_grad"> If autograd should record operations on the returned tensor. Default: false.</param>
2929
public static Tensor arange(Scalar start, Scalar stop, Scalar step, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
3030
{
31-
if (device is null)
32-
{
33-
device = get_default_device();
34-
}
3531
device = InitializeDevice(device);
3632

3733
if (!dtype.HasValue) {
@@ -81,10 +77,6 @@ public static Tensor arange(Scalar stop, ScalarType? dtype = null, Device? devic
8177
/// </summary>
8278
public static Tensor eye(long rows, long columns = -1L, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
8379
{
84-
if (device is null)
85-
{
86-
device = get_default_device();
87-
}
8880
device = InitializeDevice(device);
8981
if (!dtype.HasValue) {
9082
// Determine the element type dynamically.
@@ -448,10 +440,6 @@ public static Tensor polar(Tensor abs, Tensor angle)
448440

449441
public static Tensor from_file(string filename, bool? shared = null, long? size = 0, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
450442
{
451-
if (device is null)
452-
{
453-
device = get_default_device();
454-
}
455443
device = InitializeDevice(device);
456444
if (!dtype.HasValue) {
457445
// Determine the element type dynamically.
@@ -468,10 +456,6 @@ public static Tensor from_file(string filename, bool? shared = null, long? size
468456
/// </summary>
469457
public static Tensor linspace(double start, double end, long steps, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
470458
{
471-
if (device is null)
472-
{
473-
device = get_default_device();
474-
}
475459
device = InitializeDevice(device);
476460
if (!dtype.HasValue) {
477461
// Determine the element type dynamically.
@@ -493,10 +477,6 @@ public static Tensor linspace(double start, double end, long steps, ScalarType?
493477
/// </summary>
494478
public static Tensor logspace(double start, double end, long steps, double @base = 10, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
495479
{
496-
if (device is null)
497-
{
498-
device = get_default_device();
499-
}
500480
device = InitializeDevice(device);
501481
if (!dtype.HasValue) {
502482
// Determine the element type dynamically.
@@ -599,7 +579,7 @@ private static void ValidateIntegerRange(long value, ScalarType dtype, string ar
599579

600580
private static ConcurrentDictionary<TorchSharp.PInvoke.GCHandleDeleter, TorchSharp.PInvoke.GCHandleDeleter> deleters;
601581
private static ScalarType default_dtype = ScalarType.Float32;
602-
private static Device default_device = CPU;
582+
private static Device default_device = new Device(DeviceType.CPU, -1);
603583

604584
static torch()
605585
{

src/TorchSharp/Tensor/Factories/empty.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,6 @@ public static Tensor empty(int dim0, int dim1, int dim2, int dim3, ScalarType? d
9898
/// </summary>
9999
public static Tensor empty_strided(long[] size, long[] strides, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
100100
{
101-
if (device is null)
102-
{
103-
device = get_default_device();
104-
}
105101
device = InitializeDevice(device);
106102
if (!dtype.HasValue) {
107103
// Determine the element type dynamically.
@@ -134,10 +130,6 @@ public static Tensor empty_strided(long[] size, long[] strides, ScalarType? dtyp
134130
/// </summary>
135131
private static Tensor _empty(ReadOnlySpan<long> size, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
136132
{
137-
if (device is null)
138-
{
139-
device = get_default_device();
140-
}
141133
device = InitializeDevice(device);
142134
if (!dtype.HasValue) {
143135
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/full.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,6 @@ public static Tensor full(int dim0, int dim1, int dim2, int dim3, Scalar value,
9797
/// </summary>
9898
private static Tensor _full(ReadOnlySpan<long> size, Scalar value, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
9999
{
100-
if (device is null)
101-
{
102-
device = get_default_device();
103-
}
104100
device = InitializeDevice(device);
105101
if (!dtype.HasValue) {
106102
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/ones.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,6 @@ public static Tensor ones(int dim0, int dim1, int dim2, int dim3, ScalarType? dt
9797
/// </summary>
9898
private static Tensor _ones(ReadOnlySpan<long> size, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
9999
{
100-
if (device is null)
101-
{
102-
device = get_default_device();
103-
}
104100
device = InitializeDevice(device);
105101
if (!dtype.HasValue) {
106102
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/rand.cs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,10 +348,6 @@ private static Tensor _rand(ReadOnlySpan<long> size, ScalarType? dtype = null, D
348348
if (dtype.HasValue && torch.is_integral(dtype.Value))
349349
throw new ArgumentException($"torch.rand() was passed a bad dtype: {dtype}. It must be floating point or complex.", "dtype");
350350

351-
if (device is null)
352-
{
353-
device = get_default_device();
354-
}
355351
device = InitializeDevice(device);
356352
if (!dtype.HasValue) {
357353
// Determine the element type dynamically.
@@ -480,10 +476,6 @@ private static Tensor _randn(ReadOnlySpan<long> size, ScalarType? dtype = null,
480476
if (dtype.HasValue && torch.is_integral(dtype.Value))
481477
throw new ArgumentException($"torch.randn() was passed a bad dtype: {dtype}. It must be floating point or complex.", "dtype");
482478

483-
if (device is null)
484-
{
485-
device = get_default_device();
486-
}
487479
device = InitializeDevice(device);
488480
if (!dtype.HasValue) {
489481
// Determine the element type dynamically.

src/TorchSharp/Tensor/Factories/zeros.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,6 @@ public static Tensor zeros(int dim0, int dim1, int dim2, int dim3, ScalarType? d
9494

9595
private static Tensor _zeros(ReadOnlySpan<long> size, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
9696
{
97-
if (device is null)
98-
{
99-
device = get_default_device();
100-
}
10197
device = InitializeDevice(device);
10298
if (!dtype.HasValue) {
10399
// Determine the element type dynamically.

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ public Tensor to(DeviceType deviceType, int deviceIndex = -1, bool copy = false,
880880
/// <param name="non_blocking">Try to convert asynchronously with respect to the host if possible, e.g., converting a CPU Tensor with pinned memory to a CUDA Tensor.</param>
881881
public Tensor to(ScalarType type, torch.Device device, bool copy = false, bool disposeAfter = false, bool non_blocking = false)
882882
{
883-
torch.InitializeDevice(device);
883+
device = torch.InitializeDevice(device);
884884
var res = NativeMethods.THSTensor_to_type_and_device(Handle, (sbyte)type, (int)device.type, device.index, copy, non_blocking);
885885
if (res == IntPtr.Zero)
886886
CheckForErrors();

src/TorchSharp/Tensor/torch.SpectralOps.cs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ public static Tensor istft(Tensor input, long n_fft, long hop_length = -1, long
5151
/// </summary>
5252
public static Tensor bartlett_window(long len, bool periodic = true, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
5353
{
54-
if (device is null)
55-
{
56-
device = get_default_device();
57-
}
5854
device = InitializeDevice(device);
5955
if (!dtype.HasValue) {
6056
// Determine the element type dynamically.
@@ -78,10 +74,6 @@ public static Tensor bartlett_window(long len, bool periodic = true, ScalarType?
7874
/// </summary>
7975
public static Tensor blackman_window(long len, bool periodic = true, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
8076
{
81-
if (device is null)
82-
{
83-
device = get_default_device();
84-
}
8577
device = InitializeDevice(device);
8678
if (!dtype.HasValue) {
8779
// Determine the element type dynamically.
@@ -106,10 +98,6 @@ public static Tensor blackman_window(long len, bool periodic = true, ScalarType?
10698
/// </summary>
10799
public static Tensor hamming_window(long len, bool periodic = true, float alpha = 0.54f, float beta = 0.46f, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
108100
{
109-
if (device is null)
110-
{
111-
device = get_default_device();
112-
}
113101
device = InitializeDevice(device);
114102
if (!dtype.HasValue) {
115103
// Determine the element type dynamically.
@@ -133,10 +121,6 @@ public static Tensor hamming_window(long len, bool periodic = true, float alpha
133121
/// </summary>
134122
public static Tensor hann_window(long len, bool periodic = true, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
135123
{
136-
if (device is null)
137-
{
138-
device = get_default_device();
139-
}
140124
device = InitializeDevice(device);
141125
if (!dtype.HasValue) {
142126
// Determine the element type dynamically.
@@ -160,10 +144,6 @@ public static Tensor hann_window(long len, bool periodic = true, ScalarType? dty
160144
/// </summary>
161145
public static Tensor kaiser_window(long len, bool periodic = true, float beta = 12.0f, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
162146
{
163-
if (device is null)
164-
{
165-
device = get_default_device();
166-
}
167147
device = InitializeDevice(device);
168148
if (!dtype.HasValue) {
169149
// Determine the element type dynamically.

src/TorchSharp/Torch.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,10 @@ public static void InitializeDeviceType(DeviceType deviceType)
298298

299299
public static Device InitializeDevice(Device? device)
300300
{
301-
if (device == null)
302-
device = new Device(DeviceType.CPU, -1);
301+
if (device is null)
302+
{
303+
device = get_default_device();
304+
}
303305
InitializeDeviceType(device.type);
304306
return device;
305307
}

0 commit comments

Comments
 (0)