Skip to content

Commit ac9c239

Browse files
committed
flatten weights
1 parent 30c9497 commit ac9c239

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

art/metrics/metrics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ def wasserstein_distance(
377377
u_values = u_values.flatten().reshape(u_values.shape[0], -1)
378378
v_values = v_values.flatten().reshape(v_values.shape[0], -1)
379379

380+
if u_weights is not None and v_weights is not None:
381+
u_weights = u_weights.flatten().reshape(u_weights.shape[0], -1)
382+
v_weights = v_weights.flatten().reshape(v_weights.shape[0], -1)
383+
380384
wd = np.zeros(u_values.shape[0])
381385

382386
for i in range(u_values.shape[0]):

tests/metrics/test_metrics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,16 +373,20 @@ def test_wasserstein_distance(self):
373373

374374
x_train = x_train[0:nb_train]
375375
x_test = x_test[0:nb_test]
376+
weights = np.ones_like(x_train)
376377

377378
wd_0 = wasserstein_distance(x_train[:batch_size], x_train[:batch_size])
378379
wd_1 = wasserstein_distance(x_train[:batch_size], x_test[:batch_size])
380+
wd_2 = wasserstein_distance(x_train[:batch_size], x_train[:batch_size], weights[:batch_size],
381+
weights[:batch_size])
379382

380383
np.testing.assert_array_equal(wd_0, np.asarray([0.0, 0.0, 0.0]))
381384
np.testing.assert_array_almost_equal(wd_1, np.asarray([0.04564, 0.01235, 0.04787]), decimal=4)
382385

383386
np.testing.assert_array_equal(x_train.shape, np.asarray([nb_train, 28, 28, 1]))
384387
np.testing.assert_array_equal(x_test.shape, np.asarray([nb_test, 28, 28, 1]))
385388

389+
np.testing.assert_array_equal(wd_2, np.asarray([0.0, 0.0, 0.0]))
386390

387391
if __name__ == "__main__":
388392
unittest.main()

0 commit comments

Comments
 (0)