Skip to content

Commit 3a4bd7f

Browse files
authored
Merge pull request #112 from flaviodsr/extension_badger_store_02
2 parents 73a4618 + d04c506 commit 3a4bd7f

File tree

4 files changed

+1227
-5
lines changed

4 files changed

+1227
-5
lines changed

cmd/fuseml_core/wire.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ var storeSet = wire.NewSet(
3939
wire.Bind(new(domain.RunnableStore), new(*core.RunnableStore)),
4040
badger.NewWorkflowStore,
4141
wire.Bind(new(domain.WorkflowStore), new(*badger.WorkflowStore)),
42-
core.NewExtensionStore,
43-
wire.Bind(new(domain.ExtensionStore), new(*core.ExtensionStore)),
42+
badger.NewExtensionStore,
43+
wire.Bind(new(domain.ExtensionStore), new(*badger.ExtensionStore)),
4444
)
4545

4646
var managerSet = wire.NewSet(

cmd/fuseml_core/wire_gen.go

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/core/store/badger/extension.go

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
package badger
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/fuseml/fuseml-core/pkg/domain"
8+
"github.com/timshannon/badgerhold/v3"
9+
)
10+
11+
// ExtensionStore is a wrapper around a badgerhold.Store that implements the domain.ExtensionStore interface.
12+
type ExtensionStore struct {
13+
store *badgerhold.Store
14+
}
15+
16+
// NewExtensionStore creates a new ExtensionStore.
17+
func NewExtensionStore(store *badgerhold.Store) *ExtensionStore {
18+
return &ExtensionStore{store: store}
19+
}
20+
21+
// AddExtension adds a new extension to the store.
22+
func (es *ExtensionStore) AddExtension(ctx context.Context, extension *domain.Extension) (*domain.Extension, error) {
23+
extension.EnsureID(ctx, es)
24+
extension.SetCreated(ctx)
25+
26+
err := es.store.Insert(extension.ID, extension)
27+
if err != nil {
28+
return nil, domain.NewErrExtensionExists(extension.ID)
29+
}
30+
return extension, nil
31+
}
32+
33+
// GetExtension retrieves an extension by its ID.
34+
func (es *ExtensionStore) GetExtension(ctx context.Context, extensionID string) (*domain.Extension, error) {
35+
extension := &domain.Extension{}
36+
err := es.store.Get(extensionID, extension)
37+
if err != nil {
38+
return nil, domain.NewErrExtensionNotFound(extensionID)
39+
}
40+
return extension, nil
41+
}
42+
43+
// ListExtensions retrieves all stored extensions.
44+
func (es *ExtensionStore) ListExtensions(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.Extension) {
45+
result = []*domain.Extension{}
46+
47+
// TODO: Replace with a badgerhold query.
48+
if query != nil {
49+
if query.ExtensionID != "" {
50+
fullExtension, err := es.GetExtension(ctx, query.ExtensionID)
51+
if err == nil {
52+
matchingExtension := fullExtension.GetExtensionIfMatch(query)
53+
if matchingExtension != nil {
54+
result = append(result, matchingExtension)
55+
}
56+
}
57+
return
58+
}
59+
60+
allExtensions := []*domain.Extension{}
61+
es.store.Find(&allExtensions, nil)
62+
63+
for _, extension := range allExtensions {
64+
matchingExtension := extension.GetExtensionIfMatch(query)
65+
if matchingExtension != nil {
66+
result = append(result, matchingExtension)
67+
}
68+
}
69+
return
70+
}
71+
72+
es.store.Find(&result, nil)
73+
return
74+
}
75+
76+
// UpdateExtension updates an existing extension.
77+
func (es *ExtensionStore) UpdateExtension(ctx context.Context, newExtension *domain.Extension) error {
78+
extension, err := es.GetExtension(ctx, newExtension.ID)
79+
if err != nil {
80+
return err
81+
}
82+
newExtension.Created = extension.Created
83+
newExtension.Updated = time.Now()
84+
85+
for _, newExtService := range newExtension.ListServices() {
86+
_, err := extension.GetService(newExtService.ID)
87+
if err != nil {
88+
// If the service is new, set the creation time
89+
newExtService.SetCreated(newExtension.Updated)
90+
}
91+
}
92+
93+
err = es.store.Update(newExtension.ID, newExtension)
94+
if err != nil {
95+
return domain.NewErrExtensionNotFound(newExtension.ID)
96+
}
97+
return nil
98+
}
99+
100+
// DeleteExtension deletes an extension from the store.
101+
func (es *ExtensionStore) DeleteExtension(ctx context.Context, extensionID string) error {
102+
extension, err := es.GetExtension(ctx, extensionID)
103+
if err != nil {
104+
return err
105+
}
106+
return es.store.Delete(extension.ID, extension)
107+
}
108+
109+
// AddExtensionService adds a new extension service to an extension.
110+
func (es *ExtensionStore) AddExtensionService(ctx context.Context, extensionID string, service *domain.ExtensionService) (*domain.ExtensionService, error) {
111+
extension, err := es.GetExtension(ctx, extensionID)
112+
if err != nil {
113+
return nil, err
114+
}
115+
svc, err := extension.AddService(service)
116+
if err != nil {
117+
return nil, err
118+
}
119+
err = es.UpdateExtension(ctx, extension)
120+
if err != nil {
121+
return nil, err
122+
}
123+
return svc, nil
124+
}
125+
126+
// GetExtensionService retrieves an extension service by its ID.
127+
func (es *ExtensionStore) GetExtensionService(ctx context.Context, extensionID string, serviceID string) (*domain.ExtensionService, error) {
128+
extension, err := es.GetExtension(ctx, extensionID)
129+
if err != nil {
130+
return nil, err
131+
}
132+
return extension.GetService(serviceID)
133+
}
134+
135+
// ListExtensionServices retrieves all services belonging to an extension.
136+
func (es *ExtensionStore) ListExtensionServices(ctx context.Context, extensionID string) ([]*domain.ExtensionService, error) {
137+
extension, err := es.GetExtension(ctx, extensionID)
138+
if err != nil {
139+
return nil, err
140+
}
141+
return extension.ListServices(), nil
142+
}
143+
144+
// UpdateExtensionService updates a service belonging to an extension.
145+
func (es *ExtensionStore) UpdateExtensionService(ctx context.Context, extensionID string, newService *domain.ExtensionService) error {
146+
extension, err := es.GetExtension(ctx, extensionID)
147+
if err != nil {
148+
return err
149+
}
150+
err = extension.UpdateService(newService)
151+
if err != nil {
152+
return err
153+
}
154+
return es.UpdateExtension(ctx, extension)
155+
}
156+
157+
// DeleteExtensionService deletes an extension service from an extension.
158+
func (es *ExtensionStore) DeleteExtensionService(ctx context.Context, extensionID string, serviceID string) error {
159+
extension, err := es.GetExtension(ctx, extensionID)
160+
if err != nil {
161+
return err
162+
}
163+
err = extension.DeleteService(serviceID)
164+
if err != nil {
165+
return err
166+
}
167+
return es.UpdateExtension(ctx, extension)
168+
}
169+
170+
// AddExtensionServiceEndpoint adds a new endpoint to an extension service.
171+
func (es *ExtensionStore) AddExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpoint *domain.ExtensionServiceEndpoint) (*domain.ExtensionServiceEndpoint, error) {
172+
extension, err := es.GetExtension(ctx, extensionID)
173+
if err != nil {
174+
return nil, err
175+
}
176+
svc, err := extension.GetService(serviceID)
177+
if err != nil {
178+
return nil, err
179+
}
180+
endpoint, err = svc.AddEndpoint(endpoint)
181+
if err != nil {
182+
return nil, err
183+
}
184+
err = es.UpdateExtension(ctx, extension)
185+
if err != nil {
186+
return nil, err
187+
}
188+
return endpoint, nil
189+
}
190+
191+
// GetExtensionServiceEndpoint retrieves an extension endpoint by its ID.
192+
func (es *ExtensionStore) GetExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) (*domain.ExtensionServiceEndpoint, error) {
193+
extension, err := es.GetExtension(ctx, extensionID)
194+
if err != nil {
195+
return nil, err
196+
}
197+
svc, err := extension.GetService(serviceID)
198+
if err != nil {
199+
return nil, err
200+
}
201+
return svc.GetEndpoint(endpointID)
202+
}
203+
204+
// ListExtensionServiceEndpoints retrieves all endpoints belonging to an extension service.
205+
func (es *ExtensionStore) ListExtensionServiceEndpoints(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceEndpoint, error) {
206+
extension, err := es.GetExtension(ctx, extensionID)
207+
if err != nil {
208+
return nil, err
209+
}
210+
svc, err := extension.GetService(serviceID)
211+
if err != nil {
212+
return nil, err
213+
}
214+
return svc.ListEndpoints(), nil
215+
}
216+
217+
// UpdateExtensionServiceEndpoint updates an endpoint belonging to an extension service.
218+
func (es *ExtensionStore) UpdateExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, newEndpoint *domain.ExtensionServiceEndpoint) error {
219+
extension, err := es.GetExtension(ctx, extensionID)
220+
if err != nil {
221+
return err
222+
}
223+
svc, err := extension.GetService(serviceID)
224+
if err != nil {
225+
return err
226+
}
227+
err = svc.UpdateEndpoint(newEndpoint)
228+
if err != nil {
229+
return err
230+
}
231+
return es.UpdateExtension(ctx, extension)
232+
}
233+
234+
// DeleteExtensionServiceEndpoint deletes an extension endpoint from an extension service.
235+
func (es *ExtensionStore) DeleteExtensionServiceEndpoint(ctx context.Context, extensionID string, serviceID string, endpointID string) error {
236+
extension, err := es.GetExtension(ctx, extensionID)
237+
if err != nil {
238+
return err
239+
}
240+
svc, err := extension.GetService(serviceID)
241+
if err != nil {
242+
return err
243+
}
244+
err = svc.DeleteEndpoint(endpointID)
245+
if err != nil {
246+
return err
247+
}
248+
return es.UpdateExtension(ctx, extension)
249+
}
250+
251+
// AddExtensionServiceCredentials adds a new credential to an extension service.
252+
func (es *ExtensionStore) AddExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentials *domain.ExtensionServiceCredentials) (*domain.ExtensionServiceCredentials, error) {
253+
extension, err := es.GetExtension(ctx, extensionID)
254+
if err != nil {
255+
return nil, err
256+
}
257+
svc, err := extension.GetService(serviceID)
258+
if err != nil {
259+
return nil, err
260+
}
261+
credentials, err = svc.AddCredentials(credentials)
262+
if err != nil {
263+
return nil, err
264+
}
265+
err = es.UpdateExtension(ctx, extension)
266+
if err != nil {
267+
return nil, err
268+
}
269+
return credentials, nil
270+
}
271+
272+
// GetExtensionServiceCredentials retrieves an extension credential by its ID.
273+
func (es *ExtensionStore) GetExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) (*domain.ExtensionServiceCredentials, error) {
274+
extension, err := es.GetExtension(ctx, extensionID)
275+
if err != nil {
276+
return nil, err
277+
}
278+
svc, err := extension.GetService(serviceID)
279+
if err != nil {
280+
return nil, err
281+
}
282+
return svc.GetCredentials(credentialsID)
283+
}
284+
285+
// ListExtensionServiceCredentials retrieves all credentials belonging to an extension service.
286+
func (es *ExtensionStore) ListExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string) ([]*domain.ExtensionServiceCredentials, error) {
287+
extension, err := es.GetExtension(ctx, extensionID)
288+
if err != nil {
289+
return nil, err
290+
}
291+
svc, err := extension.GetService(serviceID)
292+
if err != nil {
293+
return nil, err
294+
}
295+
return svc.ListCredentials(), nil
296+
}
297+
298+
// UpdateExtensionServiceCredentials updates an extension credential.
299+
func (es *ExtensionStore) UpdateExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, newCredentials *domain.ExtensionServiceCredentials) error {
300+
extension, err := es.GetExtension(ctx, extensionID)
301+
if err != nil {
302+
return err
303+
}
304+
svc, err := extension.GetService(serviceID)
305+
if err != nil {
306+
return err
307+
}
308+
err = svc.UpdateCredentials(newCredentials)
309+
if err != nil {
310+
return err
311+
}
312+
return es.UpdateExtension(ctx, extension)
313+
}
314+
315+
// DeleteExtensionServiceCredentials deletes an extension credential from an extension service.
316+
func (es *ExtensionStore) DeleteExtensionServiceCredentials(ctx context.Context, extensionID string, serviceID string, credentialsID string) error {
317+
extension, err := es.GetExtension(ctx, extensionID)
318+
if err != nil {
319+
return err
320+
}
321+
svc, err := extension.GetService(serviceID)
322+
if err != nil {
323+
return err
324+
}
325+
err = svc.DeleteCredentials(credentialsID)
326+
if err != nil {
327+
return err
328+
}
329+
return es.UpdateExtension(ctx, extension)
330+
}
331+
332+
// GetExtensionAccessDescriptors retrieves access descriptors belonging to an extension that matches the query.
333+
func (es *ExtensionStore) GetExtensionAccessDescriptors(ctx context.Context, query *domain.ExtensionQuery) (result []*domain.ExtensionAccessDescriptor, err error) {
334+
result = make([]*domain.ExtensionAccessDescriptor, 0)
335+
336+
for _, extension := range es.ListExtensions(ctx, query) {
337+
result = append(result, extension.GetAccessDescriptors()...)
338+
}
339+
return result, nil
340+
}

0 commit comments

Comments
 (0)