|
5 | 5 | import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordKeyAttribute; |
6 | 6 | import com.microsoft.semantickernel.data.recordattributes.VectorStoreRecordVectorAttribute; |
7 | 7 |
|
8 | | -import javax.annotation.Nonnull; |
9 | | -import javax.annotation.Nullable; |
10 | 8 | import java.lang.reflect.Field; |
11 | 9 | import java.util.ArrayList; |
12 | 10 | import java.util.Collections; |
13 | 11 | import java.util.HashSet; |
14 | 12 | import java.util.List; |
| 13 | +import java.util.Set; |
15 | 14 | import java.util.stream.Collectors; |
16 | 15 |
|
17 | 16 | /** |
@@ -50,6 +49,50 @@ public List<VectorStoreRecordField> getAllFields() { |
50 | 49 | return fields; |
51 | 50 | } |
52 | 51 |
|
| 52 | + public List<VectorStoreRecordField> getNonVectorFields() { |
| 53 | + List<VectorStoreRecordField> fields = new ArrayList<>(); |
| 54 | + fields.add(keyField); |
| 55 | + fields.addAll(dataFields); |
| 56 | + return fields; |
| 57 | + } |
| 58 | + |
| 59 | + private List<Field> getDeclaredFields(Class<?> recordClass, List<VectorStoreRecordField> fields, String fieldType) { |
| 60 | + List<Field> declaredFields = new ArrayList<>(); |
| 61 | + for (VectorStoreRecordField field : fields) { |
| 62 | + try { |
| 63 | + Field declaredField = recordClass.getDeclaredField(field.getName()); |
| 64 | + declaredFields.add(declaredField); |
| 65 | + } catch (NoSuchFieldException e) { |
| 66 | + throw new IllegalArgumentException( |
| 67 | + String.format("%s field not found in record class: %s", fieldType, field.getName())); |
| 68 | + } |
| 69 | + } |
| 70 | + return declaredFields; |
| 71 | + } |
| 72 | + |
| 73 | + public Field getKeyDeclaredField(Class<?> recordClass) { |
| 74 | + try { |
| 75 | + return recordClass.getDeclaredField(keyField.getName()); |
| 76 | + } catch (NoSuchFieldException e) { |
| 77 | + throw new IllegalArgumentException( |
| 78 | + "Key field not found in record class: " + keyField.getName()); |
| 79 | + } |
| 80 | + } |
| 81 | + |
| 82 | + public List<Field> getDataDeclaredFields(Class<?> recordClass) { |
| 83 | + return getDeclaredFields( |
| 84 | + recordClass, |
| 85 | + dataFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), |
| 86 | + "Data"); |
| 87 | + } |
| 88 | + |
| 89 | + public List<Field> getVectorDeclaredFields(Class<?> recordClass) { |
| 90 | + return getDeclaredFields( |
| 91 | + recordClass, |
| 92 | + vectorFields.stream().map(f -> (VectorStoreRecordField) f).collect(Collectors.toList()), |
| 93 | + "Vector"); |
| 94 | + } |
| 95 | + |
53 | 96 | private VectorStoreRecordDefinition( |
54 | 97 | VectorStoreRecordKeyField keyField, |
55 | 98 | List<VectorStoreRecordDataField> dataFields, |
@@ -148,71 +191,19 @@ public static VectorStoreRecordDefinition fromRecordClass(Class<?> recordClass) |
148 | 191 | return checkFields(keyFields, dataFields, vectorFields); |
149 | 192 | } |
150 | 193 |
|
151 | | - private static String getSupportedTypesString(@Nullable HashSet<Class<?>> types) { |
152 | | - if (types == null || types.isEmpty()) { |
153 | | - return ""; |
154 | | - } |
155 | | - return types.stream().map(Class::getName).collect(Collectors.joining(", ")); |
156 | | - } |
157 | | - |
158 | | - public static void validateSupportedKeyTypes(@Nonnull Class<?> recordClass, |
159 | | - @Nonnull VectorStoreRecordDefinition recordDefinition, |
160 | | - @Nonnull HashSet<Class<?>> supportedTypes) { |
161 | | - String supportedTypesString = getSupportedTypesString(supportedTypes); |
162 | | - |
163 | | - try { |
164 | | - Field declaredField = recordClass.getDeclaredField(recordDefinition.keyField.getName()); |
165 | 194 |
|
| 195 | + public static void validateSupportedTypes(List<Field> declaredFields, Set<Class<?>> supportedTypes) { |
| 196 | + Set<Class<?>> unsupportedTypes = new HashSet<>(); |
| 197 | + for (Field declaredField : declaredFields) { |
166 | 198 | if (!supportedTypes.contains(declaredField.getType())) { |
167 | | - throw new IllegalArgumentException( |
168 | | - "Unsupported key field type: " + declaredField.getType().getName() |
169 | | - + ". Supported types are: " + supportedTypesString); |
170 | | - } |
171 | | - } catch (NoSuchFieldException e) { |
172 | | - throw new IllegalArgumentException( |
173 | | - "Key field not found in record class: " + recordDefinition.keyField.getName()); |
174 | | - } |
175 | | - } |
176 | | - |
177 | | - public static void validateSupportedDataTypes(@Nonnull Class<?> recordClass, |
178 | | - @Nonnull VectorStoreRecordDefinition recordDefinition, |
179 | | - @Nonnull HashSet<Class<?>> supportedTypes) { |
180 | | - String supportedTypesString = getSupportedTypesString(supportedTypes); |
181 | | - |
182 | | - for (VectorStoreRecordDataField field : recordDefinition.dataFields) { |
183 | | - try { |
184 | | - Field declaredField = recordClass.getDeclaredField(field.getName()); |
185 | | - |
186 | | - if (!supportedTypes.contains(declaredField.getType())) { |
187 | | - throw new IllegalArgumentException( |
188 | | - "Unsupported data field type: " + declaredField.getType().getName() |
189 | | - + ". Supported types are: " + supportedTypesString); |
190 | | - } |
191 | | - } catch (NoSuchFieldException e) { |
192 | | - throw new IllegalArgumentException( |
193 | | - "Data field not found in record class: " + field.getName()); |
| 199 | + unsupportedTypes.add(declaredField.getType()); |
194 | 200 | } |
195 | 201 | } |
196 | | - } |
197 | | - |
198 | | - public static void validateSupportedVectorTypes(@Nonnull Class<?> recordClass, |
199 | | - @Nonnull VectorStoreRecordDefinition recordDefinition, |
200 | | - @Nonnull HashSet<Class<?>> supportedTypes) { |
201 | | - String supportedTypesString = getSupportedTypesString(supportedTypes); |
202 | | - |
203 | | - for (VectorStoreRecordVectorField field : recordDefinition.vectorFields) { |
204 | | - try { |
205 | | - Field declaredField = recordClass.getDeclaredField(field.getName()); |
206 | | - |
207 | | - if (!supportedTypes.contains(declaredField.getType())) { |
208 | | - throw new IllegalArgumentException( |
209 | | - "Unsupported vector field type: " + declaredField.getType().getName() |
210 | | - + ". Supported types are: " + supportedTypesString); |
211 | | - } |
212 | | - } catch (NoSuchFieldException e) { |
213 | | - throw new IllegalArgumentException( |
214 | | - "Vector field not found in record class: " + field.getName()); |
215 | | - } |
| 202 | + if (!unsupportedTypes.isEmpty()) { |
| 203 | + throw new IllegalArgumentException( |
| 204 | + String.format("Unsupported field types found in record class: %s. Supported types: %s", |
| 205 | + unsupportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")), |
| 206 | + supportedTypes.stream().map(Class::getName).collect(Collectors.joining(", ")))); |
216 | 207 | } |
217 | 208 | } |
218 | 209 | } |
0 commit comments