Skip to content

Commit f8230ab

Browse files
authored
Merge pull request #16 from interesaaat/LibTorchSharpFirstTest
Added squeeze, index_select, and div by int
2 parents 0dc54ba + 48fea15 commit f8230ab

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

TorchSharp/Tensor/TorchTensor.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,14 @@ public TorchTensor Grad()
293293
return new TorchTensor(THSTensor_grad(handle));
294294
}
295295

296+
[DllImport("libTorchSharp")]
297+
extern static IntPtr THSTensor_index_select(IntPtr src, long dimension, IntPtr index);
298+
299+
public TorchTensor IndexSelect(long dimension, TorchTensor index)
300+
{
301+
return new TorchTensor(THSTensor_index_select(handle, dimension, index.Handle));
302+
}
303+
296304
[DllImport("libTorchSharp")]
297305
extern static IntPtr THSTensor_reshape(IntPtr src, IntPtr shape, int length);
298306

@@ -307,6 +315,14 @@ public TorchTensor Reshape(params long[] shape)
307315
}
308316
}
309317

318+
[DllImport("libTorchSharp")]
319+
extern static IntPtr THSTensor_squeeze(IntPtr src, long dimension);
320+
321+
public TorchTensor Squeeze(long dimension)
322+
{
323+
return new TorchTensor(THSTensor_squeeze(handle, dimension));
324+
}
325+
310326
[DllImport("libTorchSharp")]
311327
extern static IntPtr THSTensor_t(IntPtr src);
312328

@@ -417,6 +433,14 @@ public void DivInPlace(TorchTensor target)
417433
THSTensor_div_(handle, target.Handle);
418434
}
419435

436+
[DllImport("libTorchSharp")]
437+
extern static IntPtr THSTensor_divS(IntPtr src, int trg);
438+
439+
public TorchTensor Div(int target)
440+
{
441+
return new TorchTensor(THSTensor_divS(handle, target));
442+
}
443+
420444
[DllImport("libTorchSharp")]
421445
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
422446

@@ -573,6 +597,11 @@ public TorchTensor Sum(long[] dimensions, bool keepDimension = false)
573597
return left.Div(right);
574598
}
575599

600+
public static TorchTensor operator /(TorchTensor left, int right)
601+
{
602+
return left.Div(right);
603+
}
604+
576605
/// <summary>
577606
/// Get a string representation of the tensor.
578607
/// </summary>

0 commit comments

Comments
 (0)