Skip to content

Commit 41d50b1

Browse files
committed
Added mm, addmm and bmm.
1 parent 6a4c25c commit 41d50b1

File tree

2 files changed

+184
-16
lines changed

2 files changed

+184
-16
lines changed

TorchSharp/Tensor/TorchTensor.generated.cs

Lines changed: 156 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,14 @@ public ITorchTensor<byte> Addbmm(ITorchTensor<byte> batch1, ITorchTensor<byte> b
360360
return new ByteTensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
361361
}
362362

363+
[DllImport("libTorchSharp")]
364+
extern static IntPtr THSTensor_addmm(IntPtr mat, IntPtr mat1, IntPtr mat2, float beta, float alpha);
365+
366+
public ITorchTensor<byte> Addmm(ITorchTensor<byte> mat1, ITorchTensor<byte> mat2, float beta, float alpha)
367+
{
368+
return new ByteTensor(THSTensor_addmm(handle, mat1.Handle, mat2.Handle, beta, alpha));
369+
}
370+
363371
[DllImport("libTorchSharp")]
364372
extern static IntPtr THSTensor_argmax(IntPtr src, long dimension, bool keep_dim);
365373

@@ -376,6 +384,14 @@ public ITorchTensor<byte> Baddbmm(ITorchTensor<byte> batch2, ITorchTensor<byte>
376384
return new ByteTensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
377385
}
378386

387+
[DllImport("libTorchSharp")]
388+
extern static IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
389+
390+
public ITorchTensor<byte> Bmm(ITorchTensor<byte> batch2)
391+
{
392+
return new ByteTensor(THSTensor_bmm(handle, batch2.Handle));
393+
}
394+
379395
[DllImport("libTorchSharp")]
380396
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
381397

@@ -393,11 +409,19 @@ public ITorchTensor<byte> Exp()
393409
}
394410

395411
[DllImport("libTorchSharp")]
396-
extern static IntPtr THSTensor_matMul(IntPtr src, IntPtr target);
412+
extern static IntPtr THSTensor_matmul(IntPtr src, IntPtr target);
397413

398414
public ITorchTensor<byte> MatMul(ITorchTensor<byte> target)
399415
{
400-
return new ByteTensor(THSTensor_matMul(handle, target.Handle));
416+
return new ByteTensor(THSTensor_matmul(handle, target.Handle));
417+
}
418+
419+
[DllImport("libTorchSharp")]
420+
extern static IntPtr THSTensor_mm(IntPtr src, IntPtr target);
421+
422+
public ITorchTensor<byte> Mm(ITorchTensor<byte> target)
423+
{
424+
return new ByteTensor(THSTensor_mm(handle, target.Handle));
401425
}
402426

403427
[DllImport("libTorchSharp")]
@@ -848,6 +872,14 @@ public ITorchTensor<short> Addbmm(ITorchTensor<short> batch1, ITorchTensor<short
848872
return new ShortTensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
849873
}
850874

875+
[DllImport("libTorchSharp")]
876+
extern static IntPtr THSTensor_addmm(IntPtr mat, IntPtr mat1, IntPtr mat2, float beta, float alpha);
877+
878+
public ITorchTensor<short> Addmm(ITorchTensor<short> mat1, ITorchTensor<short> mat2, float beta, float alpha)
879+
{
880+
return new ShortTensor(THSTensor_addmm(handle, mat1.Handle, mat2.Handle, beta, alpha));
881+
}
882+
851883
[DllImport("libTorchSharp")]
852884
extern static IntPtr THSTensor_argmax(IntPtr src, long dimension, bool keep_dim);
853885

@@ -864,6 +896,14 @@ public ITorchTensor<short> Baddbmm(ITorchTensor<short> batch2, ITorchTensor<shor
864896
return new ShortTensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
865897
}
866898

899+
[DllImport("libTorchSharp")]
900+
extern static IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
901+
902+
public ITorchTensor<short> Bmm(ITorchTensor<short> batch2)
903+
{
904+
return new ShortTensor(THSTensor_bmm(handle, batch2.Handle));
905+
}
906+
867907
[DllImport("libTorchSharp")]
868908
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
869909

@@ -881,11 +921,19 @@ public ITorchTensor<short> Exp()
881921
}
882922

883923
[DllImport("libTorchSharp")]
884-
extern static IntPtr THSTensor_matMul(IntPtr src, IntPtr target);
924+
extern static IntPtr THSTensor_matmul(IntPtr src, IntPtr target);
885925

886926
public ITorchTensor<short> MatMul(ITorchTensor<short> target)
887927
{
888-
return new ShortTensor(THSTensor_matMul(handle, target.Handle));
928+
return new ShortTensor(THSTensor_matmul(handle, target.Handle));
929+
}
930+
931+
[DllImport("libTorchSharp")]
932+
extern static IntPtr THSTensor_mm(IntPtr src, IntPtr target);
933+
934+
public ITorchTensor<short> Mm(ITorchTensor<short> target)
935+
{
936+
return new ShortTensor(THSTensor_mm(handle, target.Handle));
889937
}
890938

891939
[DllImport("libTorchSharp")]
@@ -1336,6 +1384,14 @@ public ITorchTensor<int> Addbmm(ITorchTensor<int> batch1, ITorchTensor<int> batc
13361384
return new IntTensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
13371385
}
13381386

1387+
[DllImport("libTorchSharp")]
1388+
extern static IntPtr THSTensor_addmm(IntPtr mat, IntPtr mat1, IntPtr mat2, float beta, float alpha);
1389+
1390+
public ITorchTensor<int> Addmm(ITorchTensor<int> mat1, ITorchTensor<int> mat2, float beta, float alpha)
1391+
{
1392+
return new IntTensor(THSTensor_addmm(handle, mat1.Handle, mat2.Handle, beta, alpha));
1393+
}
1394+
13391395
[DllImport("libTorchSharp")]
13401396
extern static IntPtr THSTensor_argmax(IntPtr src, long dimension, bool keep_dim);
13411397

@@ -1352,6 +1408,14 @@ public ITorchTensor<int> Baddbmm(ITorchTensor<int> batch2, ITorchTensor<int> mat
13521408
return new IntTensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
13531409
}
13541410

1411+
[DllImport("libTorchSharp")]
1412+
extern static IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
1413+
1414+
public ITorchTensor<int> Bmm(ITorchTensor<int> batch2)
1415+
{
1416+
return new IntTensor(THSTensor_bmm(handle, batch2.Handle));
1417+
}
1418+
13551419
[DllImport("libTorchSharp")]
13561420
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
13571421

@@ -1369,11 +1433,19 @@ public ITorchTensor<int> Exp()
13691433
}
13701434

13711435
[DllImport("libTorchSharp")]
1372-
extern static IntPtr THSTensor_matMul(IntPtr src, IntPtr target);
1436+
extern static IntPtr THSTensor_matmul(IntPtr src, IntPtr target);
13731437

13741438
public ITorchTensor<int> MatMul(ITorchTensor<int> target)
13751439
{
1376-
return new IntTensor(THSTensor_matMul(handle, target.Handle));
1440+
return new IntTensor(THSTensor_matmul(handle, target.Handle));
1441+
}
1442+
1443+
[DllImport("libTorchSharp")]
1444+
extern static IntPtr THSTensor_mm(IntPtr src, IntPtr target);
1445+
1446+
public ITorchTensor<int> Mm(ITorchTensor<int> target)
1447+
{
1448+
return new IntTensor(THSTensor_mm(handle, target.Handle));
13771449
}
13781450

13791451
[DllImport("libTorchSharp")]
@@ -1824,6 +1896,14 @@ public ITorchTensor<long> Addbmm(ITorchTensor<long> batch1, ITorchTensor<long> b
18241896
return new LongTensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
18251897
}
18261898

1899+
[DllImport("libTorchSharp")]
1900+
extern static IntPtr THSTensor_addmm(IntPtr mat, IntPtr mat1, IntPtr mat2, float beta, float alpha);
1901+
1902+
public ITorchTensor<long> Addmm(ITorchTensor<long> mat1, ITorchTensor<long> mat2, float beta, float alpha)
1903+
{
1904+
return new LongTensor(THSTensor_addmm(handle, mat1.Handle, mat2.Handle, beta, alpha));
1905+
}
1906+
18271907
[DllImport("libTorchSharp")]
18281908
extern static IntPtr THSTensor_argmax(IntPtr src, long dimension, bool keep_dim);
18291909

@@ -1840,6 +1920,14 @@ public ITorchTensor<long> Baddbmm(ITorchTensor<long> batch2, ITorchTensor<long>
18401920
return new LongTensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
18411921
}
18421922

1923+
[DllImport("libTorchSharp")]
1924+
extern static IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
1925+
1926+
public ITorchTensor<long> Bmm(ITorchTensor<long> batch2)
1927+
{
1928+
return new LongTensor(THSTensor_bmm(handle, batch2.Handle));
1929+
}
1930+
18431931
[DllImport("libTorchSharp")]
18441932
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
18451933

@@ -1857,11 +1945,19 @@ public ITorchTensor<long> Exp()
18571945
}
18581946

18591947
[DllImport("libTorchSharp")]
1860-
extern static IntPtr THSTensor_matMul(IntPtr src, IntPtr target);
1948+
extern static IntPtr THSTensor_matmul(IntPtr src, IntPtr target);
18611949

18621950
public ITorchTensor<long> MatMul(ITorchTensor<long> target)
18631951
{
1864-
return new LongTensor(THSTensor_matMul(handle, target.Handle));
1952+
return new LongTensor(THSTensor_matmul(handle, target.Handle));
1953+
}
1954+
1955+
[DllImport("libTorchSharp")]
1956+
extern static IntPtr THSTensor_mm(IntPtr src, IntPtr target);
1957+
1958+
public ITorchTensor<long> Mm(ITorchTensor<long> target)
1959+
{
1960+
return new LongTensor(THSTensor_mm(handle, target.Handle));
18651961
}
18661962

18671963
[DllImport("libTorchSharp")]
@@ -2312,6 +2408,14 @@ public ITorchTensor<double> Addbmm(ITorchTensor<double> batch1, ITorchTensor<dou
23122408
return new DoubleTensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
23132409
}
23142410

2411+
[DllImport("libTorchSharp")]
2412+
extern static IntPtr THSTensor_addmm(IntPtr mat, IntPtr mat1, IntPtr mat2, float beta, float alpha);
2413+
2414+
public ITorchTensor<double> Addmm(ITorchTensor<double> mat1, ITorchTensor<double> mat2, float beta, float alpha)
2415+
{
2416+
return new DoubleTensor(THSTensor_addmm(handle, mat1.Handle, mat2.Handle, beta, alpha));
2417+
}
2418+
23152419
[DllImport("libTorchSharp")]
23162420
extern static IntPtr THSTensor_argmax(IntPtr src, long dimension, bool keep_dim);
23172421

@@ -2328,6 +2432,14 @@ public ITorchTensor<double> Baddbmm(ITorchTensor<double> batch2, ITorchTensor<do
23282432
return new DoubleTensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
23292433
}
23302434

2435+
[DllImport("libTorchSharp")]
2436+
extern static IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
2437+
2438+
public ITorchTensor<double> Bmm(ITorchTensor<double> batch2)
2439+
{
2440+
return new DoubleTensor(THSTensor_bmm(handle, batch2.Handle));
2441+
}
2442+
23312443
[DllImport("libTorchSharp")]
23322444
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
23332445

@@ -2345,11 +2457,19 @@ public ITorchTensor<double> Exp()
23452457
}
23462458

23472459
[DllImport("libTorchSharp")]
2348-
extern static IntPtr THSTensor_matMul(IntPtr src, IntPtr target);
2460+
extern static IntPtr THSTensor_matmul(IntPtr src, IntPtr target);
23492461

23502462
public ITorchTensor<double> MatMul(ITorchTensor<double> target)
23512463
{
2352-
return new DoubleTensor(THSTensor_matMul(handle, target.Handle));
2464+
return new DoubleTensor(THSTensor_matmul(handle, target.Handle));
2465+
}
2466+
2467+
[DllImport("libTorchSharp")]
2468+
extern static IntPtr THSTensor_mm(IntPtr src, IntPtr target);
2469+
2470+
public ITorchTensor<double> Mm(ITorchTensor<double> target)
2471+
{
2472+
return new DoubleTensor(THSTensor_mm(handle, target.Handle));
23532473
}
23542474

23552475
[DllImport("libTorchSharp")]
@@ -2800,6 +2920,14 @@ public ITorchTensor<float> Addbmm(ITorchTensor<float> batch1, ITorchTensor<float
28002920
return new FloatTensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
28012921
}
28022922

2923+
[DllImport("libTorchSharp")]
2924+
extern static IntPtr THSTensor_addmm(IntPtr mat, IntPtr mat1, IntPtr mat2, float beta, float alpha);
2925+
2926+
public ITorchTensor<float> Addmm(ITorchTensor<float> mat1, ITorchTensor<float> mat2, float beta, float alpha)
2927+
{
2928+
return new FloatTensor(THSTensor_addmm(handle, mat1.Handle, mat2.Handle, beta, alpha));
2929+
}
2930+
28032931
[DllImport("libTorchSharp")]
28042932
extern static IntPtr THSTensor_argmax(IntPtr src, long dimension, bool keep_dim);
28052933

@@ -2816,6 +2944,14 @@ public ITorchTensor<float> Baddbmm(ITorchTensor<float> batch2, ITorchTensor<floa
28162944
return new FloatTensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
28172945
}
28182946

2947+
[DllImport("libTorchSharp")]
2948+
extern static IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
2949+
2950+
public ITorchTensor<float> Bmm(ITorchTensor<float> batch2)
2951+
{
2952+
return new FloatTensor(THSTensor_bmm(handle, batch2.Handle));
2953+
}
2954+
28192955
[DllImport("libTorchSharp")]
28202956
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
28212957

@@ -2833,11 +2969,19 @@ public ITorchTensor<float> Exp()
28332969
}
28342970

28352971
[DllImport("libTorchSharp")]
2836-
extern static IntPtr THSTensor_matMul(IntPtr src, IntPtr target);
2972+
extern static IntPtr THSTensor_matmul(IntPtr src, IntPtr target);
28372973

28382974
public ITorchTensor<float> MatMul(ITorchTensor<float> target)
28392975
{
2840-
return new FloatTensor(THSTensor_matMul(handle, target.Handle));
2976+
return new FloatTensor(THSTensor_matmul(handle, target.Handle));
2977+
}
2978+
2979+
[DllImport("libTorchSharp")]
2980+
extern static IntPtr THSTensor_mm(IntPtr src, IntPtr target);
2981+
2982+
public ITorchTensor<float> Mm(ITorchTensor<float> target)
2983+
{
2984+
return new FloatTensor(THSTensor_mm(handle, target.Handle));
28412985
}
28422986

28432987
[DllImport("libTorchSharp")]

TorchSharp/Tensor/TorchTensor.tt

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,14 @@ foreach (var type in TorchTypeDef.Types) {
367367
return new <#=type.Name#>Tensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
368368
}
369369

370+
[DllImport("libTorchSharp")]
371+
extern static IntPtr THSTensor_addmm(IntPtr mat, IntPtr mat1, IntPtr mat2, float beta, float alpha);
372+
373+
public ITorchTensor<<#=type.Storage#>> Addmm(ITorchTensor<<#=type.Storage#>> mat1, ITorchTensor<<#=type.Storage#>> mat2, float beta, float alpha)
374+
{
375+
return new <#=type.Name#>Tensor(THSTensor_addmm(handle, mat1.Handle, mat2.Handle, beta, alpha));
376+
}
377+
370378
[DllImport("libTorchSharp")]
371379
extern static IntPtr THSTensor_argmax(IntPtr src, long dimension, bool keep_dim);
372380

@@ -383,6 +391,14 @@ foreach (var type in TorchTypeDef.Types) {
383391
return new <#=type.Name#>Tensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
384392
}
385393

394+
[DllImport("libTorchSharp")]
395+
extern static IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
396+
397+
public ITorchTensor<<#=type.Storage#>> Bmm(ITorchTensor<<#=type.Storage#>> batch2)
398+
{
399+
return new <#=type.Name#>Tensor(THSTensor_bmm(handle, batch2.Handle));
400+
}
401+
386402
[DllImport("libTorchSharp")]
387403
extern static IntPtr THSTensor_eq(IntPtr src, IntPtr trg);
388404

@@ -400,11 +416,19 @@ foreach (var type in TorchTypeDef.Types) {
400416
}
401417

402418
[DllImport("libTorchSharp")]
403-
extern static IntPtr THSTensor_matMul(IntPtr src, IntPtr target);
419+
extern static IntPtr THSTensor_matmul(IntPtr src, IntPtr target);
404420

405421
public ITorchTensor<<#=type.Storage#>> MatMul(ITorchTensor<<#=type.Storage#>> target)
406422
{
407-
return new <#=type.Name#>Tensor(THSTensor_matMul(handle, target.Handle));
423+
return new <#=type.Name#>Tensor(THSTensor_matmul(handle, target.Handle));
424+
}
425+
426+
[DllImport("libTorchSharp")]
427+
extern static IntPtr THSTensor_mm(IntPtr src, IntPtr target);
428+
429+
public ITorchTensor<<#=type.Storage#>> Mm(ITorchTensor<<#=type.Storage#>> target)
430+
{
431+
return new <#=type.Name#>Tensor(THSTensor_mm(handle, target.Handle));
408432
}
409433

410434
[DllImport("libTorchSharp")]
@@ -571,9 +595,9 @@ foreach (var type in TorchTypeDef.Types) {
571595
[DllImport("libTorchSharp")]
572596
extern static void THSTensor_initUniform(IntPtr src, double low, double high);
573597

574-
internal staticvoid InitUniform<T>(this ITorchTensor<T> tensor, double low = 0, double high = 1)
598+
internal static void InitUniform<T>(this ITorchTensor<T> tensor, double low = 0, double high = 1)
575599
{
576-
THSTensor_initUniform(tensor.Handle, low, high).ToTorchTensor<T>();
600+
THSTensor_initUniform(tensor.Handle, low, high);
577601
}
578602
}
579603
}

0 commit comments

Comments
 (0)