|
8 | 8 | from data_algebra.cdata_impl import RecordMap |
9 | 9 | import data_algebra.yaml |
10 | 10 | from data_algebra.cdata_impl import record_map_from_simple_obj |
| 11 | +import data_algebra.util |
11 | 12 |
|
12 | 13 |
|
13 | 14 | def test_cdata_example(): |
@@ -50,8 +51,67 @@ def test_cdata_example(): |
50 | 51 | back = data_algebra.yaml.to_pipeline(yaml_obj) |
51 | 52 |
|
52 | 53 |
|
53 | | -def test_cdata_transport(): |
| 54 | +def test_keras_example(): |
54 | 55 | obj = {'blocks_out': {'record_keys': 'epoch', 'control_table_keys': 'measure', |
55 | 56 | 'control_table': {'measure': ['minus binary cross entropy', 'accuracy'], |
56 | 57 | 'training': ['loss', 'acc'], 'validation': ['val_loss', 'val_acc']}}} |
57 | | - record_map_from_simple_obj(obj) |
| 58 | + record_map = record_map_from_simple_obj(obj) |
| 59 | + data = pandas.DataFrame({ |
| 60 | + 'val_loss': [-0.377, -0.2997], |
| 61 | + 'val_acc': [0.8722, 0.8895], |
| 62 | + 'loss': [-0.5067, -0.3002], |
| 63 | + 'acc': [0.7852, 0.904], |
| 64 | + 'epoch': [1, 2], |
| 65 | + }) |
| 66 | + res = record_map.transform(data) |
| 67 | + expect = pandas.DataFrame({ |
| 68 | + 'epoch': [1, 1, 2, 2], |
| 69 | + 'measure': ['accuracy', 'minus binary cross entropy', 'accuracy', 'minus binary cross entropy'], |
| 70 | + 'training': [0.7852, -0.5067, 0.9040, -0.3002], |
| 71 | + 'validation': [0.8722, -0.3770, 0.8895, -0.2997], |
| 72 | + }) |
| 73 | + assert data_algebra.util.equivalent_frames(res, expect) |
| 74 | + |
| 75 | +def test_cdata_block(): |
| 76 | + data = pandas.DataFrame({ |
| 77 | + 'record_id': [1, 1, 1, 2, 2, 2], |
| 78 | + 'row': ['row1', 'row2', 'row3', 'row1', 'row2', 'row3'], |
| 79 | + 'col1': [1, 4, 7, 11, 14, 17], |
| 80 | + 'col2': [2, 5, 8, 12, 15, 18], |
| 81 | + 'col3': [3, 6, 9, 13, 16, 19], |
| 82 | + }) |
| 83 | + |
| 84 | + record_keys = ['record_id'] |
| 85 | + |
| 86 | + incoming_shape = pandas.DataFrame({ |
| 87 | + 'row': ['row1', 'row2', 'row3'], |
| 88 | + 'col1': ['v11', 'v21', 'v31'], |
| 89 | + 'col2': ['v12', 'v22', 'v32'], |
| 90 | + 'col3': ['v13', 'v23', 'v33'], |
| 91 | + }) |
| 92 | + |
| 93 | + outgoing_shape = pandas.DataFrame({ |
| 94 | + 'column_label': ['rec_col1', 'rec_col2', 'rec_col3'], |
| 95 | + 'c_row1': ['v11', 'v12', 'v13'], |
| 96 | + 'c_row2': ['v21', 'v22', 'v23'], |
| 97 | + 'c_row3': ['v31', 'v32', 'v33'], |
| 98 | + }) |
| 99 | + |
| 100 | + record_map = data_algebra.cdata_impl.RecordMap( |
| 101 | + blocks_in=data_algebra.cdata.RecordSpecification( |
| 102 | + control_table=incoming_shape, |
| 103 | + record_keys=record_keys |
| 104 | + ), |
| 105 | + blocks_out=data_algebra.cdata.RecordSpecification( |
| 106 | + control_table=outgoing_shape, |
| 107 | + record_keys=record_keys |
| 108 | + ), |
| 109 | + ) |
| 110 | + |
| 111 | + res = record_map.transform(data) |
| 112 | + |
| 113 | + inv = record_map.inverse() |
| 114 | + |
| 115 | + back = inv.transform(res) |
| 116 | + |
| 117 | + assert data_algebra.util.equivalent_frames(data, back) |
0 commit comments