Skip to content

Commit 9831efc

Browse files
committed
Handle dtypes during direcsum
1 parent 7340d57 commit 9831efc

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

escnn/group/representation.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,11 +562,16 @@ def directsum(reprs: List[escnn.group.Representation],
562562
irreps += r.irreps
563563

564564
size = sum([r.size for r in reprs])
565-
566-
cob = np.zeros((size, size))
567-
cob_inv = np.zeros((size, size))
565+
566+
# Determine the dtype for the change of basis diagonal matrix to avoid unsafe casting.
567+
dtype = np.complex if np.any([np.iscomplexobj(rep.change_of_basis) for rep in reprs]) else np.float
568+
569+
cob = np.zeros((size, size), dtype=dtype)
570+
cob_inv = np.zeros((size, size), dtype=dtype)
568571
p = 0
569572
for r in reprs:
573+
assert np.can_cast(r.change_of_basis.dtype, cob.dtype), \
574+
f"Cannot safely cast {r.change_of_basis.dtype} to {cob.dtype}"
570575
cob[p:p + r.size, p:p + r.size] = r.change_of_basis
571576
cob_inv[p:p + r.size, p:p + r.size] = r.change_of_basis_inv
572577
p += r.size
@@ -580,7 +585,9 @@ def directsum(reprs: List[escnn.group.Representation],
580585

581586
supported_nonlinearities = set.intersection(*[r.supported_nonlinearities for r in reprs])
582587

583-
return Representation(group, name, irreps, change_of_basis, supported_nonlinearities, change_of_basis_inv=change_of_basis_inv)
588+
return Representation(
589+
group, name, irreps, change_of_basis, supported_nonlinearities, change_of_basis_inv=change_of_basis_inv
590+
)
584591

585592

586593
def disentangle(repr: Representation) -> Tuple[np.ndarray, List[Representation]]:

0 commit comments

Comments
 (0)