Skip to content

Commit bae6ae7

Browse files
committed
add 1D array support for independent data generator
1 parent 3ac236e commit bae6ae7

File tree

4 files changed

+26
-6
lines changed

4 files changed

+26
-6
lines changed

synthia/generators/copula.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _log(self, msg: str) -> None:
7272
print(msg, flush=True)
7373

7474
def fit(self, data: Union[np.ndarray, xr.DataArray, xr.Dataset],
75-
copula: Copula, qrng=False,
75+
copula: Copula,
7676
parameterize_by: Optional[Union[Parameterizer, Dict[int, Parameterizer], Dict[str, Parameterizer]]]=None):
7777
"""tbd
7878

synthia/generators/independent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def fit(self, data: Union[np.ndarray, xr.DataArray, xr.Dataset],
1616
1717
Args:
1818
data (ndarray or DataArray or Dataset): The input data, either a
19-
2D array of shape (sample, feature) or a dataset where all
20-
variables have the shape (sample[, ...]).
19+
1D array, a 2D array of shape (sample, feature)
20+
or a dataset where all variables have the shape (sample[, ...]).
2121
2222
parameterize_by (Parameterizer or mapping, optional): The
2323
following forms are valid:
@@ -30,7 +30,7 @@ def fit(self, data: Union[np.ndarray, xr.DataArray, xr.Dataset],
3030
None
3131
"""
3232

33-
data, self.data_info = to_feature_array(data)
33+
data, self.data_info = to_feature_array(data, allow_1d=True)
3434

3535
self.dtype = data.dtype
3636
self.n_features = data.shape[1]

synthia/util.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def to_unstacked_dataset(arr: np.ndarray, stack_info: StackInfo) -> xr.Dataset:
9595
ds = xr.Dataset(unstacked)
9696
return ds
9797

98-
def to_feature_array(data: Union[np.ndarray, xr.DataArray, xr.Dataset]) -> Tuple[xr.DataArray, dict]:
98+
def to_feature_array(data: Union[np.ndarray, xr.DataArray, xr.Dataset], allow_1d=False) -> Tuple[xr.DataArray, dict]:
9999
# TODO what about dtype?
100100
data_info = {}
101101
if isinstance(data, xr.Dataset):
@@ -111,14 +111,21 @@ def to_feature_array(data: Union[np.ndarray, xr.DataArray, xr.Dataset]) -> Tuple
111111
attrs=data.attrs
112112
)
113113
data = xr.DataArray(data)
114-
assert data.ndim == 2, f'Input array must be 2D, given: {data.ndim}'
114+
if allow_1d and data.ndim == 1:
115+
data_info['is_1d'] = True
116+
data = data.expand_dims(dim='__feature', axis=1)
117+
assert data.ndim == 2, f'Input array must be {"1D/" if allow_1d else ""}2D, given: {data.ndim}D'
115118
data_info['n_features'] = data.shape[1]
116119
return data, data_info
117120

118121
def from_feature_array(data: np.ndarray, data_info: dict) -> Union[np.ndarray, xr.DataArray, xr.Dataset]:
119122
stack_info = data_info.get('stack_info')
120123
if stack_info:
121124
return to_unstacked_dataset(data, stack_info)
125+
is_1d = data_info.get('is_1d')
126+
if is_1d:
127+
assert data.shape[1] == 1
128+
data = data[:,0]
122129
da_info = data_info.get('da_info')
123130
if da_info:
124131
return xr.DataArray(data, **da_info)

tests/test_generators.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def test_independent_dataset_generation():
3535
assert synthetic_data['a'].shape == (n_synthetic_samples, n_features[0])
3636
assert synthetic_data['b'].shape == (n_synthetic_samples, n_features[1])
3737

38+
def test_independent_1d_feature_generation():
39+
n_samples = 200
40+
input_data = np.random.normal(size=n_samples)
41+
42+
generator = syn.IndependentDataGenerator()
43+
44+
generator.fit(input_data)
45+
46+
n_synthetic_samples = 50
47+
synthetic_data = generator.generate(n_samples=n_synthetic_samples)
48+
49+
assert synthetic_data.shape == (n_synthetic_samples,)
50+
3851
def test_independent_feature_generation_with_distribution():
3952
n_samples = 20
4053
n_features = 2

0 commit comments

Comments
 (0)