@@ -9,7 +9,6 @@ def float_mtx(n_obs, n_vars, NAs=False):
99 mtx [0 , 0 ] = np .nan
1010 return mtx
1111
12-
1312def int_mtx (n_obs , n_vars ):
1413 mtx = np .arange (n_obs * n_vars ).reshape (n_obs , n_vars )
1514 return mtx
@@ -21,6 +20,20 @@ def float_mtx_nd(nr_values, dimensions, NAs=False):
2120 mtx [0 , 0 ] = np .nan
2221 return mtx
2322
23+ def int_mtx_nd (nr_values , dimensions , NAs = False ):
24+ mtx = np .arange (nr_values ).reshape (dimensions )
25+ if NAs : # numpy matrices do no support pd.NA
26+ mtx [0 , 0 ] = np .nan
27+ return mtx
28+
29+ def float_mtx_sparse_nd (nr_values , dimensions , row_major = True , NAs = False ):
30+ mtx = float_mtx_nd (nr_values , dimensions , NAs )
31+ if row_major :
32+ return sp .sparse .csr_matrix (mtx )
33+ else :
34+ return sp .sparse .csc_matrix (mtx )
35+
36+
2437# Possible matrix generators
2538# integer matrices do not support NAs in Python
2639matrix_generators = {
@@ -34,6 +47,20 @@ def float_mtx_nd(nr_values, dimensions, NAs=False):
3447 "integer_csparse" : lambda n_obs , n_vars : sp .sparse .csc_matrix (int_mtx (n_obs , n_vars )),
3548 "integer_rsparse" : lambda n_obs , n_vars : sp .sparse .csr_matrix (int_mtx (n_obs , n_vars )),
3649 "float_matrix_3d" : lambda n_obs , n_vars : float_mtx_nd (n_obs * n_vars * 3 , (n_obs , n_vars , 3 )),
50+ "integer_matrix_3d" : lambda n_obs , n_vars : int_mtx_nd (n_obs * n_vars * 3 , (n_obs , n_vars , 3 )),
51+ }
52+
53+ def string_matrix_nd (nr_values , dimensions ):
54+ return np .array (['a' for _ in range (nr_values )]).reshape (dimensions )
55+
56+ def bool_matrix_nd (nr_values , dimensions ):
57+ return np .array ([True if i % 2 else False for i in range (nr_values )]).reshape (dimensions )
58+
59+ extra_uns_matrix_generators = {
60+ "string_matrix" : lambda n_obs , n_vars : np .array (['a' for _ in range (n_obs * n_vars )]).reshape (n_obs , n_vars ),
61+ "bool_matrix" : lambda n_obs , n_vars : np .array ([True for _ in range (n_obs * n_vars )]).reshape (n_obs , n_vars ),
62+ "string_matrix_3d" : lambda n_obs , n_vars : string_matrix_nd (n_obs * n_vars * 3 , (n_obs , n_vars , 3 )),
63+ "bool_matrix_3d" : lambda n_obs , n_vars : bool_matrix_nd (n_obs * n_vars * 3 , (n_obs , n_vars , 3 )),
3764}
3865
3966generated_matrix_types = np .ndarray | sp .sparse .csc_matrix | sp .sparse .csr_matrix
@@ -58,6 +85,10 @@ def generate_matrix(n_obs: int, n_vars: int, matrix_type: str) -> generated_matr
5885 AssertionError: If the matrix_type is unknown.
5986
6087 """
61- assert matrix_type in matrix_generators , f"Unknown matrix type: { matrix_type } "
88+ assert matrix_type in matrix_generators .keys () or matrix_type in extra_uns_matrix_generators .keys (), f"Unknown matrix type: { matrix_type } "
89+
90+ if matrix_type in matrix_generators :
91+ return matrix_generators [matrix_type ](n_obs , n_vars )
92+ else :
93+ return extra_uns_matrix_generators [matrix_type ](n_obs , n_vars )
6294
63- return matrix_generators [matrix_type ](n_obs , n_vars )
0 commit comments