Skip to content

Commit fb8de14

Browse files
Merge pull request #1335 from NiklasGustafsson/missing
Fixed issue #1334
2 parents bd1f04e + 4232194 commit fb8de14

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

RELEASENOTES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@ When creating a tensor from a 1-D array, and passing in a shape, there is now an
1010

1111
__API Changes__:
1212

13-
#1326 Allow arrays used to create tensors to be larger than the tensor. Create tensors from a Memory instance<br/>
13+
#1326 Allow arrays used to create tensors to be larger than the tensor. Create tensors from a Memory instance.<br/>
1414

1515
__Bug Fixes__:
1616

17+
#1334 MultivariateNormal.log_prob() exception in TorchSharp but works in pytorch.<br/>
1718

1819
# NuGet Version 0.102.5
1920

src/TorchSharp/Distributions/MultiVariateNormal.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ private Tensor BatchMahalanobis(Tensor bL, Tensor bx)
218218
var old_batch_dims = outer_batch_dims + bL_batch_dims;
219219
var new_batch_dims = outer_batch_dims + 2 * bL_batch_dims;
220220

221-
var bx_new_shape = TakeAllBut(bx.shape, outer_batch_dims).ToList();
221+
var bx_new_shape = bx.shape.Take(outer_batch_dims).ToList();
222222

223223
for (int i = 0; i < bL.ndim - 2; i++) {
224224
var sL = bL.shape[i];

test/TorchSharpTest/TestDistributions.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,21 @@ public void TestMultivariateNormal()
13321332
}
13331333
}
13341334

1335+
[Fact]
1336+
public void TestMultivariateNormal_1334()
1337+
{
1338+
var actionMean = torch.tensor(new double[]{0.2025, -0.0714, 0.1417});
1339+
var covMat = torch.tensor(new double[,]
1340+
{
1341+
{ 0.36, 0, 0 },
1342+
{ 0, 0.36, 0 },
1343+
{ 0, 0, 0.36 },
1344+
});
1345+
var dist = torch.distributions.MultivariateNormal(actionMean, covariance_matrix: covMat);
1346+
torch.Tensor action = dist.sample();
1347+
torch.Tensor actionLogProb = dist.log_prob(action);
1348+
}
1349+
13351350
[Fact]
13361351
public void TestWeibull()
13371352
{

0 commit comments

Comments
 (0)