Skip to content

Commit 5948070

Browse files
authored
Merge pull request #54 from BerkeleyLearnVerify/dfremont-fixes
Fix ScenicSampler with more than 10 objects (again)
2 parents abbb042 + 77a9868 commit 5948070

File tree

4 files changed

+56
-9
lines changed

4 files changed

+56
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "verifai"
3-
version = "2.1.0"
3+
version = "2.1.1"
44
description = "A toolkit for the formal design and analysis of systems that include artificial intelligence (AI) and machine learning (ML) components."
55
authors = [
66
{ name = "Tommaso Dreossi" },

src/verifai/features/features.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,21 @@ def __repr__(self):
721721
return f'ScalarArray({self.domain}, {self.shape})'
722722

723723
class Struct(Domain):
724-
"""A domain consisting of named sub-domains."""
724+
"""A domain consisting of named sub-domains.
725+
726+
The order of the sub-domains is arbitrary: two Structs are considered equal
727+
if they have the same named sub-domains, regardless of order. As the order
728+
is an implementation detail, accessing the values of sub-domains in points
729+
sampled from a Struct should be done by name:
730+
731+
>>> struct = Struct({'a': Box((0, 1)), 'b': Box((2, 3))})
732+
>>> point = struct.uniformPoint()
733+
>>> point.b
734+
(2.20215292046797,)
735+
736+
Within a given version of VerifAI, the sub-domain order is consistent, so
737+
that the order of columns in error tables is also consistent.
738+
"""
725739

726740
def __init__(self, domains):
727741
self.namedDomains = tuple(sorted(domains.items(), key=lambda i: i[0]))

src/verifai/samplers/scenic_sampler.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ def spaceForScenario(scenario, ignoredProperties):
194194
assert scenario.egoObject is scenario.objects[0]
195195
doms = (domainForObject(obj, ignoredProperties)
196196
for obj in scenario.objects)
197-
objects = Struct({ f'object{i}': dom for i, dom in enumerate(doms) })
197+
objects = Struct({
198+
ScenicSampler.nameForObject(i): dom
199+
for i, dom in enumerate(doms)
200+
})
198201

199202
# create domains for global parameters
200203
paramDoms = {}
@@ -279,15 +282,30 @@ def nextSample(self, feedback=None):
279282
return self.pointForScene(self.lastScene)
280283

281284
def pointForScene(self, scene):
282-
"""Convert a sampled Scenic :obj:`Scene` to a point in our feature space."""
285+
"""Convert a sampled Scenic :obj:`~scenic.core.scenarios.Scene` to a point in our feature space.
286+
287+
The `FeatureSpace` used by this sampler consists of 2 features:
288+
289+
* ``objects``, which is a `Struct` consisting of attributes ``object0``,
290+
``object1``, etc. with the properties of the corresponding objects
291+
in the Scenic program. The names of these attributes may change in a
292+
future version of VerifAI: use the `nameForObject` function to
293+
generate them.
294+
* ``params``, which is a `Struct` storing the values of the
295+
:term:`global parameters` of the Scenic program (use
296+
`paramDictForSample` to extract them).
297+
"""
283298
lengths, dom = self.space.domains
284299
assert lengths is None
285300
assert scene.egoObject is scene.objects[0]
286301
objDomain = dom.domainNamed['objects']
287302
assert len(objDomain.domains) == len(scene.objects)
288-
objects = (pointForObject(objDomain.domainNamed[f'object{i}'], obj)
289-
for i, obj in enumerate(scene.objects))
290-
objPoint = objDomain.makePoint(*objects)
303+
objects = {
304+
self.nameForObject(i):
305+
pointForObject(objDomain.domainNamed[self.nameForObject(i)], obj)
306+
for i, obj in enumerate(scene.objects)
307+
}
308+
objPoint = objDomain.makePoint(**objects)
291309

292310
paramDomain = dom.domainNamed['params']
293311
params = {}
@@ -298,8 +316,17 @@ def pointForScene(self, scene):
298316

299317
return self.space.makePoint(objects=objPoint, params=paramPoint)
300318

319+
@staticmethod
320+
def nameForObject(i):
321+
"""Name used in the `FeatureSpace` for the Scenic object with index i.
322+
323+
That is, if ``scene`` is a :obj:`~scenic.core.scenarios.Scene`, the object
324+
``scene.objects[i]``.
325+
"""
326+
return f'object{i}'
327+
301328
def paramDictForSample(self, sample):
302-
"""Recover the dict of global parameters from a `ScenicSampler` sample."""
329+
"""Recover the dict of :term:`global parameters` from a `ScenicSampler` sample."""
303330
params = sample.params._asdict()
304331
corrected = {}
305332
for newName, quotedParam in self.quotedParams.items():

tests/scenic/test_scenic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,15 @@ def test_object_order(new_Object):
7171
sample = sampler.nextSample()
7272
objects = sample.objects
7373
assert len(objects) == 11
74-
for i, obj in enumerate(objects):
74+
for i in range(len(objects)):
75+
name = ScenicSampler.nameForObject(i)
76+
obj = getattr(objects, name)
7577
assert obj.position[:2] == pytest.approx((2*i, 0))
7678

79+
flat = sampler.space.flatten(sample)
80+
unflat = sampler.space.unflatten(flat)
81+
assert unflat == sample
82+
7783
## Active sampling
7884

7985
def test_active_sampling(new_Object):

0 commit comments

Comments
 (0)