|
| 1 | +package badger |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "time" |
| 6 | + |
| 7 | + "github.com/fuseml/fuseml-core/pkg/domain" |
| 8 | + "github.com/timshannon/badgerhold/v3" |
| 9 | +) |
| 10 | + |
| 11 | +// ExtensionStore is a wrapper around a badgerhold.Store that implements the domain.ExtensionStore interface. |
| 12 | +type ExtensionStore struct { |
| 13 | + store *badgerhold.Store |
| 14 | +} |
| 15 | + |
| 16 | +// NewExtensionStore creates a new ExtensionStore. |
| 17 | +func NewExtensionStore(store *badgerhold.Store) *ExtensionStore { |
| 18 | + return &ExtensionStore{store: store} |
| 19 | +} |
| 20 | + |
| 21 | +// AddExtension adds a new extension to the store. |
| 22 | +func (es *ExtensionStore) AddExtension(ctx context.Context, extension *domain.Extension) (*domain.Extension, error) { |
| 23 | + extension.EnsureID(ctx, es) |
| 24 | + extension.SetCreated(ctx) |
| 25 | + |
| 26 | + err := es.store.Insert(extension.ID, extension) |
| 27 | + if err != nil { |
| 28 | + return nil, domain.NewErrExtensionExists(extension.ID) |
| 29 | + } |
| 30 | + return extension, nil |
| 31 | +} |
| 32 | + |
| 33 | +// GetExtension retrieves an extension by its ID. |
| 34 | +func (es *ExtensionStore) GetExtension(ctx context.Context, extensionID string) (*domain.Extension, error) { |
| 35 | + extension := &domain.Extension{} |
| 36 | + err := es.store.Get(extensionID, extension) |
| 37 | + if err != nil { |
| 38 | + return nil, domain.NewErrExtensionNotFound(extensionID) |
| 39 | + } |
| 40 | + return extension, nil |
| 41 | +} |
| 42 | + |
| 43 | +// ListExtensions retrieves all stored extensions. |
| 44 | +func (es *ExtensionStore) ListExtensions(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.Extension) { |
| 45 | + result = []*domain.Extension{} |
| 46 | + |
| 47 | + // TODO: Replace with a badgerhold query. |
| 48 | + if query != nil { |
| 49 | + if query.ExtensionID != "" { |
| 50 | + fullExtension, err := es.GetExtension(ctx, query.ExtensionID) |
| 51 | + if err == nil { |
| 52 | + matchingExtension := fullExtension.GetExtensionIfMatch(query) |
| 53 | + if matchingExtension != nil { |
| 54 | + result = append(result, matchingExtension) |
| 55 | + } |
| 56 | + } |
| 57 | + return |
| 58 | + } |
| 59 | + |
| 60 | + allExtensions := []*domain.Extension{} |
| 61 | + es.store.Find(&allExtensions, nil) |
| 62 | + |
| 63 | + for _, extension := range allExtensions { |
| 64 | + matchingExtension := extension.GetExtensionIfMatch(query) |
| 65 | + if matchingExtension != nil { |
| 66 | + result = append(result, matchingExtension) |
| 67 | + } |
| 68 | + } |
| 69 | + return |
| 70 | + } |
| 71 | + |
| 72 | + es.store.Find(&result, nil) |
| 73 | + return |
| 74 | +} |
| 75 | + |
| 76 | +// UpdateExtension updates an existing extension. |
| 77 | +func (es *ExtensionStore) UpdateExtension(ctx context.Context, newExtension *domain.Extension) error { |
| 78 | + extension, err := es.GetExtension(ctx, newExtension.ID) |
| 79 | + if err != nil { |
| 80 | + return err |
| 81 | + } |
| 82 | + newExtension.Created = extension.Created |
| 83 | + newExtension.Updated = time.Now() |
| 84 | + |
| 85 | + for _, newExtService := range newExtension.ListServices() { |
| 86 | + _, err := extension.GetService(newExtService.ID) |
| 87 | + if err != nil { |
| 88 | + // If the service is new, set the creation time |
| 89 | + newExtService.SetCreated(newExtension.Updated) |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + err = es.store.Update(newExtension.ID, newExtension) |
| 94 | + if err != nil { |
| 95 | + return domain.NewErrExtensionNotFound(newExtension.ID) |
| 96 | + } |
| 97 | + return nil |
| 98 | +} |
| 99 | + |
| 100 | +// DeleteExtension deletes an extension from the store. |
| 101 | +func (es *ExtensionStore) DeleteExtension(ctx context.Context, extensionID string) error { |
| 102 | + extension, err := es.GetExtension(ctx, extensionID) |
| 103 | + if err != nil { |
| 104 | + return err |
| 105 | + } |
| 106 | + return es.store.Delete(extension.ID, extension) |
| 107 | +} |
| 108 | + |
| 109 | +// AddExtensionService adds a new extension service to an extension. |
| 110 | +func (es *ExtensionStore) AddExtensionService(ctx context.Context, extensionID string, service *domain.ExtensionService) (*domain.ExtensionService, error) { |
| 111 | + extension, err := es.GetExtension(ctx, extensionID) |
| 112 | + if err != nil { |
| 113 | + return nil, err |
| 114 | + } |
| 115 | + svc, err := extension.AddService(service) |
| 116 | + if err != nil { |
| 117 | + return nil, err |
| 118 | + } |
| 119 | + err = es.UpdateExtension(ctx, extension) |
| 120 | + if err != nil { |
| 121 | + return nil, err |
| 122 | + } |
| 123 | + return svc, nil |
| 124 | +} |
| 125 | + |
| 126 | +// GetExtensionService retrieves an extension service by its ID. |
| 127 | +func (es *ExtensionStore) GetExtensionService(ctx context.Context, extensionID string, serviceID string) (*domain.ExtensionService, error) { |
| 128 | + extension, err := es.GetExtension(ctx, extensionID) |
| 129 | + if err != nil { |
| 130 | + return nil, err |
| 131 | + } |
| 132 | + return extension.GetService(serviceID) |
| 133 | +} |
| 134 | + |
| 135 | +// ListExtensionServices retrieves all services belonging to an extension. |
| 136 | +func (es *ExtensionStore) ListExtensionServices(ctx context.Context, extensionID string) ([]*domain.ExtensionService, error) { |
| 137 | + extension, err := es.GetExtension(ctx, extensionID) |
| 138 | + if err != nil { |
| 139 | + return nil, err |
| 140 | + } |
| 141 | + return extension.ListServices(), nil |
| 142 | +} |
| 143 | + |
| 144 | +// UpdateExtensionService updates a service belonging to an extension. |
| 145 | +func (es *ExtensionStore) UpdateExtensionService(ctx context.Context, extensionID string, newService *domain.ExtensionService) error { |
| 146 | + extension, err := es.GetExtension(ctx, extensionID) |
| 147 | + if err != nil { |
| 148 | + return err |
| 149 | + } |
| 150 | + err = extension.UpdateService(newService) |
| 151 | + if err != nil { |
| 152 | + return err |
| 153 | + } |
| 154 | + return es.UpdateExtension(ctx, extension) |
| 155 | +} |
| 156 | + |
| 157 | +// DeleteExtensionService deletes an extension service from an extension. |
| 158 | +func (es *ExtensionStore) DeleteExtensionService(ctx context.Context, extensionID string, serviceID string) error { |
| 159 | + extension, err := es.GetExtension(ctx, extensionID) |
| 160 | + if err != nil { |
| 161 | + return err |
| 162 | + } |
| 163 | + err = extension.DeleteService(serviceID) |
| 164 | + if err != nil { |
| 165 | + return err |
| 166 | + } |
| 167 | + return es.UpdateExtension(ctx, extension) |
| 168 | +} |
| 169 | + |
| 170 | +// AddExtensionServiceEndpoint adds a new endpoint to an extension service. |
| 171 | +func (es *ExtensionStore) AddExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpoint *domain.ExtensionServiceEndpoint) (*domain.ExtensionServiceEndpoint, error) { |
| 172 | + extension, err := es.GetExtension(ctx, extensionID) |
| 173 | + if err != nil { |
| 174 | + return nil, err |
| 175 | + } |
| 176 | + svc, err := extension.GetService(serviceID) |
| 177 | + if err != nil { |
| 178 | + return nil, err |
| 179 | + } |
| 180 | + endpoint, err = svc.AddEndpoint(endpoint) |
| 181 | + if err != nil { |
| 182 | + return nil, err |
| 183 | + } |
| 184 | + err = es.UpdateExtension(ctx, extension) |
| 185 | + if err != nil { |
| 186 | + return nil, err |
| 187 | + } |
| 188 | + return endpoint, nil |
| 189 | +} |
| 190 | + |
| 191 | +// GetExtensionServiceEndpoint retrieves an extension endpoint by its ID. |
| 192 | +func (es *ExtensionStore) GetExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) (*domain.ExtensionServiceEndpoint, error) { |
| 193 | + extension, err := es.GetExtension(ctx, extensionID) |
| 194 | + if err != nil { |
| 195 | + return nil, err |
| 196 | + } |
| 197 | + svc, err := extension.GetService(serviceID) |
| 198 | + if err != nil { |
| 199 | + return nil, err |
| 200 | + } |
| 201 | + return svc.GetEndpoint(endpointID) |
| 202 | +} |
| 203 | + |
| 204 | +// ListExtensionServiceEndpoints retrieves all endpoints belonging to an extension service. |
| 205 | +func (es *ExtensionStore) ListExtensionServiceEndpoints(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceEndpoint, error) { |
| 206 | + extension, err := es.GetExtension(ctx, extensionID) |
| 207 | + if err != nil { |
| 208 | + return nil, err |
| 209 | + } |
| 210 | + svc, err := extension.GetService(serviceID) |
| 211 | + if err != nil { |
| 212 | + return nil, err |
| 213 | + } |
| 214 | + return svc.ListEndpoints(), nil |
| 215 | +} |
| 216 | + |
| 217 | +// UpdateExtensionServiceEndpoint updates an endpoint belonging to an extension service. |
| 218 | +func (es *ExtensionStore) UpdateExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, newEndpoint *domain.ExtensionServiceEndpoint) error { |
| 219 | + extension, err := es.GetExtension(ctx, extensionID) |
| 220 | + if err != nil { |
| 221 | + return err |
| 222 | + } |
| 223 | + svc, err := extension.GetService(serviceID) |
| 224 | + if err != nil { |
| 225 | + return err |
| 226 | + } |
| 227 | + err = svc.UpdateEndpoint(newEndpoint) |
| 228 | + if err != nil { |
| 229 | + return err |
| 230 | + } |
| 231 | + return es.UpdateExtension(ctx, extension) |
| 232 | +} |
| 233 | + |
| 234 | +// DeleteExtensionServiceEndpoint deletes an extension endpoint from an extension service. |
| 235 | +func (es *ExtensionStore) DeleteExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) error { |
| 236 | + extension, err := es.GetExtension(ctx, extensionID) |
| 237 | + if err != nil { |
| 238 | + return err |
| 239 | + } |
| 240 | + svc, err := extension.GetService(serviceID) |
| 241 | + if err != nil { |
| 242 | + return err |
| 243 | + } |
| 244 | + err = svc.DeleteEndpoint(endpointID) |
| 245 | + if err != nil { |
| 246 | + return err |
| 247 | + } |
| 248 | + return es.UpdateExtension(ctx, extension) |
| 249 | +} |
| 250 | + |
| 251 | +// AddExtensionServiceCredentials adds a new credential to an extension service. |
| 252 | +func (es *ExtensionStore) AddExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentials *domain.ExtensionServiceCredentials) (*domain.ExtensionServiceCredentials, error) { |
| 253 | + extension, err := es.GetExtension(ctx, extensionID) |
| 254 | + if err != nil { |
| 255 | + return nil, err |
| 256 | + } |
| 257 | + svc, err := extension.GetService(serviceID) |
| 258 | + if err != nil { |
| 259 | + return nil, err |
| 260 | + } |
| 261 | + credentials, err = svc.AddCredentials(credentials) |
| 262 | + if err != nil { |
| 263 | + return nil, err |
| 264 | + } |
| 265 | + err = es.UpdateExtension(ctx, extension) |
| 266 | + if err != nil { |
| 267 | + return nil, err |
| 268 | + } |
| 269 | + return credentials, nil |
| 270 | +} |
| 271 | + |
| 272 | +// GetExtensionServiceCredentials retrieves an extension credential by its ID. |
| 273 | +func (es *ExtensionStore) GetExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) (*domain.ExtensionServiceCredentials, error) { |
| 274 | + extension, err := es.GetExtension(ctx, extensionID) |
| 275 | + if err != nil { |
| 276 | + return nil, err |
| 277 | + } |
| 278 | + svc, err := extension.GetService(serviceID) |
| 279 | + if err != nil { |
| 280 | + return nil, err |
| 281 | + } |
| 282 | + return svc.GetCredentials(credentialsID) |
| 283 | +} |
| 284 | + |
| 285 | +// ListExtensionServiceCredentials retrieves all credentials belonging to an extension service. |
| 286 | +func (es *ExtensionStore) ListExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceCredentials, error) { |
| 287 | + extension, err := es.GetExtension(ctx, extensionID) |
| 288 | + if err != nil { |
| 289 | + return nil, err |
| 290 | + } |
| 291 | + svc, err := extension.GetService(serviceID) |
| 292 | + if err != nil { |
| 293 | + return nil, err |
| 294 | + } |
| 295 | + return svc.ListCredentials(), nil |
| 296 | +} |
| 297 | + |
| 298 | +// UpdateExtensionServiceCredentials updates an extension credential. |
| 299 | +func (es *ExtensionStore) UpdateExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, newCredentials *domain.ExtensionServiceCredentials) error { |
| 300 | + extension, err := es.GetExtension(ctx, extensionID) |
| 301 | + if err != nil { |
| 302 | + return err |
| 303 | + } |
| 304 | + svc, err := extension.GetService(serviceID) |
| 305 | + if err != nil { |
| 306 | + return err |
| 307 | + } |
| 308 | + err = svc.UpdateCredentials(newCredentials) |
| 309 | + if err != nil { |
| 310 | + return err |
| 311 | + } |
| 312 | + return es.UpdateExtension(ctx, extension) |
| 313 | +} |
| 314 | + |
| 315 | +// DeleteExtensionServiceCredentials deletes an extension credential from an extension service. |
| 316 | +func (es *ExtensionStore) DeleteExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) error { |
| 317 | + extension, err := es.GetExtension(ctx, extensionID) |
| 318 | + if err != nil { |
| 319 | + return err |
| 320 | + } |
| 321 | + svc, err := extension.GetService(serviceID) |
| 322 | + if err != nil { |
| 323 | + return err |
| 324 | + } |
| 325 | + err = svc.DeleteCredentials(credentialsID) |
| 326 | + if err != nil { |
| 327 | + return err |
| 328 | + } |
| 329 | + return es.UpdateExtension(ctx, extension) |
| 330 | +} |
| 331 | + |
| 332 | +// GetExtensionAccessDescriptors retrieves access descriptors belonging to an extension that matches the query. |
| 333 | +func (es *ExtensionStore) GetExtensionAccessDescriptors(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.ExtensionAccessDescriptor, err error) { |
| 334 | + result = make([]*domain.ExtensionAccessDescriptor, 0) |
| 335 | + |
| 336 | + for _, extension := range es.ListExtensions(ctx, query) { |
| 337 | + result = append(result, extension.GetAccessDescriptors()...) |
| 338 | + } |
| 339 | + return result, nil |
| 340 | +} |
0 commit comments