Skip to content

Commit 0f24e63

Browse files
authored
Fixes swapping of data and feature dimension to work in the general case. (#214)
Previous implementation was broken as using transpose assumes that `data_list` is a 2D array. However, in certain cases (when all the feature values array lengths are the same) the `data_list` can be a 3D array as the call to `data_list = np.array(list(dataset.as_numpy_iterator()), dtype=object)` merges inner np arrays and converts `data_list` into one big 3D array.
1 parent ebbaa2f commit 0f24e63

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

compiler_opt/tools/sparse_bucket_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def main(_) -> None:
170170
parser_fn = create_tfrecord_parser_fn(sequence_features)
171171
dataset = dataset.map(parser_fn, num_parallel_calls=tf.data.AUTOTUNE)
172172
data_list = np.array(list(dataset.as_numpy_iterator()), dtype=object)
173-
data_list = np.transpose(data_list, [1, 0])
173+
data_list = data_list.swapaxes(0, 1)
174174

175175
with mp.Pool(FLAGS.parallelism) as pool:
176176
feature_names = list(sorted(sequence_features))

0 commit comments

Comments
 (0)