Skip to content

Commit 437a95c

Browse files
committed
added default values to Addbmm
1 parent 3f57f43 commit 437a95c

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

TorchSharp/NN/Module.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ internal Module(IntPtr handle)
5252
this.handle = new HType(handle, true);
5353
}
5454

55+
internal Module()
56+
{
57+
this.handle = new HType(IntPtr.Zero, true);
58+
}
59+
5560
[DllImport("libTorchSharp")]
5661
extern static IntPtr THSNN_new_module(IntPtr names, IntPtr parameters, IntPtr with_grad, int length);
5762

TorchSharp/Tensor/ITorchTensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ public interface ITorchTensor : IDisposable
6262

6363
void AddInPlace(ITorchTensor target, int scalar = 1);
6464

65-
ITorchTensor Addbmm(ITorchTensor batch1, ITorchTensor batch2, float beta, float alpha);
65+
ITorchTensor Addbmm(ITorchTensor batch1, ITorchTensor batch2, float beta = 1, float alpha = 1);
6666

6767
ITorchTensor Argmax(long dimension, bool keepDimension = false);
6868

69-
ITorchTensor Baddbmm(ITorchTensor batch2, ITorchTensor mat, float beta, float alpha);
69+
ITorchTensor Baddbmm(ITorchTensor batch2, ITorchTensor mat, float beta = 1, float alpha = 1);
7070

7171
ITorchTensor Bmm(ITorchTensor batch2);
7272

TorchSharp/Tensor/TorchTensor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ public void AddInPlace(ITorchTensor target, int scalar = 1)
334334
[DllImport("libTorchSharp")]
335335
extern static IntPtr THSTensor_addbmm(IntPtr mat, IntPtr batch1, IntPtr batch2, float beta, float alpha);
336336

337-
public ITorchTensor Addbmm(ITorchTensor batch1, ITorchTensor batch2, float beta, float alpha)
337+
public ITorchTensor Addbmm(ITorchTensor batch1, ITorchTensor batch2, float beta = 1, float alpha = 1)
338338
{
339339
return new TorchTensor(THSTensor_addbmm(handle, batch1.Handle, batch2.Handle, beta, alpha));
340340
}
@@ -358,7 +358,7 @@ public ITorchTensor Argmax(long dimension, bool keepDim = false)
358358
[DllImport("libTorchSharp")]
359359
extern static IntPtr THSTensor_baddbmm(IntPtr batch1, IntPtr batch2, IntPtr mat, float beta, float alpha);
360360

361-
public ITorchTensor Baddbmm(ITorchTensor batch2, ITorchTensor mat, float beta, float alpha)
361+
public ITorchTensor Baddbmm(ITorchTensor batch2, ITorchTensor mat, float beta = 1, float alpha = 1)
362362
{
363363
return new TorchTensor(THSTensor_addbmm(handle, batch2.Handle, mat.Handle, beta, alpha));
364364
}

0 commit comments

Comments
 (0)