|
6 | 6 | from azure.core.credentials import AzureKeyCredential
|
7 | 7 | from azure.search.documents.aio import SearchClient
|
8 | 8 | from azure.search.documents.indexes.aio import SearchIndexClient
|
| 9 | +from azure.search.documents.indexes.models import ( |
| 10 | + SearchFieldDataType, |
| 11 | + SearchIndex, |
| 12 | + SimpleField, |
| 13 | +) |
9 | 14 | from openai.types.create_embedding_response import Usage
|
10 | 15 |
|
11 | 16 | from prepdocslib.embeddings import AzureOpenAIEmbeddingService
|
@@ -75,20 +80,72 @@ async def mock_list_index_names(self):
|
75 | 80 |
|
76 | 81 | @pytest.mark.asyncio
|
77 | 82 | async def test_create_index_does_exist(monkeypatch, search_info):
|
78 |
| - indexes = [] |
| 83 | + created_indexes = [] |
| 84 | + updated_indexes = [] |
79 | 85 |
|
80 | 86 | async def mock_create_index(self, index):
|
81 |
| - indexes.append(index) |
| 87 | + created_indexes.append(index) |
| 88 | + |
| 89 | + async def mock_list_index_names(self): |
| 90 | + yield "test" |
| 91 | + |
| 92 | + async def mock_get_index(self, *args, **kwargs): |
| 93 | + return SearchIndex( |
| 94 | + name="test", |
| 95 | + fields=[ |
| 96 | + SimpleField( |
| 97 | + name="storageUrl", |
| 98 | + type=SearchFieldDataType.String, |
| 99 | + filterable=True, |
| 100 | + ) |
| 101 | + ], |
| 102 | + ) |
| 103 | + |
| 104 | + async def mock_create_or_update_index(self, index, *args, **kwargs): |
| 105 | + updated_indexes.append(index) |
| 106 | + |
| 107 | + monkeypatch.setattr(SearchIndexClient, "create_index", mock_create_index) |
| 108 | + monkeypatch.setattr(SearchIndexClient, "list_index_names", mock_list_index_names) |
| 109 | + monkeypatch.setattr(SearchIndexClient, "get_index", mock_get_index) |
| 110 | + monkeypatch.setattr(SearchIndexClient, "create_or_update_index", mock_create_or_update_index) |
| 111 | + |
| 112 | + manager = SearchManager(search_info) |
| 113 | + await manager.create_index() |
| 114 | + assert len(created_indexes) == 0, "It should not have created a new index" |
| 115 | + assert len(updated_indexes) == 0, "It should not have updated the existing index" |
| 116 | + |
| 117 | + |
| 118 | +@pytest.mark.asyncio |
| 119 | +async def test_create_index_add_field(monkeypatch, search_info): |
| 120 | + created_indexes = [] |
| 121 | + updated_indexes = [] |
| 122 | + |
| 123 | + async def mock_create_index(self, index): |
| 124 | + created_indexes.append(index) |
82 | 125 |
|
83 | 126 | async def mock_list_index_names(self):
|
84 | 127 | yield "test"
|
85 | 128 |
|
| 129 | + async def mock_get_index(self, *args, **kwargs): |
| 130 | + return SearchIndex( |
| 131 | + name="test", |
| 132 | + fields=[], |
| 133 | + ) |
| 134 | + |
| 135 | + async def mock_create_or_update_index(self, index, *args, **kwargs): |
| 136 | + updated_indexes.append(index) |
| 137 | + |
86 | 138 | monkeypatch.setattr(SearchIndexClient, "create_index", mock_create_index)
|
87 | 139 | monkeypatch.setattr(SearchIndexClient, "list_index_names", mock_list_index_names)
|
| 140 | + monkeypatch.setattr(SearchIndexClient, "get_index", mock_get_index) |
| 141 | + monkeypatch.setattr(SearchIndexClient, "create_or_update_index", mock_create_or_update_index) |
88 | 142 |
|
89 | 143 | manager = SearchManager(search_info)
|
90 | 144 | await manager.create_index()
|
91 |
| - assert len(indexes) == 0, "It should not have created a new index" |
| 145 | + assert len(created_indexes) == 0, "It should not have created a new index" |
| 146 | + assert len(updated_indexes) == 1, "It should have updated the existing index" |
| 147 | + assert len(updated_indexes[0].fields) == 1 |
| 148 | + assert updated_indexes[0].fields[0].name == "storageUrl" |
92 | 149 |
|
93 | 150 |
|
94 | 151 | @pytest.mark.asyncio
|
|
0 commit comments