@@ -74,6 +74,7 @@ def generate_flm(
7474 spin : int = 0 ,
7575 reality : bool = False ,
7676 using_torch : bool = False ,
77+ size : tuple [int , ...] | int | None = None ,
7778) -> np .ndarray | torch .Tensor :
7879 r"""
7980 Generate a 2D set of random harmonic coefficients.
@@ -94,29 +95,39 @@ def generate_flm(
9495
9596 using_torch (bool, optional): Desired frontend functionality. Defaults to False.
9697
98+ size (tuple[int, ...] | int | None, optional): Shape of realisations.
99+
97100 Returns:
98101 np.ndarray: Random set of spherical harmonic coefficients.
99102
100103 """
101- flm = np .zeros (samples .flm_shape (L ), dtype = np .complex128 )
104+ # always turn size into a tuple of int
105+ if size is None :
106+ size = ()
107+ elif isinstance (size , int ):
108+ size = (size ,)
109+ elif not (isinstance (size , tuple ) and all (isinstance (_ , int ) for _ in size )):
110+ raise TypeError ("size must be int or tuple of int" )
111+
112+ flm = np .zeros ((* size , * samples .flm_shape (L )), dtype = np .complex128 )
102113 min_el = max (L_lower , abs (spin ))
103114 # m = 0 coefficients are always real
104- flm [min_el :L , L - 1 ] = rng .standard_normal (L - min_el )
115+ flm [..., min_el :L , L - 1 ] = rng .standard_normal (( * size , L - min_el ) )
105116 # Construct arrays of m and el indices for entries in flm corresponding to complex-
106117 # valued coefficients (m > 0)
107118 el_indices , m_indices = complex_el_and_m_indices (L , min_el )
108- len_indices = len (m_indices )
119+ rand_size = ( * size , len (m_indices ) )
109120 # Generate independent complex coefficients for positive m
110- flm [el_indices , L - 1 + m_indices ] = complex_normal (rng , len_indices , var = 2 )
121+ flm [..., el_indices , L - 1 + m_indices ] = complex_normal (rng , rand_size , var = 2 )
111122 if reality :
112123 # Real-valued signal so set complex coefficients for negative m using conjugate
113124 # symmetry such that flm[el, L - 1 - m] = (-1)**m * flm[el, L - 1 + m].conj
114- flm [el_indices , L - 1 - m_indices ] = (- 1 ) ** m_indices * (
115- flm [el_indices , L - 1 + m_indices ].conj ()
125+ flm [..., el_indices , L - 1 - m_indices ] = (- 1 ) ** m_indices * (
126+ flm [..., el_indices , L - 1 + m_indices ].conj ()
116127 )
117128 else :
118129 # Non-real signal so generate independent complex coefficients for negative m
119- flm [el_indices , L - 1 - m_indices ] = complex_normal (rng , len_indices , var = 2 )
130+ flm [..., el_indices , L - 1 - m_indices ] = complex_normal (rng , rand_size , var = 2 )
120131 return torch .from_numpy (flm ) if using_torch else flm
121132
122133
0 commit comments