Skip to content

Commit 969e0b8

Browse files
authored
Merge pull request #1848 from Trusted-AI/development_nmegiddo_patch-1
Move new random sampling method to art.utils
2 parents c67d59f + cbc20fb commit 969e0b8

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

art/utils.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,78 @@ def random_sphere(
578578
return res
579579

580580

581+
def uniform_sample_from_sphere_or_ball(
582+
nb_points: int,
583+
nb_dims: int,
584+
radius: Union[int, float, np.ndarray],
585+
sample_space: str = "ball",
586+
norm: Union[int, float, str] = 2,
587+
) -> np.ndarray:
588+
"""
589+
Generate a sample of <nb_points> distributed independently and uniformly on the sphere (with respect to the given
590+
norm) in dimension <nb_dims> with radius <radius> and centered at the origin. Note that the sphere is the boundary
591+
of the ball, i.e., every point on the sphere has the same distance to the origin.
592+
593+
:param nb_points: Number of random data points
594+
:param nb_dims: Dimensionality of the sphere
595+
:param radius: Radius of the sphere
596+
:param sample_space: One of 'b', 's', 'sphere', 'ball'
597+
:param norm: Current support: 1, 2, np.inf, "inf"
598+
:return: The sampled points from the sphere (i.e., boundary of the ball)
599+
"""
600+
assert sample_space in ["b", "s", "sphere", "ball"]
601+
602+
if norm == 1:
603+
if sample_space in ["s", "sphere"]:
604+
y = np.random.exponential(1, (nb_points, nb_dims))
605+
sums = np.sum(y, axis=1)
606+
scal = np.outer(sums, np.ones(nb_dims))
607+
y = y / scal
608+
else:
609+
y = np.random.exponential(1, (nb_points, nb_dims + 1))
610+
sums = np.sum(y, axis=1)
611+
scal = np.outer(sums, np.ones(nb_dims + 1))
612+
y = y / scal
613+
y = y[:, :nb_dims]
614+
615+
y = y * np.random.choice([-1, 1], (nb_points, nb_dims))
616+
if type(radius) in [int, float]:
617+
res = y * radius
618+
else:
619+
radii = np.outer(radius, np.ones(nb_dims))
620+
res = y * radii
621+
622+
elif norm == 2:
623+
x = np.random.normal(0.0, 1.0, (nb_points, nb_dims))
624+
scal = radius / np.sqrt(np.sum(x * x, axis=1))
625+
scal = np.transpose(np.array(list(scal) * nb_dims).reshape((nb_dims, nb_points)))
626+
res = x * scal
627+
if sample_space in ["b", "ball"]:
628+
rnd = np.random.rand(nb_points)
629+
scal = np.float_power(rnd, 1 / nb_dims)
630+
scal = np.outer(scal, np.ones(nb_dims))
631+
res = res * scal
632+
633+
elif norm in [np.inf, "inf"]:
634+
if sample_space in ["b", "ball"]:
635+
x = np.random.uniform(-1.0, 1.0, (nb_points, nb_dims))
636+
else:
637+
x = np.random.uniform(0, 1.0, (nb_points, nb_dims))
638+
rnd = np.random.random((nb_points, nb_dims))
639+
maxes = np.max(rnd, axis=1)
640+
maxes = np.outer(maxes, np.ones(nb_dims))
641+
x = np.maximum((rnd >= maxes), x) * np.random.choice([-1, 1], (nb_points, nb_dims))
642+
if type(radius) in [int, float]:
643+
res = x * radius
644+
else:
645+
radii = np.outer(radius, np.ones(nb_dims))
646+
res = x * radii
647+
else:
648+
raise NotImplementedError(f"Norm {norm} not supported")
649+
650+
return res
651+
652+
581653
def original_to_tanh(
582654
x_original: np.ndarray,
583655
clip_min: Union[float, np.ndarray],

tests/test_utils.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import numpy as np
2424
import tensorflow as tf
2525

26-
from art.utils import projection, random_sphere, to_categorical, least_likely_class, check_and_transform_label_format
26+
from art.utils import projection, random_sphere, uniform_sample_from_sphere_or_ball, to_categorical, least_likely_class
2727
from art.utils import load_dataset, load_iris, load_mnist, load_nursery, load_cifar10
2828
from art.utils import second_most_likely_class, random_targets, get_label_conf, get_labels_np_array, preprocess
29-
from art.utils import compute_success_array, compute_success
29+
from art.utils import compute_success_array, compute_success, check_and_transform_label_format
3030
from art.utils import segment_by_class, performance_diff
3131
from art.utils import is_probability
3232

@@ -179,6 +179,34 @@ def test_random_sphere(self):
179179
x = random_sphere(1, 10000, 1, np.inf)
180180
self.assertTrue(np.abs(np.max(np.abs(x), axis=1) - 1.0) < 1e-2)
181181

182+
def test_uniform_sample_from_sphere_or_ball(self):
183+
x = uniform_sample_from_sphere_or_ball(nb_points=10, nb_dims=10, radius=1, sample_space="ball", norm=1)
184+
185+
self.assertEqual(x.shape, (10, 10))
186+
self.assertTrue(np.all(np.sum(np.abs(x), axis=1) <= 1.0))
187+
188+
x = uniform_sample_from_sphere_or_ball(nb_points=10, nb_dims=10, radius=1, sample_space="ball", norm=np.inf)
189+
self.assertEqual(x.shape, (10, 10))
190+
self.assertTrue(np.all(np.abs(x) < 1.0))
191+
192+
x = uniform_sample_from_sphere_or_ball(nb_points=10, nb_dims=10, radius=0.5, sample_space="ball", norm=1)
193+
self.assertTrue(np.all(np.sum(np.abs(x), axis=1) <= 0.5))
194+
195+
x = uniform_sample_from_sphere_or_ball(nb_points=10, nb_dims=10, radius=0.5, sample_space="ball", norm=2)
196+
self.assertTrue(np.all(np.linalg.norm(x, axis=1) < 0.5))
197+
198+
x = uniform_sample_from_sphere_or_ball(nb_points=10, nb_dims=10, radius=0.5, sample_space="ball", norm=np.inf)
199+
self.assertTrue(np.all(np.abs(x) < 0.5))
200+
201+
x = uniform_sample_from_sphere_or_ball(nb_points=1, nb_dims=10000, radius=1, sample_space="ball", norm=1)
202+
self.assertTrue(np.abs(np.sum(np.abs(x), axis=1) - 1.0) < 1e-2)
203+
204+
x = uniform_sample_from_sphere_or_ball(nb_points=1, nb_dims=10000, radius=1, sample_space="ball", norm=2)
205+
self.assertTrue(np.abs(np.linalg.norm(x, axis=1) - 1.0) < 1e-2)
206+
207+
x = uniform_sample_from_sphere_or_ball(nb_points=1, nb_dims=10000, radius=1, sample_space="ball", norm=np.inf)
208+
self.assertTrue(np.abs(np.max(np.abs(x), axis=1) - 1.0) < 1e-2)
209+
182210
def test_to_categorical(self):
183211
y = np.array([3, 1, 4, 1, 5, 9])
184212
y_ = to_categorical(y)

0 commit comments

Comments
 (0)