Skip to content

Commit 9901d44

Browse files
KarhouTampytorchmergebot
authored andcommitted
[torch/utils][Code Clean] Clean asserts in torch/utils/*.py (pytorch#165410)
Including: - `torch/utils/*.py` Fixes part of pytorch#164878 Pull Request resolved: pytorch#165410 Approved by: https://github.com/albanD, https://github.com/cyyever
1 parent 6096c0f commit 9901d44

21 files changed

+273
-136
lines changed

torch/utils/_appending_byte_serializer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,22 @@ def to_bytes(self) -> bytes:
3838
digest = zlib.crc32(self._data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
3939
4, byteorder="big", signed=False
4040
)
41-
assert len(digest) == CHECKSUM_DIGEST_SIZE
41+
if len(digest) != CHECKSUM_DIGEST_SIZE:
42+
raise AssertionError("Computed checksum digest has unexpected size")
4243
self._data[0:CHECKSUM_DIGEST_SIZE] = digest
4344
return bytes(self._data)
4445

4546

4647
class BytesReader:
4748
def __init__(self, data: bytes) -> None:
4849
# Check for data corruption
49-
assert len(data) >= CHECKSUM_DIGEST_SIZE
50+
if len(data) < CHECKSUM_DIGEST_SIZE:
51+
raise AssertionError("Input data is too short to contain checksum")
5052
digest = zlib.crc32(data[CHECKSUM_DIGEST_SIZE:]).to_bytes(
5153
4, byteorder="big", signed=False
5254
)
53-
assert len(digest) == CHECKSUM_DIGEST_SIZE
55+
if len(digest) != CHECKSUM_DIGEST_SIZE:
56+
raise AssertionError("Computed checksum digest has unexpected size")
5457
if data[0:CHECKSUM_DIGEST_SIZE] != digest:
5558
raise RuntimeError(
5659
"Bytes object is corrupted, checksum does not match. "
@@ -120,7 +123,11 @@ def to_bytes(self) -> bytes:
120123
@staticmethod
121124
def to_list(data: bytes, *, deserialize_fn: Callable[[BytesReader], T]) -> list[T]:
122125
reader = BytesReader(data)
123-
assert reader.read_uint64() == _ENCODING_VERSION
126+
if reader.read_uint64() != _ENCODING_VERSION:
127+
raise AssertionError(
128+
f"Encoding version mismatch in AppendingByteSerializer.to_list, \
129+
got {reader.read_uint64()}"
130+
)
124131

125132
result: list[T] = []
126133
while not reader.is_finished():

torch/utils/_config_module.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,16 @@ def __post_init__(self) -> None:
8585
)
8686

8787
if self.alias is not None:
88-
assert (
89-
self.default is _UNSET_SENTINEL
90-
and self.justknob is None
91-
and self.env_name_default is None
92-
and self.env_name_force is None
93-
), "if alias is set, none of {default, justknob and env var} can be set"
88+
if (
89+
self.default is not _UNSET_SENTINEL
90+
or self.justknob is not None
91+
or self.env_name_default is not None
92+
or self.env_name_force is not None
93+
):
94+
raise AssertionError(
95+
"if alias is set, none of {default, justknob, \
96+
env_name_default and env_name_force} can be set"
97+
)
9498

9599
@staticmethod
96100
def string_or_list_of_string_to_list(
@@ -100,7 +104,8 @@ def string_or_list_of_string_to_list(
100104
return None
101105
if isinstance(val, str):
102106
return [val]
103-
assert isinstance(val, list)
107+
if not isinstance(val, list):
108+
raise AssertionError(f"val is not a list, got {type(val)}")
104109
return val
105110

106111

@@ -193,7 +198,10 @@ def visit(
193198
if dest is module:
194199
delattr(module, key)
195200
elif isinstance(value, type):
196-
assert value.__module__ == module.__name__
201+
if value.__module__ != module.__name__:
202+
raise AssertionError(
203+
f"subconfig class {value} must be defined in module {module.__name__}"
204+
)
197205
# a subconfig with `class Blah:` syntax
198206
proxy = SubConfigProxy(module, f"{name}.")
199207
visit(value, proxy, f"{name}.")
@@ -234,10 +242,8 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
234242
prev_name = ""
235243
maybe_current = token.string.strip()
236244
if COMPILE_IGNORED_MARKER in maybe_current:
237-
assert current_comment == (
238-
"",
239-
-1,
240-
), f"unconsumed {COMPILE_IGNORED_MARKER}"
245+
if current_comment != ("", -1):
246+
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
241247
current_comment = maybe_current, token.start[0]
242248
elif token.type == tokenize.NAME:
243249
# Only accept the first name token, to handle if you have
@@ -254,7 +260,8 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> set[str
254260
assignments.add(prev_name)
255261
current_comment = "", -1 # reset
256262
prev_name = ""
257-
assert current_comment == ("", -1), f"unconsumed {COMPILE_IGNORED_MARKER}"
263+
if current_comment != ("", -1):
264+
raise AssertionError(f"unconsumed {COMPILE_IGNORED_MARKER}")
258265
return assignments
259266

260267

@@ -306,20 +313,22 @@ def __init__(self, config: _Config):
306313

307314
# Ensure justknobs and envvars are allowlisted types
308315
if self.justknob is not None and self.default is not None:
309-
assert isinstance(self.default, bool), (
310-
f"justknobs only support booleans, {self.default} is not a boolean"
311-
)
316+
if not isinstance(self.default, bool):
317+
raise AssertionError(
318+
f"justknobs only support booleans, {self.default} is not a boolean"
319+
)
312320
if self.value_type is not None and (
313321
config.env_name_default is not None or config.env_name_force is not None
314322
):
315-
assert self.value_type in (
323+
if self.value_type not in (
316324
bool,
317325
str,
318326
Optional[bool],
319327
Optional[str],
320-
), (
321-
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
322-
)
328+
):
329+
raise AssertionError(
330+
f"envvar configs only support (optional) booleans or strings, {self.value_type} is neither"
331+
)
323332

324333

325334
class ConfigModule(ModuleType):
@@ -417,7 +426,10 @@ def _get_alias_val(self, entry: _ConfigEntry) -> Any:
417426

418427
def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
419428
data = self._get_alias_module_and_name(entry)
420-
assert data is not None
429+
if data is None:
430+
raise AssertionError(
431+
"alias data should not be None when setting alias value"
432+
)
421433
module, constant_name = data
422434
setattr(module, constant_name, val)
423435

@@ -642,19 +654,32 @@ def foo(...):
642654
changes: dict[str, Any]
643655
if arg1 is not None:
644656
if arg2 is not None:
645-
assert isinstance(arg1, str)
657+
if not isinstance(arg1, str):
658+
raise AssertionError(
659+
"first argument must be a string when passing 2 positional args to patch"
660+
)
646661
# patch("key", True) syntax
647662
changes = {arg1: arg2}
648663
else:
649-
assert isinstance(arg1, dict)
664+
if not isinstance(arg1, dict):
665+
raise AssertionError(
666+
"first argument must be a dict when passing a single positional arg to patch"
667+
)
650668
# patch({"key": True}) syntax
651669
changes = arg1
652-
assert not kwargs
670+
if kwargs:
671+
raise AssertionError(
672+
"cannot pass both positional and keyword arguments to patch"
673+
)
653674
else:
654675
# patch(key=True) syntax
655676
changes = kwargs
656-
assert arg2 is None
657-
assert isinstance(changes, dict), f"expected `dict` got {type(changes)}"
677+
if arg2 is not None:
678+
raise AssertionError(
679+
"second positional argument is only valid when first argument is a key string"
680+
)
681+
if not isinstance(changes, dict):
682+
raise AssertionError(f"expected `dict` got {type(changes)}")
658683
prior: dict[str, Any] = {}
659684
config = self
660685

@@ -663,7 +688,10 @@ def __init__(self) -> None:
663688
self.changes = changes
664689

665690
def __enter__(self) -> None:
666-
assert not prior
691+
if prior:
692+
raise AssertionError(
693+
"prior should be empty when entering ConfigPatch"
694+
)
667695
for key in self.changes.keys():
668696
# KeyError on invalid entry
669697
prior[key] = config.__getattr__(key)

torch/utils/_config_typing.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ This file should be imported into any file that uses install_config_module like
2121
Note that the import should happen before the call to install_config_module(), otherwise runtime errors may occur.
2222
"""
2323

24-
assert TYPE_CHECKING, "Do not use at runtime"
24+
if not TYPE_CHECKING: # noqa: PYI002
25+
raise AssertionError("Do not use at runtime") # noqa: W291
2526

2627
def save_config() -> bytes: ...
2728
def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ...

torch/utils/_content_store.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,10 @@ def read_storage(self, h: str, *, device=None) -> torch.UntypedStorage:
217217
weights_only=True,
218218
map_location=device,
219219
)._untyped_storage
220-
assert s is not None
220+
if s is None:
221+
raise AssertionError(
222+
f"expected storage for hash {h} in {os.path.join(self.loc, 'storages')}, got None"
223+
)
221224
if self.storage_cache is not None:
222225
self.storage_cache[device][h] = StorageWeakRef(s)
223226
return s

torch/utils/_contextlib.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,14 @@ def context_decorator(ctx, func):
8686
be a multi-shot context manager that can be directly invoked multiple times)
8787
or a callable that produces a context manager.
8888
"""
89-
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
90-
f"Passed in {ctx} is both callable and also a valid context manager "
91-
"(has __enter__), making it ambiguous which interface to use. If you "
92-
"intended to pass a context manager factory, rewrite your call as "
93-
"context_decorator(lambda: ctx()); if you intended to pass a context "
94-
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
95-
)
89+
if callable(ctx) and hasattr(ctx, "__enter__"):
90+
raise AssertionError(
91+
f"Passed in {ctx} is both callable and also a valid context manager "
92+
"(has __enter__), making it ambiguous which interface to use. If you "
93+
"intended to pass a context manager factory, rewrite your call as "
94+
"context_decorator(lambda: ctx()); if you intended to pass a context "
95+
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
96+
)
9697

9798
if not callable(ctx):
9899

torch/utils/_cxx_pytree.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,10 @@ def _broadcast_to_and_flatten(
931931
treespec: TreeSpec,
932932
is_leaf: Optional[Callable[[PyTree], bool]] = None,
933933
) -> Optional[list[Any]]:
934-
assert _is_pytreespec_instance(treespec)
934+
if not _is_pytreespec_instance(treespec):
935+
raise AssertionError(
936+
f"_broadcast_to_and_flatten: Expected `treespec` to be instance of PyTreeSpec but got {type(treespec)}"
937+
)
935938
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
936939
try:
937940
return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)

torch/utils/_device.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,18 @@ def __exit__(self, exc_type, exc_val, exc_tb):
8787
# or else someone else has popped it!
8888
for _ in range(_len_torch_function_stack() - 1):
8989
mode = _pop_mode()
90-
assert not isinstance(mode, DeviceContext)
90+
if isinstance(mode, DeviceContext):
91+
raise AssertionError(
92+
"Found nested DeviceContext on the mode stack where none expected"
93+
)
9194
cur_stack.append(mode)
9295

9396
if _len_torch_function_stack() > 0:
9497
mode = _pop_mode()
95-
assert isinstance(mode, DeviceContext)
98+
if not isinstance(mode, DeviceContext):
99+
raise AssertionError(
100+
"Expected a DeviceContext at the bottom of the mode stack"
101+
)
96102

97103
for mode in reversed(cur_stack):
98104
_push_mode(mode)

torch/utils/_functools.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def cache_method(
3131

3232
@functools.wraps(f)
3333
def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T:
34-
assert not kwargs
34+
if kwargs:
35+
raise AssertionError("cache_method does not accept keyword arguments")
3536
if not (cache := getattr(self, cache_name, None)):
3637
cache = {}
3738
setattr(self, cache_name, cache)

0 commit comments

Comments
 (0)