Skip to content

Commit 6e58312

Browse files
committed
feat: add config-based named_resources support (#1085)
1 parent 2d59334 commit 6e58312

File tree

4 files changed

+345
-2
lines changed

4 files changed

+345
-2
lines changed

docs/source/advanced.rst

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,39 @@ Registering Named Resources
7777

7878
A Named Resource is a set of predefined resource specs that are given a
7979
string name. This is particularly useful
80-
when your cluster has a fixed set of instance types. For instance if your
81-
deep learning training kubernetes cluster on AWS is
80+
when your cluster has a fixed set of instance types.
81+
82+
TorchX supports two ways to define named resources:
83+
84+
1. **Configuration-based**: Define resources in ``.torchxconfig`` files (recommended for most users)
85+
2. **Entry point-based**: Register resources via Python entry points (for package authors)
86+
87+
Configuration-Based Named Resources
88+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89+
90+
For most users, the easiest way to define custom named resources is through configuration files.
91+
92+
Create a ``.torchxconfig`` file in your project directory:
93+
94+
.. code-block:: ini
95+
96+
[named_resources]
97+
dynamic = {"cpu": 100, "gpu": 8, "memMB": 819200, "devices": {"vpc.amazonaws.com/efa": 1}}
98+
my_custom = {"cpu": 32, "gpu": 4, "memMB": 131072}
99+
100+
You can also use the ``TORCHXCONFIG`` environment variable to specify a custom config file path.
101+
102+
Usage example:
103+
104+
.. code-block:: python
105+
106+
from torchx.specs import resource
107+
my_resource = resource(h="dynamic") # Uses your config-defined resource
108+
109+
Entry Point-Based Named Resources
110+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
111+
112+
For instance if your deep learning training kubernetes cluster on AWS is
82113
comprised only of p3.16xlarge (64 vcpu, 8 gpu, 488GB), then you may want to
83114
enumerate t-shirt sized resource specs for the containers as:
84115

torchx/specs/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@
5858
GENERIC_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr(
5959
"torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
6060
)
61+
CONFIG_NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = import_attr(
62+
"torchx.specs.named_resources_config", "NAMED_RESOURCES", default={}
63+
)
6164

6265
GiB: int = 1024
6366

@@ -69,6 +72,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
6972
for name, resource in {
7073
**GENERIC_NAMED_RESOURCES,
7174
**AWS_NAMED_RESOURCES,
75+
**CONFIG_NAMED_RESOURCES,
7276
**resource_methods,
7377
}.items():
7478
materialized_resources[name] = resource
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
"""
10+
Configuration-based named resources that can be defined via .torchxconfig file.
11+
This allows users to define custom named resources with specific CPU, GPU, memory,
12+
and device requirements without hardcoding them.
13+
14+
Example .torchxconfig:
15+
[named_resources]
16+
dynamic = {"cpu": 100, "gpu": 8, "memMB": 819200, "devices": {"vpc.amazonaws.com/efa": 1}}
17+
my_custom = {"cpu": 32, "gpu": 4, "memMB": 131072}
18+
"""
19+
20+
import json
21+
import os
22+
from configparser import ConfigParser
23+
from typing import Callable, Dict, Mapping
24+
25+
from torchx.specs.api import Resource
26+
27+
28+
def _load_config_file() -> ConfigParser:
29+
"""Load the .torchxconfig file from TORCHXCONFIG env var or current directory."""
30+
config = ConfigParser()
31+
32+
# Check TORCHXCONFIG environment variable first, then current directory
33+
config_path = os.environ.get("TORCHXCONFIG", ".torchxconfig")
34+
35+
if os.path.exists(config_path):
36+
config.read(config_path)
37+
38+
return config
39+
40+
41+
def _parse_resource_config(config_str: str) -> Resource:
42+
"""Parse a resource configuration string into a Resource object."""
43+
try:
44+
config_dict = json.loads(config_str)
45+
except json.JSONDecodeError as e:
46+
raise ValueError(f"Invalid JSON in resource configuration: {e}")
47+
48+
# Extract standard resource parameters
49+
cpu = config_dict.get("cpu", 1)
50+
gpu = config_dict.get("gpu", 0)
51+
memMB = config_dict.get("memMB", 1024)
52+
53+
# Extract optional parameters
54+
capabilities = config_dict.get("capabilities", {})
55+
devices = config_dict.get("devices", {})
56+
57+
return Resource(
58+
cpu=cpu,
59+
gpu=gpu,
60+
memMB=memMB,
61+
capabilities=capabilities,
62+
devices=devices,
63+
)
64+
65+
66+
def _create_resource_factory(config_str: str) -> Callable[[], Resource]:
67+
"""Create a factory function for a resource configuration."""
68+
69+
def factory() -> Resource:
70+
return _parse_resource_config(config_str)
71+
72+
return factory
73+
74+
75+
def _load_named_resources_from_config() -> Dict[str, Callable[[], Resource]]:
76+
"""Load named resources from the configuration file."""
77+
config = _load_config_file()
78+
named_resources = {}
79+
80+
if config.has_section("named_resources"):
81+
for name, config_str in config.items("named_resources"):
82+
named_resources[name] = _create_resource_factory(config_str)
83+
84+
return named_resources
85+
86+
87+
# Load named resources from configuration
88+
NAMED_RESOURCES: Mapping[str, Callable[[], Resource]] = (
89+
_load_named_resources_from_config()
90+
)
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import json
10+
import unittest
11+
from configparser import ConfigParser
12+
from unittest.mock import patch
13+
14+
from torchx.specs.named_resources_config import (
15+
_create_resource_factory,
16+
_load_config_file,
17+
_load_named_resources_from_config,
18+
_parse_resource_config,
19+
NAMED_RESOURCES,
20+
)
21+
22+
23+
class ConfigNamedResourcesTest(unittest.TestCase):
24+
def test_parse_resource_config_basic(self) -> None:
25+
"""Test parsing basic resource configuration."""
26+
config_str = '{"cpu": 32, "gpu": 4, "memMB": 131072}'
27+
resource = _parse_resource_config(config_str)
28+
29+
self.assertEqual(resource.cpu, 32)
30+
self.assertEqual(resource.gpu, 4)
31+
self.assertEqual(resource.memMB, 131072)
32+
self.assertEqual(resource.capabilities, {})
33+
self.assertEqual(resource.devices, {})
34+
35+
def test_parse_resource_config_with_devices(self) -> None:
36+
"""Test parsing resource configuration with devices."""
37+
config_str = '{"cpu": 100, "gpu": 8, "memMB": 819200, "devices": {"vpc.amazonaws.com/efa": 1}}'
38+
resource = _parse_resource_config(config_str)
39+
40+
self.assertEqual(resource.cpu, 100)
41+
self.assertEqual(resource.gpu, 8)
42+
self.assertEqual(resource.memMB, 819200)
43+
self.assertEqual(resource.devices, {"vpc.amazonaws.com/efa": 1})
44+
45+
def test_parse_resource_config_with_capabilities(self) -> None:
46+
"""Test parsing resource configuration with capabilities."""
47+
config_str = '{"cpu": 64, "gpu": 0, "memMB": 262144, "capabilities": {"node.kubernetes.io/instance-type": "m5.16xlarge"}}'
48+
resource = _parse_resource_config(config_str)
49+
50+
self.assertEqual(resource.cpu, 64)
51+
self.assertEqual(resource.gpu, 0)
52+
self.assertEqual(resource.memMB, 262144)
53+
self.assertEqual(
54+
resource.capabilities, {"node.kubernetes.io/instance-type": "m5.16xlarge"}
55+
)
56+
57+
def test_parse_resource_config_defaults(self) -> None:
58+
"""Test parsing resource configuration with default values."""
59+
config_str = '{"cpu": 16, "memMB": 65536}'
60+
resource = _parse_resource_config(config_str)
61+
62+
self.assertEqual(resource.cpu, 16)
63+
self.assertEqual(resource.gpu, 0) # default
64+
self.assertEqual(resource.memMB, 65536)
65+
66+
def test_parse_resource_config_invalid_json(self) -> None:
67+
"""Test parsing invalid JSON configuration."""
68+
config_str = '{"cpu": 32, "gpu": 4, "memMB": 131072' # missing closing brace
69+
70+
with self.assertRaises(ValueError) as cm:
71+
_parse_resource_config(config_str)
72+
73+
self.assertIn("Invalid JSON", str(cm.exception))
74+
75+
def test_create_resource_factory(self) -> None:
76+
"""Test creating resource factory function."""
77+
config_str = '{"cpu": 8, "gpu": 1, "memMB": 32768}'
78+
factory = _create_resource_factory(config_str)
79+
80+
resource = factory()
81+
self.assertEqual(resource.cpu, 8)
82+
self.assertEqual(resource.gpu, 1)
83+
self.assertEqual(resource.memMB, 32768)
84+
85+
def test_load_config_file_not_found(self) -> None:
86+
"""Test loading config file when none exists."""
87+
with patch("os.path.exists", return_value=False):
88+
config = _load_config_file()
89+
self.assertFalse(config.sections())
90+
91+
def test_load_config_file_current_directory(self) -> None:
92+
"""Test loading config file from current directory."""
93+
with patch.dict("os.environ", {}, clear=True): # Clear TORCHXCONFIG
94+
with patch(
95+
"torchx.specs.named_resources_config.os.path.exists", return_value=True
96+
) as mock_exists:
97+
with patch("configparser.ConfigParser.read") as mock_read:
98+
_load_config_file()
99+
100+
# Verify the method was called with current directory path
101+
mock_exists.assert_called_with(".torchxconfig")
102+
mock_read.assert_called_with(".torchxconfig")
103+
104+
def test_load_config_file_with_torchxconfig_env(self) -> None:
105+
"""Test loading config file from TORCHXCONFIG environment variable."""
106+
temp_config_path = "/tmp/custom_torchx_config"
107+
108+
with patch.dict("os.environ", {"TORCHXCONFIG": temp_config_path}):
109+
with patch(
110+
"torchx.specs.named_resources_config.os.path.exists", return_value=True
111+
):
112+
with patch("configparser.ConfigParser.read") as mock_read:
113+
_load_config_file()
114+
115+
# Verify the method was called with the env var path
116+
mock_read.assert_called_with(temp_config_path)
117+
118+
def test_load_named_resources_from_config_empty(self) -> None:
119+
"""Test loading named resources when no config section exists."""
120+
with patch(
121+
"torchx.specs.named_resources_config._load_config_file"
122+
) as mock_load:
123+
mock_config = ConfigParser()
124+
mock_load.return_value = mock_config
125+
126+
resources = _load_named_resources_from_config()
127+
self.assertEqual(resources, {})
128+
129+
def test_load_named_resources_from_config_with_resources(self) -> None:
130+
"""Test loading named resources from config with valid resources."""
131+
with patch(
132+
"torchx.specs.named_resources_config._load_config_file"
133+
) as mock_load:
134+
mock_config = ConfigParser()
135+
mock_config.add_section("named_resources")
136+
mock_config.set(
137+
"named_resources",
138+
"test_resource",
139+
json.dumps({"cpu": 32, "gpu": 4, "memMB": 131072}),
140+
)
141+
mock_config.set(
142+
"named_resources",
143+
"gpu_resource",
144+
json.dumps(
145+
{
146+
"cpu": 64,
147+
"gpu": 8,
148+
"memMB": 262144,
149+
"devices": {"vpc.amazonaws.com/efa": 2},
150+
}
151+
),
152+
)
153+
mock_load.return_value = mock_config
154+
155+
resources = _load_named_resources_from_config()
156+
157+
self.assertIn("test_resource", resources)
158+
self.assertIn("gpu_resource", resources)
159+
160+
# Test the factory functions
161+
test_res = resources["test_resource"]()
162+
self.assertEqual(test_res.cpu, 32)
163+
self.assertEqual(test_res.gpu, 4)
164+
self.assertEqual(test_res.memMB, 131072)
165+
166+
gpu_res = resources["gpu_resource"]()
167+
self.assertEqual(gpu_res.cpu, 64)
168+
self.assertEqual(gpu_res.gpu, 8)
169+
self.assertEqual(gpu_res.memMB, 262144)
170+
self.assertEqual(gpu_res.devices, {"vpc.amazonaws.com/efa": 2})
171+
172+
def test_load_named_resources_from_config_invalid_json(self) -> None:
173+
"""Test loading named resources with invalid JSON (should fail when factory is called)."""
174+
with patch(
175+
"torchx.specs.named_resources_config._load_config_file"
176+
) as mock_load:
177+
mock_config = ConfigParser()
178+
mock_config.add_section("named_resources")
179+
mock_config.set(
180+
"named_resources",
181+
"valid_resource",
182+
json.dumps({"cpu": 32, "gpu": 4, "memMB": 131072}),
183+
)
184+
mock_config.set(
185+
"named_resources",
186+
"invalid_resource",
187+
'{"cpu": 32, "gpu": 4, "memMB": 131072',
188+
) # invalid JSON
189+
mock_load.return_value = mock_config
190+
191+
resources = _load_named_resources_from_config()
192+
193+
# Should have both resources (validation happens when factory is called)
194+
self.assertIn("valid_resource", resources)
195+
self.assertIn("invalid_resource", resources)
196+
197+
# Valid resource should work
198+
valid_res = resources["valid_resource"]()
199+
self.assertEqual(valid_res.cpu, 32)
200+
201+
# Invalid resource should raise exception when called
202+
with self.assertRaises(ValueError):
203+
resources["invalid_resource"]()
204+
205+
def test_named_resources_module_level(self) -> None:
206+
"""Test that NAMED_RESOURCES is properly loaded at module level."""
207+
# This tests the actual module-level NAMED_RESOURCES
208+
# The exact content depends on the actual .torchxconfig file present
209+
self.assertIsInstance(NAMED_RESOURCES, dict)
210+
211+
# Test that all values are callable factory functions
212+
for name, factory in NAMED_RESOURCES.items():
213+
self.assertTrue(callable(factory))
214+
# Test that calling the factory returns a Resource
215+
resource = factory()
216+
self.assertTrue(hasattr(resource, "cpu"))
217+
self.assertTrue(hasattr(resource, "gpu"))
218+
self.assertTrue(hasattr(resource, "memMB"))

0 commit comments

Comments
 (0)