@@ -562,11 +562,16 @@ def directsum(reprs: List[escnn.group.Representation],
562562 irreps += r .irreps
563563
564564 size = sum ([r .size for r in reprs ])
565-
566- cob = np .zeros ((size , size ))
567- cob_inv = np .zeros ((size , size ))
565+
566+ # Determine the dtype for the change of basis diagonal matrix to avoid unsafe casting.
567+ dtype = np .complex if np .any ([np .iscomplexobj (rep .change_of_basis ) for rep in reprs ]) else np .float
568+
569+ cob = np .zeros ((size , size ), dtype = dtype )
570+ cob_inv = np .zeros ((size , size ), dtype = dtype )
568571 p = 0
569572 for r in reprs :
573+ assert np .can_cast (r .change_of_basis .dtype , cob .dtype ), \
574+ f"Cannot safely cast { r .change_of_basis .dtype } to { cob .dtype } "
570575 cob [p :p + r .size , p :p + r .size ] = r .change_of_basis
571576 cob_inv [p :p + r .size , p :p + r .size ] = r .change_of_basis_inv
572577 p += r .size
@@ -580,7 +585,9 @@ def directsum(reprs: List[escnn.group.Representation],
580585
581586 supported_nonlinearities = set .intersection (* [r .supported_nonlinearities for r in reprs ])
582587
583- return Representation (group , name , irreps , change_of_basis , supported_nonlinearities , change_of_basis_inv = change_of_basis_inv )
588+ return Representation (
589+ group , name , irreps , change_of_basis , supported_nonlinearities , change_of_basis_inv = change_of_basis_inv
590+ )
584591
585592
586593def disentangle (repr : Representation ) -> Tuple [np .ndarray , List [Representation ]]:
0 commit comments