Skip to content

Commit e34a760

Browse files
committed
Added ability to return jit moduls' input and output types.
1 parent d0890b6 commit e34a760

File tree

7 files changed

+287
-20
lines changed

7 files changed

+287
-20
lines changed

Test/TorchSharp.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using TorchSharp.JIT;
34
using TorchSharp.Tensor;
45

56
namespace TorchSharp.Test
@@ -88,6 +89,38 @@ public void ScoreModel()
8889
Assert.IsNotNull(result);
8990
}
9091

92+
[TestMethod]
93+
public void ScoreModelCheckInput()
94+
{
95+
var module = JIT.Module.Load(@"E:\Source\Repos\libtorch\model.pt");
96+
Assert.IsNotNull(module);
97+
98+
var num = module.GetNumberOfInputs();
99+
100+
for (int i = 0; i < num; i++)
101+
{
102+
var type = module.GetInputType(i);
103+
104+
Assert.IsNotNull(type as DynamicType);
105+
}
106+
}
107+
108+
[TestMethod]
109+
public void ScoreModelCheckOutput()
110+
{
111+
var module = JIT.Module.Load(@"E:\Source\Repos\libtorch\model.pt");
112+
Assert.IsNotNull(module);
113+
114+
var num = module.GetNumberOfOutputs();
115+
116+
for (int i = 0; i < num; i++)
117+
{
118+
var type = module.GetOutputType(i);
119+
120+
Assert.IsNotNull(type as DynamicType);
121+
}
122+
}
123+
91124
[TestMethod]
92125
public void CreateLinear()
93126
{

TorchSharp/Generated/TorchTensor.generated.cs

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
using System;
1+

2+
3+
4+
5+
using System;
26
using System.Linq;
37
using System.Runtime.InteropServices;
48
using System.Text;
59

610
namespace TorchSharp.Tensor {
711

12+
813
/// <summary>
914
/// Tensor of type Byte.
1015
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -179,7 +184,7 @@ static public ITorchTensor<byte> Ones(long[] size, string device = "cpu", bool r
179184
{
180185
fixed (long* psizes = size)
181186
{
182-
return new ByteTensor (THS_ones ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Byte, device, requiresGrad));
187+
return new ByteTensor (THS_ones ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Byte, device, requiresGrad));
183188
}
184189
}
185190
}
@@ -196,7 +201,7 @@ static public ITorchTensor<byte> RandomN(long[] size, string device = "cpu", boo
196201
{
197202
fixed (long* psizes = size)
198203
{
199-
return new ByteTensor (THS_randn ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Byte, device, requiresGrad));
204+
return new ByteTensor (THS_randn ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Byte, device, requiresGrad));
200205
}
201206
}
202207
}
@@ -292,6 +297,7 @@ public override string ToString()
292297
return sb.ToString();
293298
}
294299
}
300+
295301
/// <summary>
296302
/// Tensor of type Short.
297303
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -466,7 +472,7 @@ static public ITorchTensor<short> Ones(long[] size, string device = "cpu", bool
466472
{
467473
fixed (long* psizes = size)
468474
{
469-
return new ShortTensor (THS_ones ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Short, device, requiresGrad));
475+
return new ShortTensor (THS_ones ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Short, device, requiresGrad));
470476
}
471477
}
472478
}
@@ -483,7 +489,7 @@ static public ITorchTensor<short> RandomN(long[] size, string device = "cpu", bo
483489
{
484490
fixed (long* psizes = size)
485491
{
486-
return new ShortTensor (THS_randn ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Short, device, requiresGrad));
492+
return new ShortTensor (THS_randn ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Short, device, requiresGrad));
487493
}
488494
}
489495
}
@@ -579,6 +585,7 @@ public override string ToString()
579585
return sb.ToString();
580586
}
581587
}
588+
582589
/// <summary>
583590
/// Tensor of type Int.
584591
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -753,7 +760,7 @@ static public ITorchTensor<int> Ones(long[] size, string device = "cpu", bool re
753760
{
754761
fixed (long* psizes = size)
755762
{
756-
return new IntTensor (THS_ones ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Int, device, requiresGrad));
763+
return new IntTensor (THS_ones ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Int, device, requiresGrad));
757764
}
758765
}
759766
}
@@ -770,7 +777,7 @@ static public ITorchTensor<int> RandomN(long[] size, string device = "cpu", bool
770777
{
771778
fixed (long* psizes = size)
772779
{
773-
return new IntTensor (THS_randn ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Int, device, requiresGrad));
780+
return new IntTensor (THS_randn ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Int, device, requiresGrad));
774781
}
775782
}
776783
}
@@ -866,6 +873,7 @@ public override string ToString()
866873
return sb.ToString();
867874
}
868875
}
876+
869877
/// <summary>
870878
/// Tensor of type Long.
871879
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -1040,7 +1048,7 @@ static public ITorchTensor<long> Ones(long[] size, string device = "cpu", bool r
10401048
{
10411049
fixed (long* psizes = size)
10421050
{
1043-
return new LongTensor (THS_ones ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Long, device, requiresGrad));
1051+
return new LongTensor (THS_ones ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Long, device, requiresGrad));
10441052
}
10451053
}
10461054
}
@@ -1057,7 +1065,7 @@ static public ITorchTensor<long> RandomN(long[] size, string device = "cpu", boo
10571065
{
10581066
fixed (long* psizes = size)
10591067
{
1060-
return new LongTensor (THS_randn ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Long, device, requiresGrad));
1068+
return new LongTensor (THS_randn ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Long, device, requiresGrad));
10611069
}
10621070
}
10631071
}
@@ -1153,6 +1161,7 @@ public override string ToString()
11531161
return sb.ToString();
11541162
}
11551163
}
1164+
11561165
/// <summary>
11571166
/// Tensor of type Double.
11581167
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -1327,7 +1336,7 @@ static public ITorchTensor<double> Ones(long[] size, string device = "cpu", bool
13271336
{
13281337
fixed (long* psizes = size)
13291338
{
1330-
return new DoubleTensor (THS_ones ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Double, device, requiresGrad));
1339+
return new DoubleTensor (THS_ones ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Double, device, requiresGrad));
13311340
}
13321341
}
13331342
}
@@ -1344,7 +1353,7 @@ static public ITorchTensor<double> RandomN(long[] size, string device = "cpu", b
13441353
{
13451354
fixed (long* psizes = size)
13461355
{
1347-
return new DoubleTensor (THS_randn ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Double, device, requiresGrad));
1356+
return new DoubleTensor (THS_randn ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Double, device, requiresGrad));
13481357
}
13491358
}
13501359
}
@@ -1440,6 +1449,7 @@ public override string ToString()
14401449
return sb.ToString();
14411450
}
14421451
}
1452+
14431453
/// <summary>
14441454
/// Tensor of type Float.
14451455
/// This tensor maps to a Torch variable (see torch/csrc/autograd/variable.h).
@@ -1614,7 +1624,7 @@ static public ITorchTensor<float> Ones(long[] size, string device = "cpu", bool
16141624
{
16151625
fixed (long* psizes = size)
16161626
{
1617-
return new FloatTensor (THS_ones ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Float, device, requiresGrad));
1627+
return new FloatTensor (THS_ones ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Float, device, requiresGrad));
16181628
}
16191629
}
16201630
}
@@ -1631,7 +1641,7 @@ static public ITorchTensor<float> RandomN(long[] size, string device = "cpu", bo
16311641
{
16321642
fixed (long* psizes = size)
16331643
{
1634-
return new FloatTensor (THS_randn ((IntPtr)psizes, size.Length, (short)ATenScalarMapping.Float, device, requiresGrad));
1644+
return new FloatTensor (THS_randn ((IntPtr)psizes, size.Length, (sbyte)ATenScalarMapping.Float, device, requiresGrad));
16351645
}
16361646
}
16371647
}
@@ -1727,8 +1737,9 @@ public override string ToString()
17271737
return sb.ToString();
17281738
}
17291739
}
1740+
17301741

1731-
internal enum ATenScalarMapping : short
1742+
public enum ATenScalarMapping : short
17321743
{
17331744
Byte = 0,
17341745
Short = 2,
@@ -1744,30 +1755,37 @@ internal static ITorchTensor<T> ToTorchTensor<T>(this IntPtr rawTensor)
17441755
{
17451756
switch (true)
17461757
{
1758+
17471759
case bool _ when typeof(T) == typeof(byte):
17481760
{
17491761
return new ByteTensor(rawTensor) as ITorchTensor<T>;
17501762
}
1763+
17511764
case bool _ when typeof(T) == typeof(short):
17521765
{
17531766
return new ShortTensor(rawTensor) as ITorchTensor<T>;
17541767
}
1768+
17551769
case bool _ when typeof(T) == typeof(int):
17561770
{
17571771
return new IntTensor(rawTensor) as ITorchTensor<T>;
17581772
}
1773+
17591774
case bool _ when typeof(T) == typeof(long):
17601775
{
17611776
return new LongTensor(rawTensor) as ITorchTensor<T>;
17621777
}
1778+
17631779
case bool _ when typeof(T) == typeof(double):
17641780
{
17651781
return new DoubleTensor(rawTensor) as ITorchTensor<T>;
17661782
}
1783+
17671784
case bool _ when typeof(T) == typeof(float):
17681785
{
17691786
return new FloatTensor(rawTensor) as ITorchTensor<T>;
17701787
}
1788+
17711789
default: throw new NotImplementedException($"Creating tensor of type {typeof(T)} is not supported.");
17721790
}
17731791
}

TorchSharp/Generated/TorchTensor.tt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ foreach (var type in TorchTypeDef.Types) {
186186
{
187187
fixed (long* psizes = size)
188188
{
189-
return new <#=type.Name#>Tensor (THS_ones ((<#=type.Ptr#>)psizes, size.Length, (short)ATenScalarMapping.<#=type.Name#>, device, requiresGrad));
189+
return new <#=type.Name#>Tensor (THS_ones ((<#=type.Ptr#>)psizes, size.Length, (sbyte)ATenScalarMapping.<#=type.Name#>, device, requiresGrad));
190190
}
191191
}
192192
}
@@ -203,7 +203,7 @@ foreach (var type in TorchTypeDef.Types) {
203203
{
204204
fixed (long* psizes = size)
205205
{
206-
return new <#=type.Name#>Tensor (THS_randn ((<#=type.Ptr#>)psizes, size.Length, (short)ATenScalarMapping.<#=type.Name#>, device, requiresGrad));
206+
return new <#=type.Name#>Tensor (THS_randn ((<#=type.Ptr#>)psizes, size.Length, (sbyte)ATenScalarMapping.<#=type.Name#>, device, requiresGrad));
207207
}
208208
}
209209
}
@@ -301,7 +301,7 @@ foreach (var type in TorchTypeDef.Types) {
301301
}
302302
<# } #>
303303

304-
internal enum ATenScalarMapping : short
304+
public enum ATenScalarMapping : short
305305
{
306306
Byte = 0,
307307
Short = 2,

TorchSharp/JIT/Module.cs

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ internal HType() : base(IntPtr.Zero, true)
2424
}
2525

2626
[DllImport("LibTorchSharp")]
27-
extern static void NN_JIT_Dispose(HType handle);
27+
extern static void JIT_Module_Dispose(HType handle);
2828

2929
protected override bool ReleaseHandle()
3030
{
31-
NN_JIT_Dispose(this);
31+
JIT_Module_Dispose(this);
3232
return true;
3333
}
3434

@@ -88,7 +88,7 @@ static public Module Load(string filename)
8888
[DllImport("LibTorchSharp")]
8989
extern static string JIT_getModuleName(HType module, int index);
9090

91-
public virtual string[] GetSubModulesNames()
91+
public string[] GetSubModulesNames()
9292
{
9393
var numModules = JIT_getNumModules(handle);
9494
string[] result = new string[numModules];
@@ -101,6 +101,55 @@ public virtual string[] GetSubModulesNames()
101101
return result;
102102
}
103103

104+
[DllImport("LibTorchSharp")]
105+
extern static int JIT_getNumberOfInputs(HType module);
106+
107+
public int GetNumberOfInputs()
108+
{
109+
return JIT_getNumberOfInputs(handle);
110+
}
111+
112+
[DllImport("LibTorchSharp")]
113+
extern static int JIT_getNumberOfOutputs(HType module);
114+
115+
public int GetNumberOfOutputs()
116+
{
117+
return JIT_getNumberOfOutputs(handle);
118+
}
119+
120+
[DllImport("LibTorchSharp")]
121+
extern static IntPtr JIT_getInputType(HType module, int index);
122+
123+
public Type GetInputType(int index)
124+
{
125+
var type = new Type(JIT_getInputType(handle, index));
126+
127+
return GetType(type);
128+
}
129+
130+
[DllImport("LibTorchSharp")]
131+
extern static IntPtr JIT_getOutputType(HType module, int index);
132+
133+
public Type GetOutputType(int index)
134+
{
135+
var type = new Type(JIT_getOutputType(handle, index));
136+
137+
return GetType(type);
138+
}
139+
140+
private Type GetType(Type type)
141+
{
142+
switch (type.Kind)
143+
{
144+
case Type.TypeKind.DynamicType:
145+
return type.AsDynamicType();
146+
case Type.TypeKind.TensorType:
147+
return type.AsDynamicType();
148+
default:
149+
return type;
150+
}
151+
}
152+
104153
[DllImport("LibTorchSharp")]
105154
extern static IntPtr JIT_forward(Module.HType module, IntPtr tensor);
106155

TorchSharp/JIT/Type/DynamicType .cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
3+
namespace TorchSharp.JIT
4+
{
5+
public sealed class DynamicType : Type
6+
{
7+
internal DynamicType(IntPtr handle) : base(handle)
8+
{
9+
this.handle = new HType(handle, true);
10+
}
11+
12+
internal DynamicType(Type type) : base()
13+
{
14+
handle = type.handle;
15+
type.handle = new HType(IntPtr.Zero, true);
16+
type.Dispose();
17+
}
18+
}
19+
}

0 commit comments

Comments
 (0)