Skip to content

Commit ccb4029

Browse files
Relax size constraint when creating a tensor from an array.
1 parent d35ee8d commit ccb4029

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/TorchSharp/Tensor/Factories/Tensor.Factories.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ private static Tensor _tensor_generic(Array rawArray, ReadOnlySpan<long> dimensi
143143
if (origType == (sbyte)ScalarType.ComplexFloat32)
144144
prod *= 2;
145145

146-
if (prod != rawArray.LongLength)
146+
if (prod > rawArray.LongLength)
147147
throw new ArgumentException($"mismatched total size creating a tensor from an array: {prod} vs. {rawArray.LongLength}");
148148
}
149149

test/TorchSharpTest/TestTorchTensor.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,16 @@ public void FromArrayFactory()
783783
() => Assert.Equal(ScalarType.Bool, t.dtype));
784784
}
785785

786+
{
787+
var array = new bool[18]; // Too long, on purpose
788+
var t = torch.tensor(array, new long[] { 8 }, device: device);
789+
Assert.Multiple(
790+
() => Assert.Equal(8, t.NumberOfElements),
791+
() => Assert.Equal(device.type, t.device_type),
792+
() => Assert.Equal(1, t.ndim),
793+
() => Assert.Equal(ScalarType.Bool, t.dtype));
794+
}
795+
786796
{
787797
var array = new int[8];
788798
var t = torch.tensor(array, device: device);
@@ -801,6 +811,16 @@ public void FromArrayFactory()
801811
() => Assert.Equal(ScalarType.Float32, t.dtype));
802812
}
803813

814+
{
815+
var array = new float[18]; // Too long, on purpose
816+
var t = torch.tensor(array, new long[] { 8 }, device: device);
817+
Assert.Multiple(
818+
() => Assert.Equal(8, t.NumberOfElements),
819+
() => Assert.Equal(device.type, t.device_type),
820+
() => Assert.Equal(1, t.ndim),
821+
() => Assert.Equal(ScalarType.Float32, t.dtype));
822+
}
823+
804824
{
805825
var array = new double[1, 2];
806826
var t = torch.from_array(array, device: device);

0 commit comments

Comments
 (0)