@@ -317,25 +317,34 @@ def split_multiple_shapes_into_blocks(
317317 strides : Optional [PerMember [PerAxis [int ]]] = None ,
318318 broadcast : bool = False ,
319319) -> Tuple [TotalNumberOfBlocks , Iterable [PerMember [BlockMeta ]]]:
320- assert not (
321- missing := [t for t in block_shapes if t not in shapes ]
322- ), f"block shape specified for unknown tensors: { missing } "
320+ if unknown_blocks := [t for t in block_shapes if t not in shapes ]:
321+ raise ValueError (
322+ f"block shape specified for unknown tensors: { unknown_blocks } ."
323+ )
324+
323325 if not block_shapes :
324326 block_shapes = shapes
325327
326- assert broadcast or not (
327- missing := [t for t in shapes if t not in block_shapes ]
328- ), f"no block shape specified for { missing } (set `broadcast` to True if these tensors should be repeated for each block)"
329- assert not (
330- missing := [t for t in halo if t not in block_shapes ]
331- ), f"`halo` specified for tensors without block shape: { missing } "
328+ if not broadcast and (
329+ missing_blocks := [t for t in shapes if t not in block_shapes ]
330+ ):
331+ raise ValueError (
332+ f"no block shape specified for { missing_blocks } ."
333+ + " Set `broadcast` to True if these tensors should be repeated"
334+ + " as a whole for each block."
335+ )
336+
337+ if extra_halo := [t for t in halo if t not in block_shapes ]:
338+ raise ValueError (
339+ f"`halo` specified for tensors without block shape: { extra_halo } ."
340+ )
332341
333342 if strides is None :
334343 strides = {}
335344
336345 assert not (
337- missing := [t for t in strides if t not in block_shapes ]
338- ), f"`stride` specified for tensors without block shape: { missing } "
346+ unknown_block := [t for t in strides if t not in block_shapes ]
347+ ), f"`stride` specified for tensors without block shape: { unknown_block } "
339348
340349 blocks : Dict [MemberId , Iterable [BlockMeta ]] = {}
341350 n_blocks : Dict [MemberId , TotalNumberOfBlocks ] = {}
@@ -355,8 +364,9 @@ def split_multiple_shapes_into_blocks(
355364 if len (unique_n_blocks ) == 2 and 1 in unique_n_blocks :
356365 if not broadcast :
357366 raise ValueError (
358- f"Mismatch for total number of blocks due to unsplit (single block) tensors: { n_blocks } ."
359- + " Set `broadcast` to True if you want to repeat unsplit (single block) tensors."
367+ "Mismatch for total number of blocks due to unsplit (single block)"
368+ + f" tensors: { n_blocks } . Set `broadcast` to True if you want to"
369+ + " repeat unsplit (single block) tensors."
360370 )
361371
362372 blocks = {
0 commit comments