Skip to content

Commit f968b3d

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/specs) Make torchx.specs.named_resources iterable (#1163)
Summary: There's a bunch of use-cases where it is useful to be able to iterate over all the registered named resources. Implementing the `__iter__()` method makes this possible. NOTE: To keep the API consistent, the `_NamedResourcesLibrary.__iter__()` method materializes the resource objec by calling the factory function. This means that if you iterate over `named_resources` early in the program, you end up losing out on any lazy materialization. Reviewed By: AbishekS Differential Revision: D86984810
1 parent 6bde935 commit f968b3d

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

torchx/specs/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import difflib
1515

1616
import os
17-
from typing import Callable, Dict, Mapping, Optional
17+
from typing import Callable, Dict, Iterator, Mapping, Optional
1818

1919
from torchx.specs.api import (
2020
ALL,
@@ -113,8 +113,9 @@ def __getitem__(self, key: str) -> Resource:
113113
def __contains__(self, key: str) -> bool:
114114
return key in _named_resource_factories
115115

116-
def __iter__(self) -> None:
117-
raise NotImplementedError("named resources doesn't support iterating")
116+
def __iter__(self) -> Iterator[tuple[str, Resource]]:
117+
for key in _named_resource_factories:
118+
yield (key, self[key])
118119

119120

120121
named_resources: _NamedResourcesLibrary = _NamedResourcesLibrary()

torchx/specs/test/api_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,15 @@ def test_copy_resource(self) -> None:
361361
self.assertEqual(new_resource.capabilities["new_key"], "new_value")
362362
self.assertEqual(resource.capabilities["test_key"], "test_value")
363363

364+
def test_named_resources_iterator(self) -> None:
365+
registered_named_resources = {}
366+
for name, resource_obj in named_resources:
367+
self.assertEqual(resource(h=name), resource_obj)
368+
registered_named_resources[name] = resource_obj
369+
370+
# validate that the for-loop was not vacuous
371+
self.assertTrue(registered_named_resources)
372+
364373
def test_named_resources(self) -> None:
365374
self.assertEqual(
366375
named_resources_aws.aws_m5_2xlarge(), named_resources["aws_m5.2xlarge"]

0 commit comments

Comments
 (0)