diff --git a/Makefile b/Makefile index c2e66855..c66a1a66 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test lint cover mock_wire mock_gen swag migrate_local +.PHONY: test lint cover wire mock_gen swag migrate_local test: go test ./... @@ -11,7 +11,7 @@ cover: go tool cover -html=cover.out -o cover.html open cover.html -mock_wire: +wire: @echo "Running wire for component mocks..." @go run -mod=mod github.com/google/wire/cmd/wire opencsg.com/csghub-server/component/... @if [ $$? -eq 0 ]; then \ @@ -22,6 +22,25 @@ mock_wire: else \ echo "Wire failed, skipping renaming."; \ fi + @echo "Running wire for api/router..." + @go run -mod=mod github.com/google/wire/cmd/wire gen --header_file=wire/ce_header opencsg.com/csghub-server/api/router/... + @mv api/router/wire_gen.go api/router/wire_gen_ce.go + + @if [ -f api/router/wire_ee.go ]; then \ + echo "Running wire for ee..."; \ + go run -mod=mod github.com/google/wire/cmd/wire gen -tags=ee --header_file=wire/ee_header opencsg.com/csghub-server/api/router/...; \ + mv api/router/wire_gen.go api/router/wire_gen_ee.go; \ + else \ + echo "wire_ee.go not exists, skipping ee generation..."; \ + fi + + @if [ -f api/router/wire_saas.go ]; then \ + echo "Running wire for saas..."; \ + go run -mod=mod github.com/google/wire/cmd/wire gen -tags=saas --header_file=wire/saas_header opencsg.com/csghub-server/api/router/...; \ + mv api/router/wire_gen.go api/router/wire_gen_saas.go; \ + else \ + echo "wire_saas.go not exists, skipping saas generation..."; \ + fi mock_gen: mockery diff --git a/_mocks/opencsg.com/csghub-server/builder/store/cache/mock_RedisClient.go b/_mocks/opencsg.com/csghub-server/builder/store/cache/mock_RedisClient.go new file mode 100644 index 00000000..4f503fb4 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/builder/store/cache/mock_RedisClient.go @@ -0,0 +1,629 @@ +// Code generated by mockery v2.49.1. DO NOT EDIT. + +package cache + +import ( + context "context" + + redis "github.com/redis/go-redis/v9" + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// MockRedisClient is an autogenerated mock type for the RedisClient type +type MockRedisClient struct { + mock.Mock +} + +type MockRedisClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MockRedisClient) EXPECT() *MockRedisClient_Expecter { + return &MockRedisClient_Expecter{mock: &_m.Mock} +} + +// BZPopMax provides a mock function with given fields: ctx, key +func (_m *MockRedisClient) BZPopMax(ctx context.Context, key string) (*redis.ZWithKey, error) { + ret := _m.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for BZPopMax") + } + + var r0 *redis.ZWithKey + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*redis.ZWithKey, error)); ok { + return rf(ctx, key) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *redis.ZWithKey); ok { + r0 = rf(ctx, key) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*redis.ZWithKey) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRedisClient_BZPopMax_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BZPopMax' +type MockRedisClient_BZPopMax_Call struct { + *mock.Call +} + +// BZPopMax is a helper method to define mock.On call +// - ctx context.Context +// - key string +func (_e *MockRedisClient_Expecter) BZPopMax(ctx interface{}, key interface{}) *MockRedisClient_BZPopMax_Call { + return &MockRedisClient_BZPopMax_Call{Call: _e.mock.On("BZPopMax", ctx, key)} +} + +func (_c *MockRedisClient_BZPopMax_Call) Run(run func(ctx context.Context, key string)) *MockRedisClient_BZPopMax_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockRedisClient_BZPopMax_Call) Return(_a0 *redis.ZWithKey, _a1 error) *MockRedisClient_BZPopMax_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRedisClient_BZPopMax_Call) RunAndReturn(run func(context.Context, string) (*redis.ZWithKey, error)) *MockRedisClient_BZPopMax_Call { + _c.Call.Return(run) + return _c +} + +// Del provides a mock function with given fields: ctx, keys +func (_m *MockRedisClient) Del(ctx context.Context, keys ...string) error { + _va := make([]interface{}, len(keys)) + for _i := range keys { + _va[_i] = keys[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for Del") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, ...string) error); ok { + r0 = rf(ctx, keys...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRedisClient_Del_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Del' +type MockRedisClient_Del_Call struct { + *mock.Call +} + +// Del is a helper method to define mock.On call +// - ctx context.Context +// - keys ...string +func (_e *MockRedisClient_Expecter) Del(ctx interface{}, keys ...interface{}) *MockRedisClient_Del_Call { + return &MockRedisClient_Del_Call{Call: _e.mock.On("Del", + append([]interface{}{ctx}, keys...)...)} +} + +func (_c *MockRedisClient_Del_Call) Run(run func(ctx context.Context, keys ...string)) *MockRedisClient_Del_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]string, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(string) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockRedisClient_Del_Call) Return(_a0 error) *MockRedisClient_Del_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRedisClient_Del_Call) RunAndReturn(run func(context.Context, ...string) error) *MockRedisClient_Del_Call { + _c.Call.Return(run) + return _c +} + +// FlushAll provides a mock function with given fields: ctx +func (_m *MockRedisClient) FlushAll(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for FlushAll") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRedisClient_FlushAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FlushAll' +type MockRedisClient_FlushAll_Call struct { + *mock.Call +} + +// FlushAll is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockRedisClient_Expecter) FlushAll(ctx interface{}) *MockRedisClient_FlushAll_Call { + return &MockRedisClient_FlushAll_Call{Call: _e.mock.On("FlushAll", ctx)} +} + +func (_c *MockRedisClient_FlushAll_Call) Run(run func(ctx context.Context)) *MockRedisClient_FlushAll_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockRedisClient_FlushAll_Call) Return(_a0 error) *MockRedisClient_FlushAll_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRedisClient_FlushAll_Call) RunAndReturn(run func(context.Context) error) *MockRedisClient_FlushAll_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: ctx, key +func (_m *MockRedisClient) Get(ctx context.Context, key string) (string, error) { + ret := _m.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { + return rf(ctx, key) + } + if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRedisClient_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type MockRedisClient_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - key string +func (_e *MockRedisClient_Expecter) Get(ctx interface{}, key interface{}) *MockRedisClient_Get_Call { + return &MockRedisClient_Get_Call{Call: _e.mock.On("Get", ctx, key)} +} + +func (_c *MockRedisClient_Get_Call) Run(run func(ctx context.Context, key string)) *MockRedisClient_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockRedisClient_Get_Call) Return(_a0 string, _a1 error) *MockRedisClient_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRedisClient_Get_Call) RunAndReturn(run func(context.Context, string) (string, error)) *MockRedisClient_Get_Call { + _c.Call.Return(run) + return _c +} + +// RunWhileLocked provides a mock function with given fields: ctx, resourceName, expiration, fn +func (_m *MockRedisClient) RunWhileLocked(ctx context.Context, resourceName string, expiration time.Duration, fn func(context.Context) error) error { + ret := _m.Called(ctx, resourceName, expiration, fn) + + if len(ret) == 0 { + panic("no return value specified for RunWhileLocked") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, time.Duration, func(context.Context) error) error); ok { + r0 = rf(ctx, resourceName, expiration, fn) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRedisClient_RunWhileLocked_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RunWhileLocked' +type MockRedisClient_RunWhileLocked_Call struct { + *mock.Call +} + +// RunWhileLocked is a helper method to define mock.On call +// - ctx context.Context +// - resourceName string +// - expiration time.Duration +// - fn func(context.Context) error +func (_e *MockRedisClient_Expecter) RunWhileLocked(ctx interface{}, resourceName interface{}, expiration interface{}, fn interface{}) *MockRedisClient_RunWhileLocked_Call { + return &MockRedisClient_RunWhileLocked_Call{Call: _e.mock.On("RunWhileLocked", ctx, resourceName, expiration, fn)} +} + +func (_c *MockRedisClient_RunWhileLocked_Call) Run(run func(ctx context.Context, resourceName string, expiration time.Duration, fn func(context.Context) error)) *MockRedisClient_RunWhileLocked_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(time.Duration), args[3].(func(context.Context) error)) + }) + return _c +} + +func (_c *MockRedisClient_RunWhileLocked_Call) Return(_a0 error) *MockRedisClient_RunWhileLocked_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRedisClient_RunWhileLocked_Call) RunAndReturn(run func(context.Context, string, time.Duration, func(context.Context) error) error) *MockRedisClient_RunWhileLocked_Call { + _c.Call.Return(run) + return _c +} + +// SAdd provides a mock function with given fields: ctx, key, members +func (_m *MockRedisClient) SAdd(ctx context.Context, key string, members ...interface{}) error { + var _ca []interface{} + _ca = append(_ca, ctx, key) + _ca = append(_ca, members...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for SAdd") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) error); ok { + r0 = rf(ctx, key, members...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRedisClient_SAdd_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SAdd' +type MockRedisClient_SAdd_Call struct { + *mock.Call +} + +// SAdd is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - members ...interface{} +func (_e *MockRedisClient_Expecter) SAdd(ctx interface{}, key interface{}, members ...interface{}) *MockRedisClient_SAdd_Call { + return &MockRedisClient_SAdd_Call{Call: _e.mock.On("SAdd", + append([]interface{}{ctx, key}, members...)...)} +} + +func (_c *MockRedisClient_SAdd_Call) Run(run func(ctx context.Context, key string, members ...interface{})) *MockRedisClient_SAdd_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]interface{}, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(interface{}) + } + } + run(args[0].(context.Context), args[1].(string), variadicArgs...) + }) + return _c +} + +func (_c *MockRedisClient_SAdd_Call) Return(_a0 error) *MockRedisClient_SAdd_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRedisClient_SAdd_Call) RunAndReturn(run func(context.Context, string, ...interface{}) error) *MockRedisClient_SAdd_Call { + _c.Call.Return(run) + return _c +} + +// SCard provides a mock function with given fields: ctx, key +func (_m *MockRedisClient) SCard(ctx context.Context, key string) (int64, error) { + ret := _m.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for SCard") + } + + var r0 int64 + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (int64, error)); ok { + return rf(ctx, key) + } + if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Get(0).(int64) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, key) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRedisClient_SCard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SCard' +type MockRedisClient_SCard_Call struct { + *mock.Call +} + +// SCard is a helper method to define mock.On call +// - ctx context.Context +// - key string +func (_e *MockRedisClient_Expecter) SCard(ctx interface{}, key interface{}) *MockRedisClient_SCard_Call { + return &MockRedisClient_SCard_Call{Call: _e.mock.On("SCard", ctx, key)} +} + +func (_c *MockRedisClient_SCard_Call) Run(run func(ctx context.Context, key string)) *MockRedisClient_SCard_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockRedisClient_SCard_Call) Return(_a0 int64, _a1 error) *MockRedisClient_SCard_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRedisClient_SCard_Call) RunAndReturn(run func(context.Context, string) (int64, error)) *MockRedisClient_SCard_Call { + _c.Call.Return(run) + return _c +} + +// SIsMember provides a mock function with given fields: ctx, key, member +func (_m *MockRedisClient) SIsMember(ctx context.Context, key string, member interface{}) (bool, error) { + ret := _m.Called(ctx, key, member) + + if len(ret) == 0 { + panic("no return value specified for SIsMember") + } + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) (bool, error)); ok { + return rf(ctx, key, member) + } + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) bool); ok { + r0 = rf(ctx, key, member) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { + r1 = rf(ctx, key, member) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRedisClient_SIsMember_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SIsMember' +type MockRedisClient_SIsMember_Call struct { + *mock.Call +} + +// SIsMember is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - member interface{} +func (_e *MockRedisClient_Expecter) SIsMember(ctx interface{}, key interface{}, member interface{}) *MockRedisClient_SIsMember_Call { + return &MockRedisClient_SIsMember_Call{Call: _e.mock.On("SIsMember", ctx, key, member)} +} + +func (_c *MockRedisClient_SIsMember_Call) Run(run func(ctx context.Context, key string, member interface{})) *MockRedisClient_SIsMember_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(interface{})) + }) + return _c +} + +func (_c *MockRedisClient_SIsMember_Call) Return(_a0 bool, _a1 error) *MockRedisClient_SIsMember_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRedisClient_SIsMember_Call) RunAndReturn(run func(context.Context, string, interface{}) (bool, error)) *MockRedisClient_SIsMember_Call { + _c.Call.Return(run) + return _c +} + +// Set provides a mock function with given fields: ctx, key, value +func (_m *MockRedisClient) Set(ctx context.Context, key string, value string) error { + ret := _m.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Set") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRedisClient_Set_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Set' +type MockRedisClient_Set_Call struct { + *mock.Call +} + +// Set is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - value string +func (_e *MockRedisClient_Expecter) Set(ctx interface{}, key interface{}, value interface{}) *MockRedisClient_Set_Call { + return &MockRedisClient_Set_Call{Call: _e.mock.On("Set", ctx, key, value)} +} + +func (_c *MockRedisClient_Set_Call) Run(run func(ctx context.Context, key string, value string)) *MockRedisClient_Set_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *MockRedisClient_Set_Call) Return(_a0 error) *MockRedisClient_Set_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRedisClient_Set_Call) RunAndReturn(run func(context.Context, string, string) error) *MockRedisClient_Set_Call { + _c.Call.Return(run) + return _c +} + +// WaitLockToRun provides a mock function with given fields: ctx, resourceName, expiration, fn +func (_m *MockRedisClient) WaitLockToRun(ctx context.Context, resourceName string, expiration time.Duration, fn func(context.Context) error) error { + ret := _m.Called(ctx, resourceName, expiration, fn) + + if len(ret) == 0 { + panic("no return value specified for WaitLockToRun") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, time.Duration, func(context.Context) error) error); ok { + r0 = rf(ctx, resourceName, expiration, fn) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRedisClient_WaitLockToRun_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WaitLockToRun' +type MockRedisClient_WaitLockToRun_Call struct { + *mock.Call +} + +// WaitLockToRun is a helper method to define mock.On call +// - ctx context.Context +// - resourceName string +// - expiration time.Duration +// - fn func(context.Context) error +func (_e *MockRedisClient_Expecter) WaitLockToRun(ctx interface{}, resourceName interface{}, expiration interface{}, fn interface{}) *MockRedisClient_WaitLockToRun_Call { + return &MockRedisClient_WaitLockToRun_Call{Call: _e.mock.On("WaitLockToRun", ctx, resourceName, expiration, fn)} +} + +func (_c *MockRedisClient_WaitLockToRun_Call) Run(run func(ctx context.Context, resourceName string, expiration time.Duration, fn func(context.Context) error)) *MockRedisClient_WaitLockToRun_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(time.Duration), args[3].(func(context.Context) error)) + }) + return _c +} + +func (_c *MockRedisClient_WaitLockToRun_Call) Return(_a0 error) *MockRedisClient_WaitLockToRun_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRedisClient_WaitLockToRun_Call) RunAndReturn(run func(context.Context, string, time.Duration, func(context.Context) error) error) *MockRedisClient_WaitLockToRun_Call { + _c.Call.Return(run) + return _c +} + +// ZAdd provides a mock function with given fields: ctx, key, z +func (_m *MockRedisClient) ZAdd(ctx context.Context, key string, z redis.Z) error { + ret := _m.Called(ctx, key, z) + + if len(ret) == 0 { + panic("no return value specified for ZAdd") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, redis.Z) error); ok { + r0 = rf(ctx, key, z) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRedisClient_ZAdd_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ZAdd' +type MockRedisClient_ZAdd_Call struct { + *mock.Call +} + +// ZAdd is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - z redis.Z +func (_e *MockRedisClient_Expecter) ZAdd(ctx interface{}, key interface{}, z interface{}) *MockRedisClient_ZAdd_Call { + return &MockRedisClient_ZAdd_Call{Call: _e.mock.On("ZAdd", ctx, key, z)} +} + +func (_c *MockRedisClient_ZAdd_Call) Run(run func(ctx context.Context, key string, z redis.Z)) *MockRedisClient_ZAdd_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(redis.Z)) + }) + return _c +} + +func (_c *MockRedisClient_ZAdd_Call) Return(_a0 error) *MockRedisClient_ZAdd_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRedisClient_ZAdd_Call) RunAndReturn(run func(context.Context, string, redis.Z) error) *MockRedisClient_ZAdd_Call { + _c.Call.Return(run) + return _c +} + +// NewMockRedisClient creates a new instance of MockRedisClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockRedisClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRedisClient { + mock := &MockRedisClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_BroadcastStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_BroadcastStore.go index b0fd1043..1691c063 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_BroadcastStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_BroadcastStore.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.51.0. DO NOT EDIT. +// Code generated by mockery v2.49.1. DO NOT EDIT. package database @@ -118,7 +118,7 @@ type MockBroadcastStore_Get_Call struct { // Get is a helper method to define mock.On call // - ctx context.Context // - id int64 -func (_e *MockBroadcastStore_Expecter) Get(ctx interface{}, id int64) *MockBroadcastStore_Get_Call { +func (_e *MockBroadcastStore_Expecter) Get(ctx interface{}, id interface{}) *MockBroadcastStore_Get_Call { return &MockBroadcastStore_Get_Call{Call: _e.mock.On("Get", ctx, id)} } @@ -165,7 +165,7 @@ type MockBroadcastStore_Save_Call struct { // Save is a helper method to define mock.On call // - ctx context.Context // - broadcast database.Broadcast -func (_e *MockBroadcastStore_Expecter) Save(ctx interface{}, broadcast database.Broadcast) *MockBroadcastStore_Save_Call { +func (_e *MockBroadcastStore_Expecter) Save(ctx interface{}, broadcast interface{}) *MockBroadcastStore_Save_Call { return &MockBroadcastStore_Save_Call{Call: _e.mock.On("Save", ctx, broadcast)} } diff --git a/_mocks/opencsg.com/csghub-server/component/mock_BroadcastComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_BroadcastComponent.go index f5725633..52668ea0 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_BroadcastComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_BroadcastComponent.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.51.0. DO NOT EDIT. +// Code generated by mockery v2.49.1. DO NOT EDIT. package component diff --git a/_mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go index 0a8f398a..2fb956c5 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_MirrorComponent.go @@ -175,6 +175,68 @@ func (_c *MockMirrorComponent_CreatePushMirrorForFinishedMirrorTask_Call) RunAnd return _c } +// FindWithMapping provides a mock function with given fields: ctx, repoType, namespace, name, mapping +func (_m *MockMirrorComponent) FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace string, name string, mapping types.Mapping) (*database.Repository, error) { + ret := _m.Called(ctx, repoType, namespace, name, mapping) + + if len(ret) == 0 { + panic("no return value specified for FindWithMapping") + } + + var r0 *database.Repository + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.RepositoryType, string, string, types.Mapping) (*database.Repository, error)); ok { + return rf(ctx, repoType, namespace, name, mapping) + } + if rf, ok := ret.Get(0).(func(context.Context, types.RepositoryType, string, string, types.Mapping) *database.Repository); ok { + r0 = rf(ctx, repoType, namespace, name, mapping) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.Repository) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.RepositoryType, string, string, types.Mapping) error); ok { + r1 = rf(ctx, repoType, namespace, name, mapping) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockMirrorComponent_FindWithMapping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FindWithMapping' +type MockMirrorComponent_FindWithMapping_Call struct { + *mock.Call +} + +// FindWithMapping is a helper method to define mock.On call +// - ctx context.Context +// - repoType types.RepositoryType +// - namespace string +// - name string +// - mapping types.Mapping +func (_e *MockMirrorComponent_Expecter) FindWithMapping(ctx interface{}, repoType interface{}, namespace interface{}, name interface{}, mapping interface{}) *MockMirrorComponent_FindWithMapping_Call { + return &MockMirrorComponent_FindWithMapping_Call{Call: _e.mock.On("FindWithMapping", ctx, repoType, namespace, name, mapping)} +} + +func (_c *MockMirrorComponent_FindWithMapping_Call) Run(run func(ctx context.Context, repoType types.RepositoryType, namespace string, name string, mapping types.Mapping)) *MockMirrorComponent_FindWithMapping_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.RepositoryType), args[2].(string), args[3].(string), args[4].(types.Mapping)) + }) + return _c +} + +func (_c *MockMirrorComponent_FindWithMapping_Call) Return(_a0 *database.Repository, _a1 error) *MockMirrorComponent_FindWithMapping_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockMirrorComponent_FindWithMapping_Call) RunAndReturn(run func(context.Context, types.RepositoryType, string, string, types.Mapping) (*database.Repository, error)) *MockMirrorComponent_FindWithMapping_Call { + _c.Call.Return(run) + return _c +} + // Index provides a mock function with given fields: ctx, currentUser, per, page, search func (_m *MockMirrorComponent) Index(ctx context.Context, currentUser string, per int, page int, search string) ([]types.Mirror, int, error) { ret := _m.Called(ctx, currentUser, per, page, search) diff --git a/_mocks/opencsg.com/csghub-server/component/mock_UserComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_UserComponent.go index 69e029e0..7297b050 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_UserComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_UserComponent.go @@ -382,6 +382,124 @@ func (_c *MockUserComponent_Evaluations_Call) RunAndReturn(run func(context.Cont return _c } +// FindByAccessToken provides a mock function with given fields: ctx, token +func (_m *MockUserComponent) FindByAccessToken(ctx context.Context, token string) (*database.User, error) { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for FindByAccessToken") + } + + var r0 *database.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*database.User, error)); ok { + return rf(ctx, token) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *database.User); ok { + r0 = rf(ctx, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.User) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, token) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockUserComponent_FindByAccessToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FindByAccessToken' +type MockUserComponent_FindByAccessToken_Call struct { + *mock.Call +} + +// FindByAccessToken is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *MockUserComponent_Expecter) FindByAccessToken(ctx interface{}, token interface{}) *MockUserComponent_FindByAccessToken_Call { + return &MockUserComponent_FindByAccessToken_Call{Call: _e.mock.On("FindByAccessToken", ctx, token)} +} + +func (_c *MockUserComponent_FindByAccessToken_Call) Run(run func(ctx context.Context, token string)) *MockUserComponent_FindByAccessToken_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockUserComponent_FindByAccessToken_Call) Return(_a0 *database.User, _a1 error) *MockUserComponent_FindByAccessToken_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockUserComponent_FindByAccessToken_Call) RunAndReturn(run func(context.Context, string) (*database.User, error)) *MockUserComponent_FindByAccessToken_Call { + _c.Call.Return(run) + return _c +} + +// FindByGitAccessToken provides a mock function with given fields: ctx, token +func (_m *MockUserComponent) FindByGitAccessToken(ctx context.Context, token string) (*database.User, error) { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for FindByGitAccessToken") + } + + var r0 *database.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*database.User, error)); ok { + return rf(ctx, token) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *database.User); ok { + r0 = rf(ctx, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*database.User) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, token) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockUserComponent_FindByGitAccessToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FindByGitAccessToken' +type MockUserComponent_FindByGitAccessToken_Call struct { + *mock.Call +} + +// FindByGitAccessToken is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *MockUserComponent_Expecter) FindByGitAccessToken(ctx interface{}, token interface{}) *MockUserComponent_FindByGitAccessToken_Call { + return &MockUserComponent_FindByGitAccessToken_Call{Call: _e.mock.On("FindByGitAccessToken", ctx, token)} +} + +func (_c *MockUserComponent_FindByGitAccessToken_Call) Run(run func(ctx context.Context, token string)) *MockUserComponent_FindByGitAccessToken_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockUserComponent_FindByGitAccessToken_Call) Return(_a0 *database.User, _a1 error) *MockUserComponent_FindByGitAccessToken_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockUserComponent_FindByGitAccessToken_Call) RunAndReturn(run func(context.Context, string) (*database.User, error)) *MockUserComponent_FindByGitAccessToken_Call { + _c.Call.Return(run) + return _c +} + // GetUserByName provides a mock function with given fields: ctx, userName func (_m *MockUserComponent) GetUserByName(ctx context.Context, userName string) (*database.User, error) { ret := _m.Called(ctx, userName) diff --git a/accounting/router/api.go b/accounting/router/api.go index ebca7115..2f2ba936 100644 --- a/accounting/router/api.go +++ b/accounting/router/api.go @@ -12,9 +12,10 @@ import ( func NewAccountRouter(config *config.Config) (*gin.Engine, error) { r := gin.New() r.Use(gin.Recovery()) + middleware := middleware.NewMiddleware(config) - r.Use(middleware.Log(config)) - r.Use(middleware.Authenticator(config)) + r.Use(middleware.Log()) + r.Use(middleware.Authenticator()) // metering meterHandler, err := handler.NewMeteringHandler() diff --git a/api/apitest/server.go b/api/apitest/server.go new file mode 100644 index 00000000..7976b1c8 --- /dev/null +++ b/api/apitest/server.go @@ -0,0 +1,140 @@ +package apitest + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/api/middleware" + "opencsg.com/csghub-server/api/router" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +type ResponseHelper struct { + response *httptest.ResponseRecorder +} + +func (h *ResponseHelper) Response() *httptest.ResponseRecorder { + return h.response +} + +func (h *ResponseHelper) ResponseEq(t *testing.T, code int, msg string, expected any) { + var r = struct { + Msg string `json:"msg"` + Data any `json:"data,omitempty"` + }{ + Msg: msg, + Data: expected, + } + b, err := json.Marshal(r) + require.NoError(t, err) + require.Equal(t, code, h.response.Code, h.response.Body.String()) + require.JSONEq(t, string(b), h.response.Body.String()) + +} + +func (h *ResponseHelper) ResponseEqSimple(t *testing.T, code int, expected any) { + b, err := json.Marshal(expected) + require.NoError(t, err) + require.Equal(t, code, h.response.Code, h.response.Body.String()) + require.JSONEq(t, string(b), h.response.Body.String()) + +} + +type TestServer struct { + server *router.ServerImpl +} + +func NewTestServer(t *testing.T, option func(s *router.ServerImpl)) *TestServer { + gin.SetMode(gin.ReleaseMode) + mu := mockcomponent.NewMockUserComponent(t) + mu.EXPECT().FindByAccessToken(mock.Anything, "u:p").Return(&database.User{ + Username: "u", + }, nil).Maybe() + mm := mockcomponent.NewMockMirrorComponent(t) + mm.EXPECT().FindWithMapping( + mock.Anything, mock.Anything, "u", "r", types.HFMapping, + ).Return(&database.Repository{ + Path: "u/r", + }, nil).Maybe() + config := &config.Config{} + config.GitServer.Type = types.GitServerTypeGitaly + now := time.Now() + md := middleware.NewMiddlewareDI(config, mu, mm, nil) + server := &router.ServerImpl{ + BaseServer: &router.BaseServer{ + Middleware: md, + Config: config, + }, + } + option(server) + err := server.RegisterRoutes(false) + if err != nil { + panic(err) + } + fmt.Println("====x", time.Since(now)) + return &TestServer{server: server} +} + +func (ts *TestServer) NewRequest(method, url string, body any) (*http.Request, error) { + d, err := json.Marshal(body) + if err != nil { + return nil, err + } + return http.NewRequest(method, url, strings.NewReader(string(d))) +} + +func (ts *TestServer) NewGetRequest(url string) (*http.Request, error) { + return http.NewRequest(http.MethodGet, url, nil) +} + +func (ts *TestServer) NewPostRequest(url string, body any) (*http.Request, error) { + d, err := json.Marshal(body) + if err != nil { + return nil, err + } + return http.NewRequest(http.MethodPost, url, strings.NewReader(string(d))) +} + +func (ts *TestServer) NewPutRequest(url string, body any) (*http.Request, error) { + d, err := json.Marshal(body) + if err != nil { + return nil, err + } + return http.NewRequest(http.MethodPut, url, strings.NewReader(string(d))) +} + +func (ts *TestServer) NewDeleteRequest(url string) (*http.Request, error) { + return http.NewRequest(http.MethodDelete, url, nil) +} + +func (ts *TestServer) AuthRequest(req *http.Request) *http.Request { + req.Header.Add("Authorization", "Bearer u:p") + return req +} + +func (ts *TestServer) Send(req *http.Request) *ResponseHelper { + w := httptest.NewRecorder() + ts.server.Engine.ServeHTTP(w, req) + return &ResponseHelper{response: w} +} + +func (ts *TestServer) AuthAndSend(t *testing.T, req *http.Request) *ResponseHelper { + r := ts.Send(req) + require.Equal(t, 401, r.response.Code) + + ts.AuthRequest(req) + w := httptest.NewRecorder() + ts.server.Engine.ServeHTTP(w, req) + return &ResponseHelper{response: w} +} diff --git a/api/handler/discussion.go b/api/handler/discussion.go index dc85e9da..9c0e7c53 100644 --- a/api/handler/discussion.go +++ b/api/handler/discussion.go @@ -19,6 +19,13 @@ type DiscussionHandler struct { sensitive component.SensitiveComponent } +func NewDiscussionHandlerDI(discussion component.DiscussionComponent, sensitive component.SensitiveComponent) *DiscussionHandler { + return &DiscussionHandler{ + discussion: discussion, + sensitive: sensitive, + } +} + func NewDiscussionHandler(cfg *config.Config) (*DiscussionHandler, error) { c := component.NewDiscussionComponent() sc, err := component.NewSensitiveComponent(cfg) diff --git a/api/handler/discussion_test.go b/api/handler/discussion_test.go index 4df96e0d..81c4d375 100644 --- a/api/handler/discussion_test.go +++ b/api/handler/discussion_test.go @@ -1,17 +1,21 @@ -package handler +package handler_test import ( "testing" - "github.com/gin-gonic/gin" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" mockcomponent "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/component" + "opencsg.com/csghub-server/api/apitest" + "opencsg.com/csghub-server/api/handler" + "opencsg.com/csghub-server/api/router" "opencsg.com/csghub-server/builder/testutil" "opencsg.com/csghub-server/common/types" ) type DiscussionTester struct { *testutil.GinTester - handler *DiscussionHandler + handler *handler.DiscussionHandler mocks struct { discussion *mockcomponent.MockDiscussionComponent sensitive *mockcomponent.MockSensitiveComponent @@ -23,31 +27,30 @@ func NewDiscussionTester(t *testing.T) *DiscussionTester { tester.mocks.discussion = mockcomponent.NewMockDiscussionComponent(t) tester.mocks.sensitive = mockcomponent.NewMockSensitiveComponent(t) - tester.handler = &DiscussionHandler{ - discussion: tester.mocks.discussion, - sensitive: tester.mocks.sensitive, - } - tester.WithParam("namespace", "u") - tester.WithParam("name", "r") + tester.handler = handler.NewDiscussionHandlerDI( + tester.mocks.discussion, + tester.mocks.sensitive, + ) return tester } -func (t *DiscussionTester) WithHandleFunc(fn func(h *DiscussionHandler) gin.HandlerFunc) *DiscussionTester { - t.Handler(fn(t.handler)) - return t -} - func TestDiscussionHandler_CreateRepoDiscussion(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.CreateRepoDiscussion + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) - tester.RequireUser(t) + req, err := server.NewPostRequest( + "/api/v1/models/u/r/discussions", &types.CreateRepoDiscussionRequest{ + Title: "foo", + }, + ) + require.NoError(t, err) tester.mocks.sensitive.EXPECT().CheckRequestV2( - tester.Ctx(), &types.CreateRepoDiscussionRequest{Title: "foo"}, + mock.Anything, &types.CreateRepoDiscussionRequest{Title: "foo"}, ).Return(true, nil) tester.mocks.discussion.EXPECT().CreateRepoDiscussion( - tester.Ctx(), types.CreateRepoDiscussionRequest{ + mock.Anything, types.CreateRepoDiscussionRequest{ CurrentUser: "u", Namespace: "u", Name: "r", @@ -55,159 +58,184 @@ func TestDiscussionHandler_CreateRepoDiscussion(t *testing.T) { Title: "foo", }, ).Return(&types.CreateDiscussionResponse{ID: 123}, nil) - tester.WithParam("repo_type", "models").WithBody(t, &types.CreateRepoDiscussionRequest{ - Title: "foo", - }).Execute() - - tester.ResponseEq(t, 200, tester.OKText, &types.CreateDiscussionResponse{ID: 123}) + require.Nil(t, err) + resp := server.AuthAndSend(t, req) + resp.ResponseEq(t, 200, tester.OKText, &types.CreateDiscussionResponse{ID: 123}) } func TestDiscussionHandler_UpdateDiscussion(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.UpdateDiscussion + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) - tester.RequireUser(t) + req, err := server.NewPutRequest( + "/api/v1/discussions/1", types.UpdateDiscussionRequest{ + CurrentUser: "u", + ID: 1, + Title: "foo", + }, + ) + require.NoError(t, err) tester.mocks.sensitive.EXPECT().CheckRequestV2( - tester.Ctx(), &types.UpdateDiscussionRequest{Title: "foo"}, + mock.Anything, &types.UpdateDiscussionRequest{Title: "foo"}, ).Return(true, nil) tester.mocks.discussion.EXPECT().UpdateDiscussion( - tester.Ctx(), types.UpdateDiscussionRequest{ + mock.Anything, types.UpdateDiscussionRequest{ CurrentUser: "u", ID: 1, Title: "foo", }, ).Return(nil) - tester.WithParam("id", "1").WithBody(t, &types.UpdateDiscussionRequest{ - Title: "foo", - }).Execute() - - tester.ResponseEq(t, 200, tester.OKText, nil) + resp := server.AuthAndSend(t, req) + resp.ResponseEq(t, 200, tester.OKText, nil) } func TestDiscussionHandler_DeleteDiscussion(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.DeleteDiscussion + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) - tester.RequireUser(t) + req, err := server.NewDeleteRequest( + "/api/v1/discussions/1", + ) + require.NoError(t, err) tester.mocks.discussion.EXPECT().DeleteDiscussion( - tester.Ctx(), "u", int64(1), + mock.Anything, "u", int64(1), ).Return(nil) - tester.WithParam("id", "1").Execute() - - tester.ResponseEq(t, 200, tester.OKText, nil) + resp := server.AuthAndSend(t, req) + resp.ResponseEq(t, 200, tester.OKText, nil) } func TestDiscussionHandler_ShowDiscussion(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.ShowDiscussion + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) + req, err := server.NewGetRequest( + "/api/v1/discussions/1", + ) + require.NoError(t, err) tester.mocks.discussion.EXPECT().GetDiscussion( - tester.Ctx(), int64(1), + mock.Anything, int64(1), ).Return(&types.ShowDiscussionResponse{Title: "foo"}, nil) - tester.WithParam("id", "1").Execute() - - tester.ResponseEq(t, 200, tester.OKText, &types.ShowDiscussionResponse{Title: "foo"}) + resp := server.Send(req) + resp.ResponseEq(t, 200, tester.OKText, &types.ShowDiscussionResponse{Title: "foo"}) } func TestDiscussionHandler_ListRepoDiscussions(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.ListRepoDiscussions + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) + req, err := server.NewGetRequest( + "/api/v1/models/u/r/discussions", + ) + require.NoError(t, err) tester.mocks.discussion.EXPECT().ListRepoDiscussions( - tester.Ctx(), types.ListRepoDiscussionRequest{ - CurrentUser: "u", - RepoType: types.ModelRepo, - Namespace: "u", - Name: "r", + mock.Anything, types.ListRepoDiscussionRequest{ + RepoType: types.ModelRepo, + Namespace: "u", + Name: "r", }, ).Return(&types.ListRepoDiscussionResponse{Discussions: []*types.CreateDiscussionResponse{ {ID: 1}, }}, nil) - tester.WithUser().WithParam("repo_type", "models").Execute() - - tester.ResponseEq(t, 200, tester.OKText, &types.ListRepoDiscussionResponse{ + resp := server.Send(req) + resp.ResponseEq(t, 200, tester.OKText, &types.ListRepoDiscussionResponse{ Discussions: []*types.CreateDiscussionResponse{{ID: 1}}, }) - } func TestDiscussionHandler_CreateDiscussionComment(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.CreateDiscussionComment + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) - tester.RequireUser(t) + req, err := server.NewPostRequest( + "/api/v1/discussions/1/comments", + &types.CreateCommentRequest{Content: "foo"}, + ) + require.NoError(t, err) tester.mocks.sensitive.EXPECT().CheckRequestV2( - tester.Ctx(), &types.CreateCommentRequest{Content: "foo"}, + mock.Anything, &types.CreateCommentRequest{Content: "foo"}, ).Return(true, nil) tester.mocks.discussion.EXPECT().CreateDiscussionComment( - tester.Ctx(), types.CreateCommentRequest{ + mock.Anything, types.CreateCommentRequest{ CurrentUser: "u", Content: "foo", CommentableID: 1, }, ).Return(&types.CreateCommentResponse{ID: 1}, nil) - tester.WithParam("id", "1").WithParam("repo_type", "models").WithBody( - t, &types.CreateCommentRequest{Content: "foo"}, - ).Execute() - - tester.ResponseEq(t, 200, tester.OKText, &types.CreateCommentResponse{ID: 1}) + resp := server.AuthAndSend(t, req) + resp.ResponseEq(t, 200, tester.OKText, &types.CreateCommentResponse{ID: 1}) } func TestDiscussionHandler_UpdateComment(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.UpdateComment + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) - tester.RequireUser(t) + req, err := server.NewPutRequest( + "/api/v1/discussions/1/comments/2", + &types.UpdateCommentRequest{Content: "foo"}, + ) + require.NoError(t, err) tester.mocks.sensitive.EXPECT().CheckRequestV2( - tester.Ctx(), &types.UpdateCommentRequest{Content: "foo"}, + mock.Anything, &types.UpdateCommentRequest{Content: "foo"}, ).Return(true, nil) tester.mocks.discussion.EXPECT().UpdateComment( - tester.Ctx(), "u", int64(1), "foo", + mock.Anything, "u", int64(1), "foo", ).Return(nil) - tester.WithParam("id", "1").WithBody( - t, &types.UpdateCommentRequest{Content: "foo"}, - ).Execute() - - tester.ResponseEq(t, 200, tester.OKText, nil) + resp := server.AuthAndSend(t, req) + resp.ResponseEq(t, 200, tester.OKText, nil) } func TestDiscussionHandler_DeleteComment(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.DeleteComment + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) - tester.RequireUser(t) + req, err := server.NewDeleteRequest( + "/api/v1/discussions/1/comments/2", + ) + require.NoError(t, err) tester.mocks.discussion.EXPECT().DeleteComment( - tester.Ctx(), "u", int64(1), + mock.Anything, "u", int64(1), ).Return(nil) - tester.WithParam("id", "1").Execute() - tester.ResponseEq(t, 200, tester.OKText, nil) + resp := server.AuthAndSend(t, req) + resp.ResponseEq(t, 200, tester.OKText, nil) } func TestDiscussionHandler_ListDiscussionComments(t *testing.T) { - tester := NewDiscussionTester(t).WithHandleFunc(func(h *DiscussionHandler) gin.HandlerFunc { - return h.ListDiscussionComments + tester := NewDiscussionTester(t) + server := apitest.NewTestServer(t, func(s *router.ServerImpl) { + s.DiscussionHandler = tester.handler }) + req, err := server.NewGetRequest( + "/api/v1/discussions/1/comments", + ) + require.NoError(t, err) tester.mocks.discussion.EXPECT().ListDiscussionComments( - tester.Ctx(), int64(1), + mock.Anything, int64(1), ).Return([]*types.DiscussionResponse_Comment{{Content: "foo"}}, nil) - tester.WithUser().WithParam("id", "1").Execute() - tester.ResponseEq(t, 200, tester.OKText, []*types.DiscussionResponse_Comment{{Content: "foo"}}) + resp := server.Send(req) + resp.ResponseEq(t, 200, tester.OKText, []*types.DiscussionResponse_Comment{{Content: "foo"}}) } diff --git a/api/handler/repo.go b/api/handler/repo.go index 254cf60e..3b7218dc 100644 --- a/api/handler/repo.go +++ b/api/handler/repo.go @@ -35,6 +35,13 @@ func NewRepoHandler(config *config.Config) (*RepoHandler, error) { }, nil } +func NewRepoHandlerDI(repo component.RepoComponent, deployStatusCheckInterval time.Duration) *RepoHandler { + return &RepoHandler{ + c: repo, + deployStatusCheckInterval: deployStatusCheckInterval, + } +} + type RepoHandler struct { c component.RepoComponent deployStatusCheckInterval time.Duration diff --git a/api/middleware/access_token.go b/api/middleware/access_token.go index 68929aa2..2b17e97b 100644 --- a/api/middleware/access_token.go +++ b/api/middleware/access_token.go @@ -7,11 +7,9 @@ import ( "github.com/gin-gonic/gin" "opencsg.com/csghub-server/api/httpbase" - "opencsg.com/csghub-server/builder/store/database" ) -func GetUserFromAccessToken() gin.HandlerFunc { - userStore := database.NewUserStore() +func (m *Middleware) GetUserFromAccessToken() gin.HandlerFunc { return func(c *gin.Context) { // Get Auzhorization token authHeader := c.Request.Header.Get("Authorization") @@ -19,7 +17,7 @@ func GetUserFromAccessToken() gin.HandlerFunc { if authHeader != "" { // Get token token := strings.TrimPrefix(authHeader, "Bearer ") - user, err := userStore.FindByAccessToken(context.Background(), token) + user, err := m.userComponent.FindByAccessToken(context.Background(), token) if err != nil { slog.Debug("Can not find user by access token", slog.String("token", token)) c.Next() diff --git a/api/middleware/authenticator.go b/api/middleware/authenticator.go index 1497a7fe..9182df38 100644 --- a/api/middleware/authenticator.go +++ b/api/middleware/authenticator.go @@ -12,14 +12,12 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "opencsg.com/csghub-server/api/httpbase" - "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" - "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" ) // BuildJwtSession create and save session with jwt from query string -func BuildJwtSession(jwtSignKey string) gin.HandlerFunc { +func (m *Middleware) BuildJwtSession(jwtSignKey string) gin.HandlerFunc { return func(c *gin.Context) { token := c.Query("jwt") // If no JWT provided, continue with the next middleware @@ -47,7 +45,7 @@ func BuildJwtSession(jwtSignKey string) gin.HandlerFunc { } // AuthSession verify user login by session, ans save user name into context if login -func AuthSession() gin.HandlerFunc { +func (m *Middleware) AuthSession() gin.HandlerFunc { return func(c *gin.Context) { session := sessions.Default(c) userName := session.Get(httpbase.CurrentUserCtxVar) @@ -60,26 +58,9 @@ func AuthSession() gin.HandlerFunc { } } -func Authenticator(config *config.Config) gin.HandlerFunc { - //TODO:change to component - userStore := database.NewUserStore() +func (m *Middleware) Authenticator() gin.HandlerFunc { return func(c *gin.Context) { - sessionObj, sessionExists := c.Get(sessions.DefaultKey) - if sessionExists && sessionObj != nil { - session := sessions.Default(c) - sessionUserName := session.Get(httpbase.CurrentUserCtxVar) - if sessionUserName != nil { - slog.Debug("get username from session", slog.Any("session username", sessionUserName.(string))) - if len(sessionUserName.(string)) > 0 { - httpbase.SetCurrentUser(c, sessionUserName.(string)) - httpbase.SetAuthType(c, httpbase.AuthTypeJwt) - c.Next() - return - } - } - } - - apiToken := config.APIToken + apiToken := m.config.APIToken // Get Auzhorization token authHeader := c.Request.Header.Get("Authorization") @@ -107,7 +88,7 @@ func Authenticator(config *config.Config) gin.HandlerFunc { } if strings.Contains(token, ".") { - claims, err := parseJWTToken(config.JWT.SigningKey, token) + claims, err := parseJWTToken(m.config.JWT.SigningKey, token) if err == nil { httpbase.SetCurrentUser(c, claims.CurrentUser) httpbase.SetAuthType(c, httpbase.AuthTypeJwt) @@ -117,7 +98,7 @@ func Authenticator(config *config.Config) gin.HandlerFunc { } } else { //TODO:use cache to check access token - user, _ := userStore.FindByAccessToken(context.Background(), token) + user, _ := m.userComponent.FindByAccessToken(context.Background(), token) if user != nil { httpbase.SetCurrentUser(c, user.Username) httpbase.SetAccessToken(c, token) @@ -156,9 +137,9 @@ func parseJWTToken(signKey, tokenString string) (*types.JWTClaims, error) { return nil, fmt.Errorf("JWT token claims not match: %+v", *token) } -func OnlyAPIKeyAuthenticator(config *config.Config) gin.HandlerFunc { +func (m *Middleware) NeedAPIKey() gin.HandlerFunc { return func(c *gin.Context) { - apiToken := config.APIToken + apiToken := m.config.APIToken // Get Authorization token authHeader := c.Request.Header.Get("Authorization") @@ -188,7 +169,7 @@ func OnlyAPIKeyAuthenticator(config *config.Config) gin.HandlerFunc { } } -func MustLogin() gin.HandlerFunc { +func (m *Middleware) MustLogin() gin.HandlerFunc { return func(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) if currentUser == "" { @@ -199,10 +180,7 @@ func MustLogin() gin.HandlerFunc { } } -func NeedAdmin(config *config.Config) gin.HandlerFunc { - userSvcClient := rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) - +func (m *Middleware) NeedAdmin() gin.HandlerFunc { return func(ctx *gin.Context) { currentUser := httpbase.GetCurrentUser(ctx) if currentUser == "" { @@ -211,7 +189,7 @@ func NeedAdmin(config *config.Config) gin.HandlerFunc { return } - user, err := userSvcClient.GetUserInfo(ctx, currentUser, currentUser) + user, err := m.userServiceClient.GetUserInfo(ctx, currentUser, currentUser) if err != nil { httpbase.ServerError(ctx, fmt.Errorf("failed to find user, cause:%w", err)) @@ -232,12 +210,3 @@ func NeedAdmin(config *config.Config) gin.HandlerFunc { ctx.Next() } } - -type AuthenticatorCollection struct { - // only can be accessed by api key - NeedAPIKey gin.HandlerFunc - // user need to login first - NeedLogin gin.HandlerFunc - //user must be admin role to access - NeedAdmin gin.HandlerFunc -} diff --git a/api/middleware/git_http_param.go b/api/middleware/git_http_param.go index 8ee8a49f..e5e0181c 100644 --- a/api/middleware/git_http_param.go +++ b/api/middleware/git_http_param.go @@ -11,12 +11,11 @@ import ( "github.com/gin-gonic/gin" "opencsg.com/csghub-server/api/httpbase" - "opencsg.com/csghub-server/builder/store/database" ) const gitSuffix = ".git" -func GitHTTPParamMiddleware() gin.HandlerFunc { +func (m *Middleware) GitHTTPParamMiddleware() gin.HandlerFunc { return func(c *gin.Context) { name := c.Param("name") namespace := c.Param("namespace") @@ -49,7 +48,7 @@ func GitHTTPParamMiddleware() gin.HandlerFunc { } } -func ContentEncoding() gin.HandlerFunc { +func (m *Middleware) ContentEncoding() gin.HandlerFunc { return func(c *gin.Context) { var ( body io.ReadCloser @@ -80,8 +79,7 @@ func ContentEncoding() gin.HandlerFunc { } } -func GetCurrentUserFromHeader() gin.HandlerFunc { - userStore := database.NewUserStore() +func (m *Middleware) GetCurrentUserFromHeader() gin.HandlerFunc { return func(c *gin.Context) { authHeader := c.Request.Header.Get("Authorization") if authHeader != "" && !strings.HasPrefix(authHeader, "X-OPENCSG-Sync-Token") { @@ -96,7 +94,7 @@ func GetCurrentUserFromHeader() gin.HandlerFunc { username := strings.Split(string(authInfo), ":")[0] password := strings.Split(string(authInfo), ":")[1] - user, err := userStore.FindByGitAccessToken(context.Background(), password) + user, err := m.userComponent.FindByGitAccessToken(context.Background(), password) if err != nil { c.Header("WWW-Authenticate", "Basic realm=opencsg-git") c.PureJSON(http.StatusUnauthorized, nil) diff --git a/api/middleware/gitlab_shell_jwt.go b/api/middleware/gitlab_shell_jwt.go index 13d421f8..008559cc 100644 --- a/api/middleware/gitlab_shell_jwt.go +++ b/api/middleware/gitlab_shell_jwt.go @@ -8,7 +8,6 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" - "opencsg.com/csghub-server/common/config" ) const apiSecretHeaderName = "Gitlab-Shell-Api-Request" @@ -27,10 +26,10 @@ func parseGitlabShellJWTToken(signKey, tokenString string) (bool, error) { return true, nil } -func CheckGitlabShellJWTToken(config *config.Config) gin.HandlerFunc { +func (m *Middleware) CheckGitlabShellJWTToken() gin.HandlerFunc { return func(c *gin.Context) { tokenString := c.Request.Header.Get(apiSecretHeaderName) - pass, err := parseGitlabShellJWTToken(config.GitalyServer.JWTSecret, tokenString) + pass, err := parseGitlabShellJWTToken(m.config.GitalyServer.JWTSecret, tokenString) if err != nil { slog.Debug("fail to parse gitlab-shell jwt token", slog.String("token_get", tokenString), slog.Any("error", err)) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) diff --git a/api/middleware/log.go b/api/middleware/log.go index 6f4231c6..bcf56e10 100644 --- a/api/middleware/log.go +++ b/api/middleware/log.go @@ -9,17 +9,16 @@ import ( slogmulti "github.com/samber/slog-multi" "go.opentelemetry.io/contrib/bridges/otelslog" "opencsg.com/csghub-server/api/httpbase" - "opencsg.com/csghub-server/common/config" ) -func Log(config *config.Config) gin.HandlerFunc { +func (m *Middleware) Log() gin.HandlerFunc { handlers := []slog.Handler{ slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ AddSource: false, Level: slog.LevelInfo, }), } - if config.Instrumentation.OTLPEndpoint != "" && config.Instrumentation.OTLPLogging { + if m.config.Instrumentation.OTLPEndpoint != "" && m.config.Instrumentation.OTLPLogging { handlers = append(handlers, otelslog.NewHandler("csghub-server")) } diff --git a/api/middleware/middleware.go b/api/middleware/middleware.go new file mode 100644 index 00000000..4cdd1d61 --- /dev/null +++ b/api/middleware/middleware.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/component" +) + +type Middleware struct { + config *config.Config + userComponent component.UserComponent + mirrorComponent component.MirrorComponent + userServiceClient rpc.UserSvcClient +} + +func NewMiddleware(config *config.Config) *Middleware { + userComponent, err := component.NewUserComponent(config) + if err != nil { + panic(err) + } + mirrorComponent, err := component.NewMirrorComponent(config) + if err != nil { + panic(err) + } + userServiceClient := rpc.NewUserSvcHttpClient(config) + return NewMiddlewareDI(config, userComponent, mirrorComponent, userServiceClient) +} + +func NewMiddlewareDI(config *config.Config, userComponent component.UserComponent, mirrorComponent component.MirrorComponent, userServiceClient rpc.UserSvcClient) *Middleware { + return &Middleware{ + config: config, + userComponent: userComponent, + mirrorComponent: mirrorComponent, + userServiceClient: userServiceClient, + } +} diff --git a/api/middleware/repo.go b/api/middleware/repo.go index 0278854e..c0a5bfda 100644 --- a/api/middleware/repo.go +++ b/api/middleware/repo.go @@ -5,12 +5,11 @@ import ( "strings" "github.com/gin-gonic/gin" - "opencsg.com/csghub-server/builder/store/database" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/common/utils/common" ) -func RepoType(t types.RepositoryType) gin.HandlerFunc { +func (m *Middleware) RepoType(t types.RepositoryType) gin.HandlerFunc { return func(ctx *gin.Context) { slog.Debug("middleware RepoType called", "repo_type", t) common.SetRepoTypeContext(ctx, t) @@ -18,23 +17,22 @@ func RepoType(t types.RepositoryType) gin.HandlerFunc { } } -func RepoMapping(repo_type types.RepositoryType) gin.HandlerFunc { - mirrorStore := database.NewMirrorStore() +func (m *Middleware) RepoMapping(repoType types.RepositoryType) gin.HandlerFunc { return func(ctx *gin.Context) { slog.Debug("middleware RepoMapping called") - common.SetRepoTypeContext(ctx, repo_type) + common.SetRepoTypeContext(ctx, repoType) namespace := ctx.Param("namespace") name := ctx.Param("name") branch := ctx.Param("branch") if branch == "" { branch = ctx.Param("ref") } - mapping := GetMapping(ctx) + mapping := getMapping(ctx) if mapping == types.CSGHubMapping { ctx.Next() return } - repo, err := mirrorStore.FindWithMapping(ctx, repo_type, namespace, name, mapping) + repo, err := m.mirrorComponent.FindWithMapping(ctx, repoType, namespace, name, mapping) //if found mirror, that means this is a synced source, otherwise it's may a user-upload repo if err == nil { namespace, name = repo.NamespaceAndName() @@ -51,7 +49,7 @@ func RepoMapping(repo_type types.RepositoryType) gin.HandlerFunc { } } -func GetMapping(ctx *gin.Context) types.Mapping { +func getMapping(ctx *gin.Context) types.Mapping { fullPath := ctx.FullPath() if strings.HasPrefix(fullPath, "/hf/") { return types.HFMapping diff --git a/api/router/api.go b/api/router/api.go index aef3e7cd..cfb2c609 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -9,6 +9,7 @@ import ( cache "github.com/chenyahui/gin-cache" "github.com/chenyahui/gin-cache/persist" "github.com/gin-contrib/cors" + "github.com/gin-contrib/pprof" "github.com/gin-gonic/gin" swaggerFiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" @@ -18,18 +19,170 @@ import ( "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/api/middleware" "opencsg.com/csghub-server/builder/instrumentation" + "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/temporal" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" "opencsg.com/csghub-server/mirror" ) +type UserProxyHandler handler.InternalServiceProxyHandler +type DatasetViewerPeoxyHandler handler.InternalServiceProxyHandler + +func newUserProxyHandler(config *config.Config) (*UserProxyHandler, error) { + + h, err := handler.NewInternalServiceProxyHandler(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port)) + return (*UserProxyHandler)(h), err +} + +func newDatasetViewerProxyHandler(config *config.Config) (*DatasetViewerPeoxyHandler, error) { + dataViewerAddr := fmt.Sprintf("%s:%d", config.DataViewer.Host, config.DataViewer.Port) + h, err := handler.NewInternalServiceProxyHandler(dataViewerAddr) + return (*DatasetViewerPeoxyHandler)(h), err +} + +func newMemoryStore() *persist.MemoryStore { + return persist.NewMemoryStore(1 * time.Minute) +} + +type BaseServer struct { + Config *config.Config + Middleware *middleware.Middleware + GitHTTPHandler *handler.GitHTTPHandler + UserHandler *handler.UserHandler + OrgHandler *handler.OrganizationHandler + RepoCommonHandler *handler.RepoHandler + ModelHandler *handler.ModelHandler + DsHandler *handler.DatasetHandler + MirrorHandler *handler.MirrorHandler + HfdsHandler *handler.HFDatasetHandler + ListHandler *handler.ListHandler + EvaluationHandler *handler.EvaluationHandler + CodeHandler *handler.CodeHandler + SpaceHandler *handler.SpaceHandler + SpaceResourceHandler *handler.SpaceResourceHandler + SpaceSdkHandler *handler.SpaceSdkHandler + UserProxyHandler *handler.InternalServiceProxyHandler + SshKeyHandler *handler.SSHKeyHandler + TagCtrl *handler.TagsHandler + CallbackCtrl *callback.GitCallbackHandler + SensitiveCtrl *handler.SensitiveHandler + MsHandler *handler.MirrorSourceHandler + CollectionHandler *handler.CollectionHandler + ClusterHandler *handler.ClusterHandler + EventHandler *handler.EventHandler + BroadcastHandler *handler.BroadcastHandler + RuntimeArchHandler *handler.RuntimeArchitectureHandler + SyncHandler *handler.SyncHandler + SyncClientSettingHandler *handler.SyncClientSettingHandler + MeteringHandler *handler.AccountingHandler + RecomHandler *handler.RecomHandler + TelemetryHandler *handler.TelemetryHandler + InternalHandler *handler.InternalHandler + DiscussionHandler *handler.DiscussionHandler + PromptHandler *handler.PromptHandler + DsViewerHandler *handler.InternalServiceProxyHandler + PaymentProxyHandler *handler.InternalServiceProxyHandler + MemoryStore *persist.MemoryStore + UserServiceClient rpc.UserSvcClient + Engine *gin.Engine +} + +func NewBaseServer( + config *config.Config, + middleware *middleware.Middleware, + gitHTTPHandler *handler.GitHTTPHandler, + userHandler *handler.UserHandler, + orgHandler *handler.OrganizationHandler, + repoCommonHandler *handler.RepoHandler, + modelHandler *handler.ModelHandler, + dsHandler *handler.DatasetHandler, + mirrorHandler *handler.MirrorHandler, + hfdsHandler *handler.HFDatasetHandler, + listHandler *handler.ListHandler, + evaluationHandler *handler.EvaluationHandler, + codeHandler *handler.CodeHandler, + spaceHandler *handler.SpaceHandler, + spaceResourceHandler *handler.SpaceResourceHandler, + spaceSdkHandler *handler.SpaceSdkHandler, + userProxyHandler *UserProxyHandler, + datasetViewerProxyHandler *DatasetViewerPeoxyHandler, + sshKeyHandler *handler.SSHKeyHandler, + tagCtrl *handler.TagsHandler, + callbackCtrl *callback.GitCallbackHandler, + sensitiveCtrl *handler.SensitiveHandler, + msHandler *handler.MirrorSourceHandler, + collectionHandler *handler.CollectionHandler, + clusterHandler *handler.ClusterHandler, + eventHandler *handler.EventHandler, + broadcastHandler *handler.BroadcastHandler, + runtimeArchHandler *handler.RuntimeArchitectureHandler, + syncHandler *handler.SyncHandler, + syncClientSettingHandler *handler.SyncClientSettingHandler, + meteringHandler *handler.AccountingHandler, + recomHandler *handler.RecomHandler, + telemetryHandler *handler.TelemetryHandler, + internalHandler *handler.InternalHandler, + discussionHandler *handler.DiscussionHandler, + promptHandler *handler.PromptHandler, + memoryStore *persist.MemoryStore, + userServiceClient rpc.UserSvcClient, + +) (*BaseServer, error) { + server := &BaseServer{ + Config: config, + Middleware: middleware, + GitHTTPHandler: gitHTTPHandler, + UserHandler: userHandler, + OrgHandler: orgHandler, + RepoCommonHandler: repoCommonHandler, + ModelHandler: modelHandler, + DsHandler: dsHandler, + MirrorHandler: mirrorHandler, + HfdsHandler: hfdsHandler, + ListHandler: listHandler, + EvaluationHandler: evaluationHandler, + CodeHandler: codeHandler, + SpaceHandler: spaceHandler, + SpaceResourceHandler: spaceResourceHandler, + SpaceSdkHandler: spaceSdkHandler, + UserProxyHandler: (*handler.InternalServiceProxyHandler)(userProxyHandler), + SshKeyHandler: sshKeyHandler, + TagCtrl: tagCtrl, + CallbackCtrl: callbackCtrl, + SensitiveCtrl: sensitiveCtrl, + MsHandler: msHandler, + CollectionHandler: collectionHandler, + ClusterHandler: clusterHandler, + EventHandler: eventHandler, + BroadcastHandler: broadcastHandler, + RuntimeArchHandler: runtimeArchHandler, + SyncHandler: syncHandler, + SyncClientSettingHandler: syncClientSettingHandler, + MeteringHandler: meteringHandler, + RecomHandler: recomHandler, + TelemetryHandler: telemetryHandler, + InternalHandler: internalHandler, + DiscussionHandler: discussionHandler, + PromptHandler: promptHandler, + MemoryStore: memoryStore, + UserServiceClient: userServiceClient, + DsViewerHandler: (*handler.InternalServiceProxyHandler)(datasetViewerProxyHandler), + } + return server, nil +} + func RunServer(config *config.Config, enableSwagger bool) { stopOtel, err := instrumentation.SetupOTelSDK(context.Background(), config, "csghub-api") if err != nil { panic(err) } - r, err := NewRouter(config, enableSwagger) + slog.Info("init gin http router") + srv, err := InitializeServer(config) + if err != nil { + panic(err) + } + err = srv.RegisterRoutes(enableSwagger) if err != nil { panic(err) } @@ -38,7 +191,7 @@ func RunServer(config *config.Config, enableSwagger bool) { httpbase.GraceServerOpt{ Port: config.APIServer.Port, }, - r, + srv.Engine, ) // Initialize mirror service mirrorService, err := mirror.NewMirrorPriorityQueue(config) @@ -53,12 +206,16 @@ func RunServer(config *config.Config, enableSwagger bool) { server.Run() _ = stopOtel(context.Background()) temporal.Stop() +} +func (s *BaseServer) GetEngine() *gin.Engine { + return s.Engine } -func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) { +func (s *BaseServer) RegisterRoutes(enableSwagger bool) error { r := gin.New() - if config.Instrumentation.OTLPEndpoint != "" { + s.Engine = r + if s.Config.Instrumentation.OTLPEndpoint != "" { r.Use(otelgin.Middleware("csghub-server")) } @@ -69,402 +226,253 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) { AllowAllOrigins: true, })) r.Use(gin.Recovery()) - r.Use(middleware.Log(config)) + r.Use(s.Middleware.Log()) + + //add router for golang pprof + debugGroup := r.Group("/debug", s.Middleware.NeedAPIKey()) + pprof.RouteRegister(debugGroup, "pprof") - gitHTTPHandler, err := handler.NewGitHTTPHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating git http handler:%w", err) - } gitHTTP := r.Group("/:repo_type/:namespace/:name") - gitHTTP.Use(middleware.GitHTTPParamMiddleware()) - gitHTTP.Use(middleware.GetCurrentUserFromHeader()) + gitHTTP.Use(s.Middleware.GitHTTPParamMiddleware()) + gitHTTP.Use(s.Middleware.GetCurrentUserFromHeader()) { - gitHTTP.GET("/info/refs", gitHTTPHandler.InfoRefs) - gitHTTP.POST("/git-upload-pack", middleware.ContentEncoding(), gitHTTPHandler.GitUploadPack) - gitHTTP.POST("/git-receive-pack", middleware.ContentEncoding(), gitHTTPHandler.GitReceivePack) + gitHTTP.GET("/info/refs", s.GitHTTPHandler.InfoRefs) + gitHTTP.POST("/git-upload-pack", s.Middleware.ContentEncoding(), s.GitHTTPHandler.GitUploadPack) + gitHTTP.POST("/git-receive-pack", s.Middleware.ContentEncoding(), s.GitHTTPHandler.GitReceivePack) + lfsGroup := gitHTTP.Group("/info/lfs") { objectsGroup := lfsGroup.Group("/objects") { - objectsGroup.POST("/batch", gitHTTPHandler.LfsBatch) - objectsGroup.PUT("/:oid/:size", gitHTTPHandler.LfsUpload) - lfsGroup.GET("/:oid", gitHTTPHandler.LfsDownload) + objectsGroup.POST("/batch", s.GitHTTPHandler.LfsBatch) + objectsGroup.PUT("/:oid/:size", s.GitHTTPHandler.LfsUpload) + lfsGroup.GET("/:oid", s.GitHTTPHandler.LfsDownload) } - lfsGroup.POST("/verify", gitHTTPHandler.LfsVerify) + lfsGroup.POST("/verify", s.GitHTTPHandler.LfsVerify) locksGroup := lfsGroup.Group("/locks") { - locksGroup.GET("", gitHTTPHandler.ListLocks) - locksGroup.POST("", gitHTTPHandler.CreateLock) - locksGroup.POST("/verify", gitHTTPHandler.VerifyLock) - locksGroup.POST("/:lid/unlock", gitHTTPHandler.UnLock) + locksGroup.GET("", s.GitHTTPHandler.ListLocks) + locksGroup.POST("", s.GitHTTPHandler.CreateLock) + locksGroup.POST("/verify", s.GitHTTPHandler.VerifyLock) + locksGroup.POST("/:lid/unlock", s.GitHTTPHandler.UnLock) } } } - - r.Use(middleware.Authenticator(config)) - - authCollection := middleware.AuthenticatorCollection{} - authCollection.NeedAPIKey = middleware.OnlyAPIKeyAuthenticator(config) - authCollection.NeedAdmin = middleware.NeedAdmin(config) + r.Use(s.Middleware.Authenticator()) if enableSwagger { r.GET("/api/v1/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) } - // User routes - userHandler, err := handler.NewUserHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating user controller:%w", err) - } - orgHandler, err := handler.NewOrganizationHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating organization controller:%w", err) - } - - repoCommonHandler, err := handler.NewRepoHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating repo common handler: %w", err) - } - modelHandler, err := handler.NewModelHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating model controller:%w", err) - } - dsHandler, err := handler.NewDatasetHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating dataset handler:%w", err) - } - - // Mirror - mirrorHandler, err := handler.NewMirrorHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating mirror controller:%w", err) - } - - hfdsHandler, err := handler.NewHFDatasetHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating HF dataset handler: %w", err) - } //create routes for hf - createMappingRoutes(r, "/hf", hfdsHandler, repoCommonHandler, modelHandler, userHandler) + createMappingRoutes(r, "/hf", s.HfdsHandler, s.RepoCommonHandler, s.ModelHandler, s.UserHandler, s.Middleware) //create routes for ms - createMappingRoutes(r, "/ms", hfdsHandler, repoCommonHandler, modelHandler, userHandler) + createMappingRoutes(r, "/ms", s.HfdsHandler, s.RepoCommonHandler, s.ModelHandler, s.UserHandler, s.Middleware) //create routes for csg - createMappingRoutes(r, "/csg", hfdsHandler, repoCommonHandler, modelHandler, userHandler) + createMappingRoutes(r, "/csg", s.HfdsHandler, s.RepoCommonHandler, s.ModelHandler, s.UserHandler, s.Middleware) apiGroup := r.Group("/api/v1") - // TODO:use middleware to handle common response - // - memoryStore := persist.NewMemoryStore(1 * time.Minute) // List trending models and datasets routes - listHandler, err := handler.NewListHandler(config) - if err != nil { - return nil, fmt.Errorf("error creatring list handler: %v", err) - } { - apiGroup.POST("/list/models_by_path", cache.CacheByRequestURI(memoryStore, 1*time.Minute), listHandler.ListModelsByPath) - apiGroup.POST("/list/datasets_by_path", cache.CacheByRequestURI(memoryStore, 1*time.Minute), listHandler.ListDatasetsByPath) - apiGroup.POST("/list/spaces_by_path", cache.CacheByRequestURI(memoryStore, 1*time.Minute), listHandler.ListSpacesByPath) + apiGroup.POST("/list/models_by_path", cache.CacheByRequestURI(s.MemoryStore, 1*time.Minute), s.ListHandler.ListModelsByPath) + apiGroup.POST("/list/datasets_by_path", cache.CacheByRequestURI(s.MemoryStore, 1*time.Minute), s.ListHandler.ListDatasetsByPath) + apiGroup.POST("/list/spaces_by_path", cache.CacheByRequestURI(s.MemoryStore, 1*time.Minute), s.ListHandler.ListSpacesByPath) } //evaluation handler - evaluationHandler, err := handler.NewEvaluationHandler(config) - if err != nil { - return nil, fmt.Errorf("error creatring evaluation handler: %v", err) - } - - createEvaluationRoutes(apiGroup, evaluationHandler) + createEvaluationRoutes(apiGroup, s.EvaluationHandler) // Model routes - createModelRoutes(config, apiGroup, authCollection.NeedAPIKey, modelHandler, repoCommonHandler) + createModelRoutes(s.Config, apiGroup, s.Middleware, s.ModelHandler, s.RepoCommonHandler) // Dataset routes - createDatasetRoutes(config, apiGroup, dsHandler, repoCommonHandler) + createDatasetRoutes(s.Config, apiGroup, s.DsHandler, s.RepoCommonHandler, s.Middleware) - codeHandler, err := handler.NewCodeHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating code handler:%w", err) - } // Code routes - createCodeRoutes(config, apiGroup, codeHandler, repoCommonHandler) + createCodeRoutes(s.Config, apiGroup, s.CodeHandler, s.RepoCommonHandler, s.Middleware) - spaceHandler, err := handler.NewSpaceHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating space handler:%w", err) - } // space routers - createSpaceRoutes(config, apiGroup, spaceHandler, repoCommonHandler) - - spaceResourceHandler, err := handler.NewSpaceResourceHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating space resource handler:%w", err) - } + createSpaceRoutes(s.Config, apiGroup, s.SpaceHandler, s.RepoCommonHandler, s.Middleware) spaceResource := apiGroup.Group("space_resources") { - spaceResource.GET("", spaceResourceHandler.Index) - spaceResource.POST("", authCollection.NeedAdmin, spaceResourceHandler.Create) - spaceResource.PUT("/:id", authCollection.NeedAdmin, spaceResourceHandler.Update) - spaceResource.DELETE("/:id", authCollection.NeedAdmin, spaceResourceHandler.Delete) - } - - spaceSdkHandler, err := handler.NewSpaceSdkHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating space sdk handler:%w", err) + spaceResource.GET("", s.SpaceResourceHandler.Index) + spaceResource.POST("", s.Middleware.NeedAdmin(), s.SpaceResourceHandler.Create) + spaceResource.PUT("/:id", s.Middleware.NeedAdmin(), s.SpaceResourceHandler.Update) + spaceResource.DELETE("/:id", s.Middleware.NeedAdmin(), s.SpaceResourceHandler.Delete) } spaceSdk := apiGroup.Group("space_sdks") { - spaceSdk.GET("", spaceSdkHandler.Index) - spaceSdk.POST("", authCollection.NeedAPIKey, spaceSdkHandler.Create) - spaceSdk.PUT("/:id", authCollection.NeedAPIKey, spaceSdkHandler.Update) - spaceSdk.DELETE("/:id", authCollection.NeedAPIKey, spaceSdkHandler.Delete) + spaceSdk.GET("", s.SpaceSdkHandler.Index) + spaceSdk.POST("", s.Middleware.NeedAPIKey(), s.SpaceSdkHandler.Create) + spaceSdk.PUT("/:id", s.Middleware.NeedAPIKey(), s.SpaceSdkHandler.Update) + spaceSdk.DELETE("/:id", s.Middleware.NeedAPIKey(), s.SpaceSdkHandler.Delete) } - userProxyHandler, err := handler.NewInternalServiceProxyHandler(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port)) - if err != nil { - return nil, fmt.Errorf("error creating user proxy handler:%w", err) - } - - createUserRoutes(apiGroup, authCollection.NeedAPIKey, userProxyHandler, userHandler) - + createUserRoutes(apiGroup, s.Middleware, s.UserProxyHandler, s.UserHandler) tokenGroup := apiGroup.Group("token") { - tokenGroup.POST("/:app/:token_name", userProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name")) - tokenGroup.PUT("/:app/:token_name", userProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name")) - tokenGroup.DELETE("/:app/:token_name", userProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name")) + tokenGroup.POST("/:app/:token_name", s.UserProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name")) + tokenGroup.PUT("/:app/:token_name", s.UserProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name")) + tokenGroup.DELETE("/:app/:token_name", s.UserProxyHandler.ProxyToApi("/api/v1/token/%s/%s", "app", "token_name")) // check token info - tokenGroup.GET("/:token_value", authCollection.NeedAPIKey, userProxyHandler.ProxyToApi("/api/v1/token/%s", "token_value")) + tokenGroup.GET("/:token_value", s.Middleware.NeedAPIKey(), s.UserProxyHandler.ProxyToApi("/api/v1/token/%s", "token_value")) } - sshKeyHandler, err := handler.NewSSHKeyHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating user controller:%w", err) - } { - apiGroup.GET("/user/:username/ssh_keys", sshKeyHandler.Index) - apiGroup.POST("/user/:username/ssh_keys", sshKeyHandler.Create) - apiGroup.DELETE("/user/:username/ssh_key/:name", sshKeyHandler.Delete) + apiGroup.GET("/user/:username/ssh_keys", s.SshKeyHandler.Index) + apiGroup.POST("/user/:username/ssh_keys", s.SshKeyHandler.Create) + apiGroup.DELETE("/user/:username/ssh_key/:name", s.SshKeyHandler.Delete) } { - apiGroup.GET("/organizations", userProxyHandler.Proxy) - apiGroup.POST("/organizations", userProxyHandler.Proxy) - apiGroup.GET("/organization/:namespace", userProxyHandler.ProxyToApi("/api/v1/organization/%s", "namespace")) - apiGroup.PUT("/organization/:namespace", userProxyHandler.ProxyToApi("/api/v1/organization/%s", "namespace")) - apiGroup.DELETE("/organization/:namespace", userProxyHandler.ProxyToApi("/api/v1/organization/%s", "namespace")) + apiGroup.GET("/organizations", s.UserProxyHandler.Proxy) + apiGroup.POST("/organizations", s.UserProxyHandler.Proxy) + apiGroup.GET("/organization/:namespace", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s", "namespace")) + apiGroup.PUT("/organization/:namespace", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s", "namespace")) + apiGroup.DELETE("/organization/:namespace", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s", "namespace")) // Organization assets - apiGroup.GET("/organization/:namespace/models", orgHandler.Models) - apiGroup.GET("/organization/:namespace/datasets", orgHandler.Datasets) - apiGroup.GET("/organization/:namespace/codes", orgHandler.Codes) - apiGroup.GET("/organization/:namespace/spaces", orgHandler.Spaces) - apiGroup.GET("/organization/:namespace/collections", orgHandler.Collections) - apiGroup.GET("/organization/:namespace/prompts", orgHandler.Prompts) + apiGroup.GET("/organization/:namespace/models", s.OrgHandler.Models) + apiGroup.GET("/organization/:namespace/datasets", s.OrgHandler.Datasets) + apiGroup.GET("/organization/:namespace/codes", s.OrgHandler.Codes) + apiGroup.GET("/organization/:namespace/spaces", s.OrgHandler.Spaces) + apiGroup.GET("/organization/:namespace/collections", s.OrgHandler.Collections) + apiGroup.GET("/organization/:namespace/prompts", s.OrgHandler.Prompts) } { - apiGroup.GET("/organization/:namespace/members", userProxyHandler.ProxyToApi("/api/v1/organization/%s/members", "namespace")) - apiGroup.POST("/organization/:namespace/members", userProxyHandler.ProxyToApi("/api/v1/organization/%s/members", "namespace")) - apiGroup.GET("/organization/:namespace/members/:username", userProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) - apiGroup.PUT("/organization/:namespace/members/:username", userProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) - apiGroup.DELETE("/organization/:namespace/members/:username", userProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) + apiGroup.GET("/organization/:namespace/members", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s/members", "namespace")) + apiGroup.POST("/organization/:namespace/members", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s/members", "namespace")) + apiGroup.GET("/organization/:namespace/members/:username", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) + apiGroup.PUT("/organization/:namespace/members/:username", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) + apiGroup.DELETE("/organization/:namespace/members/:username", s.UserProxyHandler.ProxyToApi("/api/v1/organization/%s/members/%s", "namespace", "username")) } // Tag - tagCtrl, err := handler.NewTagHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating tag controller:%w", err) - } - createTagsRoutes(apiGroup, tagCtrl) + createTagsRoutes(apiGroup, s.TagCtrl) // JWT token - apiGroup.POST("/jwt/token", authCollection.NeedAPIKey, userProxyHandler.Proxy) - apiGroup.GET("/jwt/:token", authCollection.NeedAPIKey, userProxyHandler.ProxyToApi("/api/v1/jwt/%s", "token")) - apiGroup.GET("/users", userProxyHandler.Proxy) + apiGroup.POST("/jwt/token", s.Middleware.NeedAPIKey(), s.UserProxyHandler.Proxy) + apiGroup.GET("/jwt/:token", s.Middleware.NeedAPIKey(), s.UserProxyHandler.ProxyToApi("/api/v1/jwt/%s", "token")) + apiGroup.GET("/users", s.UserProxyHandler.Proxy) // callback - callbackCtrl, err := callback.NewGitCallbackHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating callback controller:%w", err) - } - apiGroup.POST("/callback/git", callbackCtrl.Handle) - apiGroup.GET("/callback/casdoor", userProxyHandler.Proxy) + apiGroup.POST("/callback/git", s.CallbackCtrl.Handle) + apiGroup.GET("/callback/casdoor", s.UserProxyHandler.Proxy) // Sensive check - if config.SensitiveCheck.Enable { - sensitiveCtrl, err := handler.NewSensitiveHandler(config) + if s.Config.SensitiveCheck.Enable { + sensitiveCtrl, err := handler.NewSensitiveHandler(s.Config) if err != nil { - return nil, fmt.Errorf("error creating sensitive handler:%w", err) + return fmt.Errorf("error creating sensitive handler:%w", err) } apiGroup.POST("/sensitive/text", sensitiveCtrl.Text) apiGroup.POST("/sensitive/image", sensitiveCtrl.Image) } // MirrorSource - msHandler, err := handler.NewMirrorSourceHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating mirror source controller:%w", err) - } + apiGroup.GET("/mirrors", s.MirrorHandler.Index) - apiGroup.GET("/mirrors", mirrorHandler.Index) mirror := apiGroup.Group("/mirror") { - mirror.GET("/sources", msHandler.Index) - mirror.POST("/sources", msHandler.Create) - mirror.PUT("/sources/:id", msHandler.Update) - mirror.DELETE("/sources/:id", msHandler.Delete) - mirror.GET("/sources/:id", msHandler.Get) - mirror.POST("/repo", mirrorHandler.CreateMirrorRepo) - mirror.GET("/repos", mirrorHandler.Repos) + mirror.GET("/sources", s.MsHandler.Index) + mirror.POST("/sources", s.MsHandler.Create) + mirror.PUT("/sources/:id", s.MsHandler.Update) + mirror.DELETE("/sources/:id", s.MsHandler.Delete) + mirror.GET("/sources/:id", s.MsHandler.Get) + mirror.POST("/repo", s.MirrorHandler.CreateMirrorRepo) + mirror.GET("/repos", s.MirrorHandler.Repos) } - collectionHandler, err := handler.NewCollectionHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating collection handler:%w", err) - } collections := apiGroup.Group("/collections") { // list all collection - collections.GET("", collectionHandler.Index) - collections.POST("", collectionHandler.Create) - collections.GET("/:id", collectionHandler.GetCollection) - collections.PUT("/:id", collectionHandler.UpdateCollection) - collections.DELETE("/:id", collectionHandler.DeleteCollection) - collections.POST("/:id/repos", collectionHandler.AddRepoToCollection) - collections.DELETE("/:id/repos", collectionHandler.RemoveRepoFromCollection) + collections.GET("", s.CollectionHandler.Index) + collections.POST("", s.CollectionHandler.Create) + collections.GET("/:id", s.CollectionHandler.GetCollection) + collections.PUT("/:id", s.CollectionHandler.UpdateCollection) + collections.DELETE("/:id", s.CollectionHandler.DeleteCollection) + collections.POST("/:id/repos", s.CollectionHandler.AddRepoToCollection) + collections.DELETE("/:id/repos", s.CollectionHandler.RemoveRepoFromCollection) } // cluster infos - clusterHandler, err := handler.NewClusterHandler(config) - if err != nil { - return nil, fmt.Errorf("fail to creating cluster handler: %w", err) - } cluster := apiGroup.Group("/cluster") { - cluster.GET("", clusterHandler.Index) - cluster.GET("/:id", clusterHandler.GetClusterById) - cluster.PUT("/:id", authCollection.NeedAPIKey, clusterHandler.Update) + cluster.GET("", s.ClusterHandler.Index) + cluster.GET("/:id", s.ClusterHandler.GetClusterById) + cluster.PUT("/:id", s.Middleware.NeedAPIKey(), s.ClusterHandler.Update) } - eventHandler, err := handler.NewEventHandler() - if err != nil { - return nil, fmt.Errorf("error creating event handler:%w", err) - } event := apiGroup.Group("/events") - event.POST("", eventHandler.Create) + event.POST("", s.EventHandler.Create) // routes for broadcast - broadcastHandler, err := handler.NewBroadcastHandler() - if err != nil { - return nil, fmt.Errorf("error creating broadcast handler:%w", err) - } broadcast := apiGroup.Group("/broadcasts") adminBroadcast := apiGroup.Group("/admin/broadcasts") - adminBroadcast.Use(authCollection.NeedAdmin) - - adminBroadcast.POST("", broadcastHandler.Create) - adminBroadcast.PUT("/:id", broadcastHandler.Update) - adminBroadcast.GET("", broadcastHandler.Index) - adminBroadcast.GET("/:id", broadcastHandler.Show) - broadcast.GET("/:id", broadcastHandler.Show) - broadcast.GET("/active", broadcastHandler.Active) + adminBroadcast.Use(s.Middleware.NeedAdmin()) + + broadcast.GET("/active", s.BroadcastHandler.Active) + adminBroadcast.POST("", s.BroadcastHandler.Create) + adminBroadcast.PUT("/:id", s.BroadcastHandler.Update) + adminBroadcast.GET("", s.BroadcastHandler.Index) + adminBroadcast.GET("/:id", s.BroadcastHandler.Show) + broadcast.GET("/:id", s.BroadcastHandler.Show) // end routes for broadcast + createRuntimeFrameworkRoutes( + apiGroup, s.Middleware, s.ModelHandler, s.RuntimeArchHandler, s.RepoCommonHandler, + ) - runtimeArchHandler, err := handler.NewRuntimeArchitectureHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating runtime framework architecture handler:%w", err) - } - - createRuntimeFrameworkRoutes(apiGroup, authCollection.NeedAPIKey, modelHandler, runtimeArchHandler, repoCommonHandler) - - syncHandler, err := handler.NewSyncHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating sync handler:%w", err) - } - syncClientSettingHandler, err := handler.NewSyncClientSettingHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating sync client setting handler:%w", err) - } syncGroup := apiGroup.Group("sync") { - syncGroup.GET("/version/latest", syncHandler.Latest) + syncGroup.GET("/version/latest", s.SyncHandler.Latest) // syncGroup.GET("/version/oldest", syncHandler.Oldest) - syncGroup.GET("/client_setting", syncClientSettingHandler.Show) - syncGroup.POST("/client_setting", syncClientSettingHandler.Create) + syncGroup.GET("/client_setting", s.SyncClientSettingHandler.Show) + syncGroup.POST("/client_setting", s.SyncClientSettingHandler.Create) } - accountingHandler, err := handler.NewAccountingHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating accounting handler setting handler:%w", err) - } - - createAccountRoutes(apiGroup, authCollection.NeedAPIKey, accountingHandler) - - recomHandler, err := handler.NewRecomHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating recomHandler,%w", err) - } recomGroup := apiGroup.Group("/recom") { - recomGroup.POST("opweight", authCollection.NeedAPIKey, recomHandler.SetOpWeight) + recomGroup.POST("opweight", s.Middleware.NeedAdmin(), s.RecomHandler.SetOpWeight) } // telemetry - telemetryHandler, err := handler.NewTelemetryHandler() - if err != nil { - return nil, fmt.Errorf("error creating telemetry handler:%w", err) - } teleGroup := apiGroup.Group("/telemetry") - teleGroup.POST("/usage", telemetryHandler.Usage) + teleGroup.POST("/usage", s.TelemetryHandler.Usage) // internal API for gitaly to check request permissions - internalHandler, err := handler.NewInternalHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating internalHandler,%w", err) - } - needGitlabShellJWTToken := middleware.CheckGitlabShellJWTToken(config) - r.GET("/api/v4/internal/authorized_keys", needGitlabShellJWTToken, internalHandler.GetAuthorizedKeys) - r.POST("/api/v4/internal/allowed", needGitlabShellJWTToken, internalHandler.SSHAllowed) - r.POST("/api/v4/internal/pre_receive", needGitlabShellJWTToken, internalHandler.PreReceive) - r.POST("api/v4/internal/lfs_authenticate", needGitlabShellJWTToken, internalHandler.LfsAuthenticate) - r.POST("/api/v4/internal/post_receive", needGitlabShellJWTToken, internalHandler.PostReceive) + needGitlabShellJWTToken := s.Middleware.CheckGitlabShellJWTToken() + r.GET("/api/v4/internal/authorized_keys", needGitlabShellJWTToken, s.InternalHandler.GetAuthorizedKeys) + r.POST("/api/v4/internal/allowed", needGitlabShellJWTToken, s.InternalHandler.SSHAllowed) + r.POST("/api/v4/internal/pre_receive", needGitlabShellJWTToken, s.InternalHandler.PreReceive) + r.POST("api/v4/internal/lfs_authenticate", needGitlabShellJWTToken, s.InternalHandler.LfsAuthenticate) + r.POST("/api/v4/internal/post_receive", needGitlabShellJWTToken, s.InternalHandler.PostReceive) internalGroup := apiGroup.Group("/internal") { - internalGroup.POST("/allowed", needGitlabShellJWTToken, internalHandler.Allowed) - internalGroup.POST("/pre_receive", needGitlabShellJWTToken, internalHandler.PreReceive) - internalGroup.POST("/post_receive", needGitlabShellJWTToken, internalHandler.PostReceive) - } - - discussionHandler, err := handler.NewDiscussionHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating discussion handler:%w", err) + internalGroup.POST("/allowed", needGitlabShellJWTToken, s.InternalHandler.Allowed) + internalGroup.POST("/pre_receive", needGitlabShellJWTToken, s.InternalHandler.PreReceive) + internalGroup.POST("/post_receive", needGitlabShellJWTToken, s.InternalHandler.PostReceive) } - createDiscussionRoutes(apiGroup, authCollection.NeedAPIKey, discussionHandler) + createDiscussionRoutes(apiGroup, s.Middleware.NeedAPIKey(), s.DiscussionHandler) // prompt - promptHandler, err := handler.NewPromptHandler(config) - if err != nil { - return nil, fmt.Errorf("error creating prompt handler,%w", err) - } - createPromptRoutes(apiGroup, promptHandler) - - dataViewerAddr := fmt.Sprintf("%s:%d", config.DataViewer.Host, config.DataViewer.Port) - dsViewerHandler, err := handler.NewInternalServiceProxyHandler(dataViewerAddr) - if err != nil { - return nil, fmt.Errorf("error creating dataset viewer proxy:%w", err) - } - createDataViewerRoutes(apiGroup, dsViewerHandler) + createPromptRoutes(apiGroup, s.PromptHandler) + // Dataset viewer proxy + createDataViewerRoutes(apiGroup, s.DsViewerHandler) // space template - templateHandler, err := handler.NewSpaceTemplateHandler(config) + templateHandler, err := handler.NewSpaceTemplateHandler(s.Config) if err != nil { - return nil, fmt.Errorf("error creating space template proxy:%w", err) + return fmt.Errorf("error creating space template proxy:%w", err) } - createSpaceTemplateRoutes(apiGroup, authCollection, templateHandler) - return r, nil + createSpaceTemplateRoutes(apiGroup, s.Middleware, templateHandler) + + return nil } func createEvaluationRoutes(apiGroup *gin.RouterGroup, evaluationHandler *handler.EvaluationHandler) { @@ -477,7 +485,7 @@ func createEvaluationRoutes(apiGroup *gin.RouterGroup, evaluationHandler *handle } } -func createModelRoutes(config *config.Config, apiGroup *gin.RouterGroup, needAPIKey gin.HandlerFunc, modelHandler *handler.ModelHandler, repoCommonHandler *handler.RepoHandler) { +func createModelRoutes(config *config.Config, apiGroup *gin.RouterGroup, middleware *middleware.Middleware, modelHandler *handler.ModelHandler, repoCommonHandler *handler.RepoHandler) { // Models routes modelsGroup := apiGroup.Group("/models") { @@ -568,7 +576,7 @@ func createModelRoutes(config *config.Config, apiGroup *gin.RouterGroup, needAPI } } -func createDatasetRoutes(config *config.Config, apiGroup *gin.RouterGroup, dsHandler *handler.DatasetHandler, repoCommonHandler *handler.RepoHandler) { +func createDatasetRoutes(config *config.Config, apiGroup *gin.RouterGroup, dsHandler *handler.DatasetHandler, repoCommonHandler *handler.RepoHandler, middleware *middleware.Middleware) { datasetsGroup := apiGroup.Group("/datasets") { datasetsGroup.POST("", dsHandler.Create) @@ -610,7 +618,7 @@ func createDatasetRoutes(config *config.Config, apiGroup *gin.RouterGroup, dsHan } } -func createCodeRoutes(config *config.Config, apiGroup *gin.RouterGroup, codeHandler *handler.CodeHandler, repoCommonHandler *handler.RepoHandler) { +func createCodeRoutes(config *config.Config, apiGroup *gin.RouterGroup, codeHandler *handler.CodeHandler, repoCommonHandler *handler.RepoHandler, middleware *middleware.Middleware) { codesGroup := apiGroup.Group("/codes") { codesGroup.POST("", codeHandler.Create) @@ -651,7 +659,7 @@ func createCodeRoutes(config *config.Config, apiGroup *gin.RouterGroup, codeHand } } -func createSpaceRoutes(config *config.Config, apiGroup *gin.RouterGroup, spaceHandler *handler.SpaceHandler, repoCommonHandler *handler.RepoHandler) { +func createSpaceRoutes(config *config.Config, apiGroup *gin.RouterGroup, spaceHandler *handler.SpaceHandler, repoCommonHandler *handler.RepoHandler, middleware *middleware.Middleware) { spaces := apiGroup.Group("/spaces") { // list all spaces @@ -710,7 +718,7 @@ func createSpaceRoutes(config *config.Config, apiGroup *gin.RouterGroup, spaceHa } } -func createUserRoutes(apiGroup *gin.RouterGroup, needAPIKey gin.HandlerFunc, userProxyHandler *handler.InternalServiceProxyHandler, userHandler *handler.UserHandler) { +func createUserRoutes(apiGroup *gin.RouterGroup, middleware *middleware.Middleware, userProxyHandler *handler.InternalServiceProxyHandler, userHandler *handler.UserHandler) { // depricated { apiGroup.POST("/users", userProxyHandler.ProxyToApi("/api/v1/user")) @@ -754,10 +762,11 @@ func createUserRoutes(apiGroup *gin.RouterGroup, needAPIKey gin.HandlerFunc, use apiGroup.GET("/user/:username/tokens", userProxyHandler.ProxyToApi("/api/v1/user/%s/tokens", "username")) // serverless list - apiGroup.GET("/user/:username/run/serverless", needAPIKey, userHandler.GetRunServerless) + apiGroup.GET("/user/:username/run/serverless", middleware.NeedAPIKey(), userHandler.GetRunServerless) } -func createRuntimeFrameworkRoutes(apiGroup *gin.RouterGroup, needAPIKey gin.HandlerFunc, modelHandler *handler.ModelHandler, runtimeArchHandler *handler.RuntimeArchitectureHandler, repoCommonHandler *handler.RepoHandler) { +func createRuntimeFrameworkRoutes(apiGroup *gin.RouterGroup, middleware *middleware.Middleware, modelHandler *handler.ModelHandler, runtimeArchHandler *handler.RuntimeArchitectureHandler, repoCommonHandler *handler.RepoHandler) { + needAPIKey := middleware.NeedAPIKey() runtimeFramework := apiGroup.Group("/runtime_framework") { runtimeFramework.GET("/:id/models", modelHandler.ListByRuntimeFrameworkID) @@ -776,17 +785,7 @@ func createRuntimeFrameworkRoutes(apiGroup *gin.RouterGroup, needAPIKey gin.Hand } } -func createAccountRoutes(apiGroup *gin.RouterGroup, needAPIKey gin.HandlerFunc, accountingHandler *handler.AccountingHandler) { - accountingGroup := apiGroup.Group("/accounting") - { - meterGroup := accountingGroup.Group("/metering") - { - meterGroup.GET("/:id/statements", accountingHandler.QueryMeteringStatementByUserID) - } - } -} - -func createMappingRoutes(r *gin.Engine, group string, hfdsHandler *handler.HFDatasetHandler, repoCommonHandler *handler.RepoHandler, modelHandler *handler.ModelHandler, userHandler *handler.UserHandler) { +func createMappingRoutes(r *gin.Engine, group string, hfdsHandler *handler.HFDatasetHandler, repoCommonHandler *handler.RepoHandler, modelHandler *handler.ModelHandler, userHandler *handler.UserHandler, middleware *middleware.Middleware) { // Huggingface SDK routes hfGroup := r.Group(group) { @@ -887,13 +886,13 @@ func createDataViewerRoutes(apiGroup *gin.RouterGroup, dsViewerHandler *handler. } } -func createSpaceTemplateRoutes(apiGroup *gin.RouterGroup, authCollection middleware.AuthenticatorCollection, templateHandler *handler.SpaceTemplateHandler) { +func createSpaceTemplateRoutes(apiGroup *gin.RouterGroup, middleware *middleware.Middleware, templateHandler *handler.SpaceTemplateHandler) { spaceTemplateGrp := apiGroup.Group("/space_templates") { - spaceTemplateGrp.GET("", authCollection.NeedAdmin, templateHandler.Index) - spaceTemplateGrp.POST("", authCollection.NeedAdmin, templateHandler.Create) - spaceTemplateGrp.PUT("/:id", authCollection.NeedAdmin, templateHandler.Update) - spaceTemplateGrp.DELETE("/:id", authCollection.NeedAdmin, templateHandler.Delete) + spaceTemplateGrp.GET("", middleware.NeedAdmin(), templateHandler.Index) + spaceTemplateGrp.POST("", middleware.NeedAdmin(), templateHandler.Create) + spaceTemplateGrp.PUT("/:id", middleware.NeedAdmin(), templateHandler.Update) + spaceTemplateGrp.DELETE("/:id", middleware.NeedAdmin(), templateHandler.Delete) spaceTemplateGrp.GET("/:type", templateHandler.List) } } diff --git a/api/router/api_ce.go b/api/router/api_ce.go new file mode 100644 index 00000000..7e98d313 --- /dev/null +++ b/api/router/api_ce.go @@ -0,0 +1,17 @@ +//go:build !ee && !saas + +package router + +type ServerImpl struct { + *BaseServer +} + +func NewServer(base *BaseServer) *ServerImpl { + return &ServerImpl{ + BaseServer: base, + } +} + +func (s *ServerImpl) RegisterRoutes(enableSwagger bool) error { + return s.BaseServer.RegisterRoutes(enableSwagger) +} diff --git a/api/router/rproxy.go b/api/router/rproxy.go index 733702e9..7893c957 100644 --- a/api/router/rproxy.go +++ b/api/router/rproxy.go @@ -22,7 +22,8 @@ func NewRProxyRouter(config *config.Config) (*gin.Engine, error) { AllowAllOrigins: true, })) r.Use(gin.Recovery()) - r.Use(middleware.Log(config)) + middleware := middleware.NewMiddleware(config) + r.Use(middleware.Log()) store := cookie.NewStore([]byte(config.Space.SessionSecretKey)) store.Options(sessions.Options{ // SameSite: http.SameSiteNoneMode, // support 3rd part @@ -34,7 +35,7 @@ func NewRProxyRouter(config *config.Config) (*gin.Engine, error) { //to access space with jwt token in query string r.Use(middleware.BuildJwtSession(config.JWT.SigningKey)) //to access model,fintune with any kind of tokens in auth header - r.Use(middleware.Authenticator(config)) + r.Use(middleware.Authenticator()) handler, err := handler.NewRProxyHandler(config) if err != nil { diff --git a/api/router/wire.go b/api/router/wire.go new file mode 100644 index 00000000..4cf28b16 --- /dev/null +++ b/api/router/wire.go @@ -0,0 +1,53 @@ +//go:build wireinject +// +build wireinject + +package router + +import ( + "github.com/google/wire" + "opencsg.com/csghub-server/api/handler" + "opencsg.com/csghub-server/api/handler/callback" + "opencsg.com/csghub-server/api/middleware" + "opencsg.com/csghub-server/builder/rpc" +) + +var BaseServerSet = wire.NewSet( + handler.NewGitHTTPHandler, + handler.NewUserHandler, + handler.NewOrganizationHandler, + handler.NewRepoHandler, + handler.NewModelHandler, + handler.NewDatasetHandler, + handler.NewMirrorHandler, + handler.NewHFDatasetHandler, + handler.NewListHandler, + handler.NewEvaluationHandler, + handler.NewCodeHandler, + handler.NewSpaceHandler, + handler.NewSpaceResourceHandler, + handler.NewSpaceSdkHandler, + handler.NewSSHKeyHandler, + handler.NewTagHandler, + handler.NewSensitiveHandler, + handler.NewMirrorSourceHandler, + handler.NewCollectionHandler, + handler.NewClusterHandler, + handler.NewEventHandler, + handler.NewBroadcastHandler, + handler.NewRuntimeArchitectureHandler, + handler.NewSyncHandler, + handler.NewSyncClientSettingHandler, + handler.NewAccountingHandler, + handler.NewRecomHandler, + handler.NewTelemetryHandler, + handler.NewPromptHandler, + handler.NewInternalHandler, + handler.NewDiscussionHandler, + newMemoryStore, + newUserProxyHandler, + newDatasetViewerProxyHandler, + rpc.NewUserSvcHttpClient, + callback.NewGitCallbackHandler, + NewBaseServer, + middleware.NewMiddleware, +) diff --git a/api/router/wire_ce.go b/api/router/wire_ce.go new file mode 100644 index 00000000..1f64c3f9 --- /dev/null +++ b/api/router/wire_ce.go @@ -0,0 +1,17 @@ +//go:build wireinject && !ee && !saas +// +build wireinject,!ee,!saas + +package router + +import ( + "github.com/google/wire" + "opencsg.com/csghub-server/common/config" +) + +func InitializeServer(config *config.Config) (*ServerImpl, error) { + wire.Build( + BaseServerSet, + NewServer, + ) + return &ServerImpl{}, nil +} diff --git a/api/router/wire_gen_ce.go b/api/router/wire_gen_ce.go new file mode 100644 index 00000000..04bfe486 --- /dev/null +++ b/api/router/wire_gen_ce.go @@ -0,0 +1,166 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:build !wireinject && !ee && !saas +// +build !wireinject,!ee,!saas + +//go:generate go run -mod=mod github.com/google/wire/cmd/wire + +package router + +import ( + "opencsg.com/csghub-server/api/handler" + "opencsg.com/csghub-server/api/handler/callback" + "opencsg.com/csghub-server/api/middleware" + "opencsg.com/csghub-server/builder/rpc" + "opencsg.com/csghub-server/common/config" +) + +// Injectors from wire_ce.go: + +func InitializeServer(config2 *config.Config) (*ServerImpl, error) { + middlewareMiddleware := middleware.NewMiddleware(config2) + gitHTTPHandler, err := handler.NewGitHTTPHandler(config2) + if err != nil { + return nil, err + } + userHandler, err := handler.NewUserHandler(config2) + if err != nil { + return nil, err + } + organizationHandler, err := handler.NewOrganizationHandler(config2) + if err != nil { + return nil, err + } + repoHandler, err := handler.NewRepoHandler(config2) + if err != nil { + return nil, err + } + modelHandler, err := handler.NewModelHandler(config2) + if err != nil { + return nil, err + } + datasetHandler, err := handler.NewDatasetHandler(config2) + if err != nil { + return nil, err + } + mirrorHandler, err := handler.NewMirrorHandler(config2) + if err != nil { + return nil, err + } + hfDatasetHandler, err := handler.NewHFDatasetHandler(config2) + if err != nil { + return nil, err + } + listHandler, err := handler.NewListHandler(config2) + if err != nil { + return nil, err + } + evaluationHandler, err := handler.NewEvaluationHandler(config2) + if err != nil { + return nil, err + } + codeHandler, err := handler.NewCodeHandler(config2) + if err != nil { + return nil, err + } + spaceHandler, err := handler.NewSpaceHandler(config2) + if err != nil { + return nil, err + } + spaceResourceHandler, err := handler.NewSpaceResourceHandler(config2) + if err != nil { + return nil, err + } + spaceSdkHandler, err := handler.NewSpaceSdkHandler(config2) + if err != nil { + return nil, err + } + userProxyHandler, err := newUserProxyHandler(config2) + if err != nil { + return nil, err + } + datasetViewerPeoxyHandler, err := newDatasetViewerProxyHandler(config2) + if err != nil { + return nil, err + } + sshKeyHandler, err := handler.NewSSHKeyHandler(config2) + if err != nil { + return nil, err + } + tagsHandler, err := handler.NewTagHandler(config2) + if err != nil { + return nil, err + } + gitCallbackHandler, err := callback.NewGitCallbackHandler(config2) + if err != nil { + return nil, err + } + sensitiveHandler, err := handler.NewSensitiveHandler(config2) + if err != nil { + return nil, err + } + mirrorSourceHandler, err := handler.NewMirrorSourceHandler(config2) + if err != nil { + return nil, err + } + collectionHandler, err := handler.NewCollectionHandler(config2) + if err != nil { + return nil, err + } + clusterHandler, err := handler.NewClusterHandler(config2) + if err != nil { + return nil, err + } + eventHandler, err := handler.NewEventHandler() + if err != nil { + return nil, err + } + broadcastHandler, err := handler.NewBroadcastHandler() + if err != nil { + return nil, err + } + runtimeArchitectureHandler, err := handler.NewRuntimeArchitectureHandler(config2) + if err != nil { + return nil, err + } + syncHandler, err := handler.NewSyncHandler(config2) + if err != nil { + return nil, err + } + syncClientSettingHandler, err := handler.NewSyncClientSettingHandler(config2) + if err != nil { + return nil, err + } + accountingHandler, err := handler.NewAccountingHandler(config2) + if err != nil { + return nil, err + } + recomHandler, err := handler.NewRecomHandler(config2) + if err != nil { + return nil, err + } + telemetryHandler, err := handler.NewTelemetryHandler() + if err != nil { + return nil, err + } + internalHandler, err := handler.NewInternalHandler(config2) + if err != nil { + return nil, err + } + discussionHandler, err := handler.NewDiscussionHandler(config2) + if err != nil { + return nil, err + } + promptHandler, err := handler.NewPromptHandler(config2) + if err != nil { + return nil, err + } + memoryStore := newMemoryStore() + userSvcClient := rpc.NewUserSvcHttpClient(config2) + baseServer, err := NewBaseServer(config2, middlewareMiddleware, gitHTTPHandler, userHandler, organizationHandler, repoHandler, modelHandler, datasetHandler, mirrorHandler, hfDatasetHandler, listHandler, evaluationHandler, codeHandler, spaceHandler, spaceResourceHandler, spaceSdkHandler, userProxyHandler, datasetViewerPeoxyHandler, sshKeyHandler, tagsHandler, gitCallbackHandler, sensitiveHandler, mirrorSourceHandler, collectionHandler, clusterHandler, eventHandler, broadcastHandler, runtimeArchitectureHandler, syncHandler, syncClientSettingHandler, accountingHandler, recomHandler, telemetryHandler, internalHandler, discussionHandler, promptHandler, memoryStore, userSvcClient) + if err != nil { + return nil, err + } + serverImpl := NewServer(baseServer) + return serverImpl, nil +} diff --git a/builder/rpc/user_svc_client.go b/builder/rpc/user_svc_client.go index 33545996..b33f45b1 100644 --- a/builder/rpc/user_svc_client.go +++ b/builder/rpc/user_svc_client.go @@ -6,6 +6,7 @@ import ( "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/builder/git/membership" + "opencsg.com/csghub-server/common/config" ) type UserSvcClient interface { @@ -14,15 +15,14 @@ type UserSvcClient interface { GetUserInfo(ctx context.Context, userName, visitorName string) (*User, error) } -//go:generate mockgen -destination=mocks/client.go -package=mocks . Client - type UserSvcHttpClient struct { hc *HttpClient } -func NewUserSvcHttpClient(endpoint string, opts ...RequestOption) UserSvcClient { +func NewUserSvcHttpClient(config *config.Config) UserSvcClient { + endpoint := fmt.Sprintf("%s:%d", config.User.Host, config.User.Port) return &UserSvcHttpClient{ - hc: NewHttpClient(endpoint, opts...), + hc: NewHttpClient(endpoint, AuthWithApiKey(config.APIToken)), } } diff --git a/component/broadcast_test.go b/component/broadcast_test.go index 79c7754e..e8bce249 100644 --- a/component/broadcast_test.go +++ b/component/broadcast_test.go @@ -15,7 +15,7 @@ func TestBroadcastComponent_GetBroadcast(t *testing.T) { broadcast := database.Broadcast{ID: 1, Content: "test", BcType: "banner", Theme: "light", Status: "active"} - cc.mocks.stores.BroadcastMock().EXPECT().Get(ctx, 1).Return( + cc.mocks.stores.BroadcastMock().EXPECT().Get(ctx, int64(1)).Return( &broadcast, nil, ) data, err := cc.GetBroadcast(ctx, 1) @@ -66,7 +66,7 @@ func TestBroadcastComponent_UpdateBroadcast(t *testing.T) { dbBroadcast := database.Broadcast{ID: 1, Content: "test", BcType: "banner", Theme: "light", Status: "active"} broadcast := types.Broadcast{ID: 1, Content: "test", BcType: "banner", Theme: "light", Status: "active"} - cc.mocks.stores.BroadcastMock().EXPECT().Get(ctx, 1).Return(&dbBroadcast, nil) + cc.mocks.stores.BroadcastMock().EXPECT().Get(ctx, int64(1)).Return(&dbBroadcast, nil) cc.mocks.stores.BroadcastMock().EXPECT().Update(ctx, dbBroadcast).Return(&dbBroadcast, nil) data, _ := cc.UpdateBroadcast(ctx, broadcast) diff --git a/component/code.go b/component/code.go index 365e77b4..69169500 100644 --- a/component/code.go +++ b/component/code.go @@ -43,8 +43,7 @@ func NewCodeComponent(config *config.Config) (CodeComponent, error) { c.gitServer = gs c.config = config c.userLikesStore = database.NewUserLikesStore() - c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) + c.userSvcClient = rpc.NewUserSvcHttpClient(config) return c, nil } diff --git a/component/collection.go b/component/collection.go index 9d3a8243..866685d8 100644 --- a/component/collection.go +++ b/component/collection.go @@ -36,8 +36,7 @@ func NewCollectionComponent(config *config.Config) (CollectionComponent, error) cc.userStore = database.NewUserStore() cc.orgStore = database.NewOrgStore() cc.userLikesStore = database.NewUserLikesStore() - cc.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) + cc.userSvcClient = rpc.NewUserSvcHttpClient(config) spaceComponent, err := NewSpaceComponent(config) if err != nil { return nil, err diff --git a/component/dataset.go b/component/dataset.go index 70357607..a46b6593 100644 --- a/component/dataset.go +++ b/component/dataset.go @@ -54,8 +54,7 @@ func NewDatasetComponent(config *config.Config) (DatasetComponent, error) { if err != nil { return nil, fmt.Errorf("failed to create git server, error: %w", err) } - c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) + c.userSvcClient = rpc.NewUserSvcHttpClient(config) c.gitServer = gs c.config = config return c, nil diff --git a/component/mirror.go b/component/mirror.go index 267f05ad..9cd074b2 100644 --- a/component/mirror.go +++ b/component/mirror.go @@ -50,6 +50,7 @@ type MirrorComponent interface { Repos(ctx context.Context, currentUser string, per, page int) ([]types.MirrorRepo, int, error) Index(ctx context.Context, currentUser string, per, page int, search string) ([]types.Mirror, int, error) Statistics(ctx context.Context, currentUser string) ([]types.MirrorStatusCount, error) + FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*database.Repository, error) } func NewMirrorComponent(config *config.Config) (MirrorComponent, error) { @@ -690,3 +691,7 @@ func (c *mirrorComponentImpl) Statistics(ctx context.Context, currentUser string return scs, nil } + +func (c *mirrorComponentImpl) FindWithMapping(ctx context.Context, repoType types.RepositoryType, namespace, name string, mapping types.Mapping) (*database.Repository, error) { + return c.mirrorStore.FindWithMapping(ctx, repoType, namespace, name, mapping) +} diff --git a/component/model.go b/component/model.go index 815c69cf..7d67b3c2 100644 --- a/component/model.go +++ b/component/model.go @@ -115,8 +115,7 @@ func NewModelComponent(config *config.Config) (ModelComponent, error) { if err != nil { return nil, err } - c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) + c.userSvcClient = rpc.NewUserSvcHttpClient(config) c.recomStore = database.NewRecomStore() return c, nil } diff --git a/component/prompt.go b/component/prompt.go index deb07435..d4ad4ceb 100644 --- a/component/prompt.go +++ b/component/prompt.go @@ -82,8 +82,7 @@ func NewPromptComponent(cfg *config.Config) (PromptComponent, error) { if err != nil { return nil, fmt.Errorf("failed to create git server,cause:%w", err) } - usc := rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", cfg.User.Host, cfg.User.Port), - rpc.AuthWithApiKey(cfg.APIToken)) + usc := rpc.NewUserSvcHttpClient(cfg) return &promptComponentImpl{ config: cfg, userStore: database.NewUserStore(), diff --git a/component/repo.go b/component/repo.go index 01c729b6..dc904fd5 100644 --- a/component/repo.go +++ b/component/repo.go @@ -208,8 +208,7 @@ func NewRepoComponent(config *config.Config) (RepoComponent, error) { return nil, newError } c.lfsBucket = config.S3.Bucket - c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) + c.userSvcClient = rpc.NewUserSvcHttpClient(config) c.runtimeFrameworksStore = database.NewRuntimeFrameworksStore() c.deployTaskStore = database.NewDeployTaskStore() c.deployer = deploy.NewDeployer() diff --git a/component/space_ce.go b/component/space_ce.go index ad4e60ac..6710dedf 100644 --- a/component/space_ce.go +++ b/component/space_ce.go @@ -4,7 +4,6 @@ package component import ( "context" - "fmt" "opencsg.com/csghub-server/builder/deploy" "opencsg.com/csghub-server/builder/git" @@ -36,8 +35,7 @@ func NewSpaceComponent(config *config.Config) (SpaceComponent, error) { c.serverBaseUrl = config.APIServer.PublicDomain c.userLikesStore = database.NewUserLikesStore() c.config = config - c.userSvcClient = rpc.NewUserSvcHttpClient(fmt.Sprintf("%s:%d", config.User.Host, config.User.Port), - rpc.AuthWithApiKey(config.APIToken)) + c.userSvcClient = rpc.NewUserSvcHttpClient(config) c.deployTaskStore = database.NewDeployTaskStore() c.git, err = git.NewGitServer(config) diff --git a/component/user.go b/component/user.go index 0340d1a4..83b408ff 100644 --- a/component/user.go +++ b/component/user.go @@ -39,6 +39,8 @@ type UserComponent interface { GetUserByName(ctx context.Context, userName string) (*database.User, error) Prompts(ctx context.Context, req *types.UserPromptsReq) ([]types.PromptRes, int, error) Evaluations(ctx context.Context, req *types.UserEvaluationReq) ([]types.ArgoWorkFlowRes, int, error) + FindByAccessToken(ctx context.Context, token string) (*database.User, error) + FindByGitAccessToken(ctx context.Context, token string) (*database.User, error) } func NewUserComponent(config *config.Config) (UserComponent, error) { @@ -656,3 +658,11 @@ func (c *userComponentImpl) Evaluations(ctx context.Context, req *types.UserEval } return res.List, res.Total, nil } + +func (c *userComponentImpl) FindByAccessToken(ctx context.Context, token string) (*database.User, error) { + return c.userStore.FindByAccessToken(ctx, token) +} + +func (c *userComponentImpl) FindByGitAccessToken(ctx context.Context, token string) (*database.User, error) { + return c.userStore.FindByGitAccessToken(ctx, token) +} diff --git a/dataviewer/router/api.go b/dataviewer/router/api.go index a80587fe..95ad8f17 100644 --- a/dataviewer/router/api.go +++ b/dataviewer/router/api.go @@ -27,14 +27,15 @@ func NewDataViewerRouter(config *config.Config, tc temporal.Client) (*gin.Engine r.Use(otelgin.Middleware("csghub-dataviewer")) } r.Use(gin.Recovery()) - r.Use(middleware.Log(config)) - needAPIKey := middleware.OnlyAPIKeyAuthenticator(config) + middleware := middleware.NewMiddleware(config) + r.Use(middleware.Log()) + needAPIKey := middleware.NeedAPIKey() //add router for golang pprof debugGroup := r.Group("/debug", needAPIKey) pprof.RouteRegister(debugGroup, "pprof") - r.Use(middleware.Authenticator(config)) + r.Use(middleware.Authenticator()) apiGroup := r.Group("/api/v1") datasetsGrp := apiGroup.Group("/datasets/:namespace/:name") diff --git a/go.mod b/go.mod index 0188b097..82df70cc 100644 --- a/go.mod +++ b/go.mod @@ -314,7 +314,7 @@ require ( golang.org/x/crypto v0.32.0 golang.org/x/net v0.34.0 // indirect golang.org/x/oauth2 v0.24.0 // indirect - golang.org/x/sync v0.10.0 // indirect + golang.org/x/sync v0.10.0 golang.org/x/sys v0.29.0 // indirect golang.org/x/term v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/moderation/router/api.go b/moderation/router/api.go index 15c40716..c5b259e4 100644 --- a/moderation/router/api.go +++ b/moderation/router/api.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" + "github.com/gin-contrib/pprof" "github.com/gin-gonic/gin" "opencsg.com/csghub-server/api/middleware" "opencsg.com/csghub-server/common/config" @@ -13,7 +14,15 @@ import ( func NewRouter(config *config.Config) (*gin.Engine, error) { r := gin.New() r.Use(gin.Recovery()) - r.Use(middleware.Log(config)) + middleware := middleware.NewMiddleware(config) + r.Use(middleware.Log()) + + needAPIKey := middleware.NeedAPIKey() + + //add router for golang pprof + debugGroup := r.Group("/debug", needAPIKey) + pprof.RouteRegister(debugGroup, "pprof") + // r.Use(middleware.Authenticator(config)) apiV1Group := r.Group("/api/v1") diff --git a/runner/router/api.go b/runner/router/api.go index 95531b59..df4ad054 100644 --- a/runner/router/api.go +++ b/runner/router/api.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" + "github.com/gin-contrib/pprof" "github.com/gin-gonic/gin" "opencsg.com/csghub-server/api/middleware" "opencsg.com/csghub-server/builder/deploy/cluster" @@ -14,7 +15,14 @@ import ( func NewHttpServer(config *config.Config) (*gin.Engine, error) { r := gin.New() r.Use(gin.Recovery()) - r.Use(middleware.Log(config)) + middleware := middleware.NewMiddleware(config) + r.Use(middleware.Log()) + + needAPIKey := middleware.NeedAPIKey() + + //add router for golang pprof + debugGroup := r.Group("/debug", needAPIKey) + pprof.RouteRegister(debugGroup, "pprof") clusterPool, err := cluster.NewClusterPool() if err != nil { diff --git a/user/router/api.go b/user/router/api.go index 2d729799..5e4aeb92 100644 --- a/user/router/api.go +++ b/user/router/api.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" + "github.com/gin-contrib/pprof" "github.com/gin-gonic/gin" "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/api/middleware" @@ -15,8 +16,15 @@ import ( func NewRouter(config *config.Config) (*gin.Engine, error) { r := gin.New() r.Use(gin.Recovery()) - r.Use(middleware.Log(config)) - r.Use(middleware.Authenticator(config)) + middleware := middleware.NewMiddleware(config) + r.Use(middleware.Log()) + needAPIKey := middleware.NeedAPIKey() + + //add router for golang pprof + debugGroup := r.Group("/debug", needAPIKey) + pprof.RouteRegister(debugGroup, "pprof") + + r.Use(middleware.Authenticator()) userHandler, err := handler.NewUserHandler(config) if err != nil { @@ -46,7 +54,6 @@ func NewRouter(config *config.Config) (*gin.Engine, error) { userGroup := apiV1Group.Group("/user") tokenGroup := apiV1Group.Group("/token") - needAPIKey := middleware.OnlyAPIKeyAuthenticator(config) jwtHandler, err := handler.NewJWTHandler(config) if err != nil { return nil, fmt.Errorf("error creating jwt handler:%w", err) diff --git a/wire/ce_header b/wire/ce_header new file mode 100644 index 00000000..1bbd5269 --- /dev/null +++ b/wire/ce_header @@ -0,0 +1,4 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:build !wireinject && !ee && !saas +// +build !wireinject,!ee,!saas \ No newline at end of file diff --git a/wire/ee_header b/wire/ee_header new file mode 100644 index 00000000..c3b1fea6 --- /dev/null +++ b/wire/ee_header @@ -0,0 +1,4 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:build !wireinject && ee +// +build !wireinject,ee \ No newline at end of file diff --git a/wire/saas_header b/wire/saas_header new file mode 100644 index 00000000..12f23025 --- /dev/null +++ b/wire/saas_header @@ -0,0 +1,4 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:build !wireinject && saas +// +build !wireinject,saas \ No newline at end of file