Skip to content

Commit 1c137ff

Browse files
authored
feat(utils): add to_bool for consistent boolean conversion (#1259)
* Introduces a `to_bool` utility function that converts string representations ("true", "false") to boolean values. * This function is now used by `is_async`, `skip_bundle`, and `allow_dirty` to ensure consistent boolean parsing from string inputs.
1 parent fafdf23 commit 1c137ff

File tree

4 files changed

+43
-5
lines changed

4 files changed

+43
-5
lines changed

axlearn/cloud/common/bundler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
copy_blobs,
6363
get_pyproject_version,
6464
parse_kv_flags,
65+
to_bool,
6566
)
6667
from axlearn.common.config import REQUIRED, Configurable, Required, config_class
6768
from axlearn.common.file_system import copy, exists, makedirs
@@ -341,9 +342,17 @@ def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> Config
341342
cfg: BaseDockerBundler.Config = super().from_spec(spec, fv=fv)
342343
kwargs = parse_kv_flags(spec, delimiter="=")
343344
cache_from = canonicalize_to_list(kwargs.pop("cache_from", None))
345+
skip_bundle = to_bool(kwargs.pop("skip_bundle", False))
346+
allow_dirty = to_bool(kwargs.pop("allow_dirty", False))
344347
# Non-config specs are treated as build args.
345348
build_args = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k not in cfg}
346-
return cfg.set(build_args=build_args, cache_from=cache_from, **kwargs)
349+
return cfg.set(
350+
build_args=build_args,
351+
cache_from=cache_from,
352+
skip_bundle=skip_bundle,
353+
allow_dirty=allow_dirty,
354+
**kwargs,
355+
)
347356

348357
# pylint: disable-next=arguments-renamed
349358
def id(self, tag: str) -> str:

axlearn/cloud/common/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,19 @@ def merge(base: dict, overrides: dict):
292292
return base
293293

294294

295+
def to_bool(value: Any) -> bool:
296+
"""Converts a string representation of truth to a bool."""
297+
if isinstance(value, bool):
298+
return value
299+
elif isinstance(value, str):
300+
val_lower = value.lower()
301+
if val_lower == "true":
302+
return True
303+
elif val_lower == "false":
304+
return False
305+
raise ValueError(f"Invalid truth value: '{value}'")
306+
307+
295308
_Row = list[Any]
296309

297310

axlearn/cloud/common/utils_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,24 @@ def test_canonicalize(self, v_seq: Sequence[str], v_str: str, v_list: str, delim
222222
def test_merge(self, base, overrides, expected):
223223
self.assertEqual(expected, utils.merge(base, overrides))
224224

225+
@parameterized.parameters(
226+
("true", True),
227+
("True", True),
228+
("false", False),
229+
("False", False),
230+
(True, True),
231+
(False, False),
232+
("yes", ValueError),
233+
(1, ValueError),
234+
)
235+
def test_to_bool(self, value, expected):
236+
if isinstance(expected, type) and issubclass(expected, Exception):
237+
with self.assertRaises(expected):
238+
utils.to_bool(value)
239+
else:
240+
result = utils.to_bool(value)
241+
self.assertEqual(result, expected)
242+
225243
def test_infer_resources(self):
226244
@config_class
227245
class DummyConfig(ConfigBase):

axlearn/cloud/gcp/bundler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from axlearn.cloud.common.bundler import main_flags as bundler_main_flags
5959
from axlearn.cloud.common.bundler import register_bundler
6060
from axlearn.cloud.common.docker import registry_from_repo
61-
from axlearn.cloud.common.utils import canonicalize_to_list
61+
from axlearn.cloud.common.utils import canonicalize_to_list, to_bool
6262
from axlearn.cloud.gcp.cloud_build import get_cloud_build_status
6363
from axlearn.cloud.gcp.config import gcp_settings
6464
from axlearn.cloud.gcp.utils import common_flags
@@ -148,9 +148,7 @@ def from_spec(
148148
cfg.project = cfg.project or gcp_settings("project", required=False, fv=fv)
149149
cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv)
150150
cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv)
151-
# The value from from_spec is a str and will result in wrong condition.
152-
if isinstance(cfg.is_async, str):
153-
cfg.is_async = cfg.is_async.lower() != "false"
151+
cfg.is_async = to_bool(cfg.is_async)
154152
return cfg
155153

156154
# pylint: disable-next=no-self-use,unused-argument

0 commit comments

Comments
 (0)