|
16 | 16 | limitations under the License. |
17 | 17 | """ |
18 | 18 |
|
| 19 | +import mock |
19 | 20 | import pytest |
20 | 21 |
|
21 | | -from toolium.utils.ai_utils.accuracy import get_accuracy_and_retries_from_tags |
| 22 | +from toolium.utils.ai_utils.accuracy import (get_accuracy_and_retries_from_tags, get_accuracy_data_suffix_from_tags, |
| 23 | + get_accuracy_data, store_retry_data) |
22 | 24 |
|
23 | 25 |
|
24 | 26 | accuracy_tags_examples = ( |
|
31 | 33 | (['no_accuracy_tag'], None), |
32 | 34 | (['accuracy_85', 'accuracy_95_15'], {'accuracy': 0.85, 'retries': 10}), |
33 | 35 | ([], None), |
| 36 | + (['accuracy_data', 'accuracy_data_50'], None), |
| 37 | + (['accuracy_75_5', 'accuracy_data'], {'accuracy': 0.75, 'retries': 5}) |
34 | 38 | ) |
35 | 39 |
|
36 | 40 |
|
37 | 41 | @pytest.mark.parametrize('tags, expected_accuracy_data', accuracy_tags_examples) |
38 | 42 | def test_get_accuracy_and_retries_from_tags(tags, expected_accuracy_data): |
39 | 43 | accuracy_data = get_accuracy_and_retries_from_tags(tags) |
40 | 44 | assert accuracy_data == expected_accuracy_data |
| 45 | + |
| 46 | + |
| 47 | +accuracy_tags_examples = ( |
| 48 | + (['accuracy'], 8, {'accuracy': 0.9, 'retries': 8}), |
| 49 | + (['accuracy_85'], 8, {'accuracy': 0.85, 'retries': 8}), |
| 50 | + (['accuracy_percent_80'], 8, {'accuracy': 0.8, 'retries': 8}), |
| 51 | + (['accuracy_75_5'], 8, {'accuracy': 0.75, 'retries': 5}), |
| 52 | + (['accuracy_percent_70_retries_3'], 8, {'accuracy': 0.7, 'retries': 3}), |
| 53 | + (['other_tag', 'accuracy_95_15'], 8, {'accuracy': 0.95, 'retries': 15}), |
| 54 | + (['no_accuracy_tag'], 8, None), |
| 55 | + (['accuracy_85', 'accuracy_95_15'], 8, {'accuracy': 0.85, 'retries': 8}), |
| 56 | + ([], 8, None), |
| 57 | + (['accuracy_data', 'accuracy_data_50'], 8, None), |
| 58 | + (['accuracy_75_5', 'accuracy_data'], 8, {'accuracy': 0.75, 'retries': 5}) |
| 59 | +) |
| 60 | + |
| 61 | + |
| 62 | +@pytest.mark.parametrize('tags, data_length, expected_accuracy_data', accuracy_tags_examples) |
| 63 | +def test_get_accuracy_and_retries_from_tags_with_data_length(tags, data_length, expected_accuracy_data): |
| 64 | + accuracy_data = get_accuracy_and_retries_from_tags(tags, accuracy_data_len=data_length) |
| 65 | + assert accuracy_data == expected_accuracy_data |
| 66 | + |
| 67 | + |
| 68 | +accuracy_data_suffix_examples = ( |
| 69 | + (['accuracy_data'], ''), |
| 70 | + (['accuracy_data_balance'], '_balance'), |
| 71 | + (['accuracy_data_balance_50'], '_balance_50'), |
| 72 | + (['other_tag', 'accuracy_data_transactions'], '_transactions'), |
| 73 | + (['no_accuracy_data_tag'], ''), |
| 74 | + (['accuracy', 'accuracy_85', 'accuracy_percent_70_retries_3'], ''), |
| 75 | + ([], '') |
| 76 | +) |
| 77 | + |
| 78 | + |
| 79 | +@pytest.mark.parametrize('tags, expected_data_suffix', accuracy_data_suffix_examples) |
| 80 | +def test_get_accuracy_data_suffix_from_tags(tags, expected_data_suffix): |
| 81 | + data_suffix = get_accuracy_data_suffix_from_tags(tags) |
| 82 | + assert data_suffix == expected_data_suffix |
| 83 | + |
| 84 | + |
| 85 | +@pytest.fixture |
| 86 | +def context(): |
| 87 | + context = mock.MagicMock() |
| 88 | + context.storage = { |
| 89 | + 'accuracy_data': [{'question': 'Q1', 'answer': 'A1'}, |
| 90 | + {'question': 'Q2', 'answer': 'A2'}], |
| 91 | + 'accuracy_data_balance': [{'question': 'Q1 balance', 'answer': 'A1'}, |
| 92 | + {'question': 'Q2 balance', 'answer': 'A2'}], |
| 93 | + 'accuracy_data_wrong': "This is not a list" |
| 94 | + } |
| 95 | + return context |
| 96 | + |
| 97 | + |
| 98 | +def test_get_accuracy_data_default_suffix(context): |
| 99 | + data = get_accuracy_data(context, data_key_suffix='') |
| 100 | + assert data == [{'question': 'Q1', 'answer': 'A1'}, |
| 101 | + {'question': 'Q2', 'answer': 'A2'}] |
| 102 | + |
| 103 | + |
| 104 | +def test_get_accuracy_data_with_suffix(context): |
| 105 | + data = get_accuracy_data(context, data_key_suffix='_balance') |
| 106 | + assert data == [{'question': 'Q1 balance', 'answer': 'A1'}, |
| 107 | + {'question': 'Q2 balance', 'answer': 'A2'}] |
| 108 | + |
| 109 | + |
| 110 | +def test_get_accuracy_data_with_nonexistent_suffix(context): |
| 111 | + data = get_accuracy_data(context, data_key_suffix='_nonexistent') |
| 112 | + assert data == [] |
| 113 | + |
| 114 | + |
| 115 | +def test_get_accuracy_data_with_wrong_type(context): |
| 116 | + with pytest.raises(AssertionError) as exc: |
| 117 | + get_accuracy_data(context, data_key_suffix='_wrong') |
| 118 | + assert str(exc.value) == 'Expected accuracy_data_wrong must be a list: This is not a list' |
| 119 | + |
| 120 | + |
| 121 | +def test_store_retry_data_default_suffix(context): |
| 122 | + store_retry_data(context, retry=1, data_key_suffix='') |
| 123 | + assert context.storage['accuracy_retry_data'] == {'question': 'Q1', 'answer': 'A1'} |
| 124 | + assert context.storage['accuracy_retry_index'] == 1 |
| 125 | + store_retry_data(context, retry=2, data_key_suffix='') |
| 126 | + assert context.storage['accuracy_retry_data'] == {'question': 'Q2', 'answer': 'A2'} |
| 127 | + assert context.storage['accuracy_retry_index'] == 2 |
| 128 | + store_retry_data(context, retry=3, data_key_suffix='') |
| 129 | + assert context.storage['accuracy_retry_data'] == {'question': 'Q1', 'answer': 'A1'} |
| 130 | + assert context.storage['accuracy_retry_index'] == 3 |
| 131 | + |
| 132 | + |
| 133 | +def test_store_retry_data_with_suffix(context): |
| 134 | + store_retry_data(context, retry=1, data_key_suffix='_balance') |
| 135 | + assert context.storage['accuracy_retry_data'] == {'question': 'Q1 balance', 'answer': 'A1'} |
| 136 | + assert context.storage['accuracy_retry_index'] == 1 |
| 137 | + store_retry_data(context, retry=2, data_key_suffix='_balance') |
| 138 | + assert context.storage['accuracy_retry_data'] == {'question': 'Q2 balance', 'answer': 'A2'} |
| 139 | + assert context.storage['accuracy_retry_index'] == 2 |
| 140 | + store_retry_data(context, retry=3, data_key_suffix='_balance') |
| 141 | + assert context.storage['accuracy_retry_data'] == {'question': 'Q1 balance', 'answer': 'A1'} |
| 142 | + assert context.storage['accuracy_retry_index'] == 3 |
| 143 | + |
| 144 | + |
| 145 | +def test_store_retry_data_with_nonexistent_suffix(context): |
| 146 | + store_retry_data(context, retry=1, data_key_suffix='_nonexistent') |
| 147 | + assert context.storage['accuracy_retry_data'] is None |
| 148 | + assert context.storage['accuracy_retry_index'] == 1 |
0 commit comments