Skip to content

Commit 6b6fda0

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/specs) Make torchx.specs.named_resources iterable (#1163)
Summary: There's a bunch of use-cases where it is useful to be able to iterate over all the registered named resources. Implementing the `__iter__()` method makes this possible. Reviewed By: AbishekS Differential Revision: D86984810
1 parent 6bde935 commit 6b6fda0

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

torchx/specs/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import difflib
1515

1616
import os
17-
from typing import Callable, Dict, Mapping, Optional
17+
from typing import Callable, Dict, Iterator, Mapping, Optional
1818

1919
from torchx.specs.api import (
2020
ALL,
@@ -113,8 +113,22 @@ def __getitem__(self, key: str) -> Resource:
113113
def __contains__(self, key: str) -> bool:
114114
return key in _named_resource_factories
115115

116-
def __iter__(self) -> None:
117-
raise NotImplementedError("named resources doesn't support iterating")
116+
def __iter__(self) -> Iterator[str]:
117+
"""Iterates through the names of the registered named_resources.
118+
119+
Usage:
120+
121+
.. doctest::
122+
123+
from torchx import specs
124+
125+
for resource_name in specs.named_resources:
126+
resource = specs.resource(h=resource_name)
127+
assert isinstance(resource, specs.Resource)
128+
129+
"""
130+
for key in _named_resource_factories:
131+
yield (key)
118132

119133

120134
named_resources: _NamedResourcesLibrary = _NamedResourcesLibrary()

torchx/specs/test/api_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,16 @@ def test_copy_resource(self) -> None:
361361
self.assertEqual(new_resource.capabilities["new_key"], "new_value")
362362
self.assertEqual(resource.capabilities["test_key"], "test_value")
363363

364+
def test_named_resources_iterator(self) -> None:
365+
registered_named_resources = set()
366+
for resource_name in named_resources:
367+
# just make sure we can create the resource using the name
368+
self.assertIsInstance(resource(h=resource_name), Resource)
369+
registered_named_resources.add(resource_name)
370+
371+
# validate that the for-loop was not vacuous
372+
self.assertTrue(registered_named_resources)
373+
364374
def test_named_resources(self) -> None:
365375
self.assertEqual(
366376
named_resources_aws.aws_m5_2xlarge(), named_resources["aws_m5.2xlarge"]

0 commit comments

Comments
 (0)