Skip to content

Commit bd1f04e

Browse files
Merge pull request #1331 from NiklasGustafsson/missing
Support creating tensors from a Memory<T> instance.
2 parents 5109469 + ec6b818 commit bd1f04e

File tree

13 files changed

+193
-15
lines changed

13 files changed

+193
-15
lines changed

RELEASENOTES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the
66

77
__Breaking Changes__:
88

9+
When creating a tensor from a 1-D array, and passing in a shape, there is now an ambiguity between the IList and Memory overloads of `torch.tensor()`. The ambiguity is resolved by removing the `dimensions` argument if it is redundant, or by an explicit cast to IList if it is not.
10+
911
__API Changes__:
1012

13+
#1326 Allow arrays used to create tensors to be larger than the tensor. Create tensors from a Memory instance<br/>
14+
1115
__Bug Fixes__:
1216

1317

src/TorchSharp/Tensor/Factories/Tensor.Factories.cs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public static Tensor normal(double mean, double std, ReadOnlySpan<long> size, Sc
134134
/// <returns>A constructed tensor with elements of `dtype`</returns>
135135
/// <exception cref="ArgumentException"></exception>
136136
private static Tensor _tensor_generic(Array rawArray, ReadOnlySpan<long> dimensions, sbyte origType, ScalarType? dtype, Device? device, bool requires_grad, bool clone = true, string[]? names = null)
137-
{
137+
{
138138
{
139139
// Validate the sizes before handing over storage to native code...
140140
var prod = 1L;
@@ -164,6 +164,68 @@ private static Tensor _tensor_generic(Array rawArray, ReadOnlySpan<long> dimensi
164164
dtype = dtype.HasValue ? dtype : (ScalarType)origType;
165165

166166
unsafe {
167+
void *ptr = null;
168+
IntPtr iPtr = (IntPtr)ptr;
169+
170+
fixed (long* shape = dimensions) {
171+
var handle = THSTensor_new(dataArrayAddr, deleter, (IntPtr)shape, dimensions.Length, origType, (sbyte)dtype.Value, (int)device.type, device.index, requires_grad);
172+
173+
if (handle == IntPtr.Zero) {
174+
GC.Collect();
175+
GC.WaitForPendingFinalizers();
176+
handle = THSTensor_new(dataArrayAddr, deleter, (IntPtr)shape, dimensions.Length, origType, (sbyte)dtype.Value, (int)device.type, device.index, requires_grad);
177+
}
178+
179+
if (handle == IntPtr.Zero) { CheckForErrors(); }
180+
var tensor = new Tensor(handle);
181+
182+
if (names != null && names.Length > 0) {
183+
tensor.rename_(names);
184+
}
185+
186+
return tensor;
187+
}
188+
}
189+
}
190+
191+
private static Tensor _tensor_generic<T>(Memory<T> rawArray, ReadOnlySpan<long> dimensions, sbyte origType, ScalarType? dtype, Device? device, bool requires_grad, bool clone = true, string[]? names = null)
192+
{
193+
if (clone)
194+
{
195+
return _tensor_generic(rawArray.ToArray(), dimensions, origType, dtype, device, requires_grad, false, names);
196+
}
197+
198+
{
199+
// Validate the sizes before handing over storage to native code...
200+
var prod = 1L;
201+
foreach (var sz in dimensions) prod *= sz;
202+
203+
if (origType == (sbyte)ScalarType.ComplexFloat32)
204+
prod *= 2;
205+
206+
if (prod > rawArray.Length)
207+
throw new ArgumentException($"mismatched total size creating a tensor from an array: {prod} vs. {rawArray.Length}");
208+
}
209+
210+
device = InitializeDevice(device);
211+
212+
dtype = dtype.HasValue ? dtype : (ScalarType)origType;
213+
214+
unsafe {
215+
216+
var dataHandle = rawArray.Pin();
217+
var dataArrayAddr = (IntPtr)dataHandle.Pointer;
218+
219+
TorchSharp.PInvoke.GCHandleDeleter deleter = null!;
220+
deleter = new TorchSharp.PInvoke.GCHandleDeleter((IntPtr ptr) => {
221+
dataHandle.Dispose();
222+
deleters.TryRemove(deleter, out deleter!);
223+
});
224+
deleters.TryAdd(deleter, deleter); // keep the delegate alive
225+
226+
void *ptr = null;
227+
IntPtr iPtr = (IntPtr)ptr;
228+
167229
fixed (long* shape = dimensions) {
168230
var handle = THSTensor_new(dataArrayAddr, deleter, (IntPtr)shape, dimensions.Length, origType, (sbyte)dtype.Value, (int)device.type, device.index, requires_grad);
169231

src/TorchSharp/Tensor/Factories/tensor_Complex.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,5 +271,14 @@ public static Tensor tensor(System.Numerics.Complex[,,,] rawArray, ScalarType? d
271271
{
272272
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.ComplexFloat64, dtype, device, requires_grad, names: names);
273273
}
274+
275+
/// <summary>
276+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
277+
/// </summary>
278+
[Pure]
279+
public static Tensor tensor(Memory<System.Numerics.Complex> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
280+
{
281+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.ComplexFloat64, dtype, device, requires_grad, names: names);
282+
}
274283
}
275284
}

src/TorchSharp/Tensor/Factories/tensor_Half.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ public static Tensor tensor(Half[,,,] rawArray, ScalarType? dtype = null, Device
113113
{
114114
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.Float16, dtype, device, requires_grad, names: names);
115115
}
116+
117+
/// <summary>
118+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
119+
/// </summary>
120+
[Pure]
121+
public static Tensor tensor(Memory<Half> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
122+
{
123+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.Float16, dtype, device, requires_grad, names: names);
124+
}
116125
#endif
117126
}
118127
}

src/TorchSharp/Tensor/Factories/tensor_byte.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ public static Tensor tensor(byte[,,,] rawArray, ScalarType? dtype = null, Device
125125
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.Byte, dtype, device, requires_grad, names: names);
126126
}
127127

128+
/// <summary>
129+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
130+
/// </summary>
131+
[Pure]
132+
public static Tensor tensor(Memory<byte> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
133+
{
134+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.Byte, dtype, device, requires_grad, names: names);
135+
}
136+
128137
/// <summary>
129138
/// Cast a tensor to a byte tensor.
130139
/// </summary>

src/TorchSharp/Tensor/Factories/tensor_double.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,15 @@ public static Tensor tensor(double[,,,] rawArray, ScalarType? dtype = null, Devi
150150
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.Float64, dtype, device, requires_grad, names: names);
151151
}
152152

153+
/// <summary>
154+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
155+
/// </summary>
156+
[Pure]
157+
public static Tensor tensor(Memory<double> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
158+
{
159+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.Float64, dtype, device, requires_grad, names: names);
160+
}
161+
153162
/// <summary>
154163
/// Cast a tensor to a 64-bit floating point tensor.
155164
/// </summary>

src/TorchSharp/Tensor/Factories/tensor_float.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ public static Tensor tensor(float[,,,] rawArray, ScalarType? dtype = null, Devic
159159
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.Float32, dtype, device, requires_grad, names: names);
160160
}
161161

162+
/// <summary>
163+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
164+
/// </summary>
165+
[Pure]
166+
public static Tensor tensor(Memory<float> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
167+
{
168+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.Float32, dtype, device, requires_grad, names: names);
169+
}
170+
162171
/// <summary>
163172
/// Cast a tensor to a 32-bit floating point tensor.
164173
/// </summary>

src/TorchSharp/Tensor/Factories/tensor_int.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,15 @@ public static Tensor tensor(int[,,,] rawArray, ScalarType? dtype = null, Device?
125125
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.Int32, dtype, device, requires_grad, names: names);
126126
}
127127

128+
/// <summary>
129+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
130+
/// </summary>
131+
[Pure]
132+
public static Tensor tensor(Memory<int> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
133+
{
134+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.Int32, dtype, device, requires_grad, names: names);
135+
}
136+
128137
/// <summary>
129138
/// Cast a tensor to a 32-bit integer tensor.
130139
/// </summary>

src/TorchSharp/Tensor/Factories/tensor_long.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@ public static Tensor tensor(long[,,,] rawArray, ScalarType? dtype = null, Device
134134
return _tensor_i64(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, dtype, device, requires_grad, names: names);
135135
}
136136

137+
/// <summary>
138+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
139+
/// </summary>
140+
[Pure]
141+
public static Tensor tensor(Memory<long> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
142+
{
143+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.Int64, dtype, device, requires_grad, names: names);
144+
}
145+
137146
/// <summary>
138147
/// Cast a tensor to a 64-bit integer tensor.
139148
/// </summary>

src/TorchSharp/Tensor/Factories/tensor_sbyte.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,14 @@ public static Tensor tensor(sbyte[,,,] rawArray, ScalarType? dtype = null, Devic
124124
{
125125
return _tensor_generic(rawArray, stackalloc long[] { rawArray.GetLongLength(0), rawArray.GetLongLength(1), rawArray.GetLongLength(2), rawArray.GetLongLength(3) }, (sbyte)ScalarType.Int8, dtype, device, requires_grad, names: names);
126126
}
127+
128+
/// <summary>
129+
/// Create a tensor from an array of values, shaping it based on the shape passed in.
130+
/// </summary>
131+
[Pure]
132+
public static Tensor tensor(Memory<sbyte> rawArray, ReadOnlySpan<long> dimensions, ScalarType? dtype = null, Device? device = null, bool requires_grad = false, string[]? names = null)
133+
{
134+
return _tensor_generic(rawArray, dimensions, (sbyte)ScalarType.Int8, dtype, device, requires_grad, names: names);
135+
}
127136
}
128137
}

0 commit comments

Comments
 (0)