Skip to content

Commit ceb9e62

Browse files
committed
Reverse complex number delegation order to prevent extra heap allocation
1 parent 09e0def commit ceb9e62

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

src/TorchSharp/Tensor/Factories/tensor_double.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,19 @@ public static Tensor tensor(double scalar, ScalarType? dtype = null, Device? dev
2626
/// </summary>
2727
public static Tensor tensor((double Real, double Imaginary) scalar, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
2828
{
29-
device = InitializeDevice(device);
30-
var handle = THSTensor_newComplexFloat64Scalar(scalar.Real, scalar.Imaginary, (int)device.type, device.index, requires_grad);
31-
if (handle == IntPtr.Zero) { CheckForErrors(); }
32-
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
29+
return tensor(scalar.Real, scalar.Imaginary, dtype, device, requires_grad);
3330
}
3431

3532
/// <summary>
3633
/// Create a scalar complex number tensor from independent real and imaginary components
3734
/// </summary>
3835
public static Tensor tensor(double real, double imaginary, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
3936
{
40-
return tensor((real, imaginary), dtype, device, requires_grad);
37+
38+
device = InitializeDevice(device);
39+
var handle = THSTensor_newComplexFloat64Scalar(real, imaginary, (int)device.type, device.index, requires_grad);
40+
if (handle == IntPtr.Zero) { CheckForErrors(); }
41+
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
4142
}
4243

4344
/// <summary>

src/TorchSharp/Tensor/Factories/tensor_float.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,18 @@ public static Tensor tensor(float scalar, Device? device = null, bool requires_g
2626
/// </summary>
2727
public static Tensor tensor(float real, float imaginary, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
2828
{
29-
return tensor((real, imaginary), dtype: dtype, device: device);
29+
device = InitializeDevice(device);
30+
var handle = THSTensor_newComplexFloat32Scalar(real, imaginary, (int)device.type, device.index, requires_grad);
31+
if (handle == IntPtr.Zero) { CheckForErrors(); }
32+
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
3033
}
3134

3235
/// <summary>
3336
/// Create a scalar complex number tensor from a tuple of (real, imaginary)
3437
/// </summary>
3538
public static Tensor tensor((float Real, float Imaginary) scalar, ScalarType? dtype = null, Device? device = null, bool requires_grad = false)
3639
{
37-
device = InitializeDevice(device);
38-
var handle = THSTensor_newComplexFloat32Scalar(scalar.Real, scalar.Imaginary, (int)device.type, device.index, requires_grad);
39-
if (handle == IntPtr.Zero) { CheckForErrors(); }
40-
return InstantiateTensorWithLeakSafeTypeChange(handle, dtype);
40+
return tensor(scalar.Real, scalar.Imaginary, dtype: dtype, device: device);
4141
}
4242

4343
/// <summary>

0 commit comments

Comments
 (0)