Skip to content

Commit 544eda6

Browse files
authored
fix: service tokens (#102)
* fix: jwt service token duration * feat: add bulk categorization mcp api
1 parent 220d4e1 commit 544eda6

File tree

6 files changed

+122
-113
lines changed

6 files changed

+122
-113
lines changed

pkg/auth/jwt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ func (j *Service) CreateServiceToken(
164164
func (j *Service) generateTokenInternal(req *GenerateTokenRequest) (*JwtClaims, string, error) {
165165
claims := &JwtClaims{
166166
RegisteredClaims: &jwt2.RegisteredClaims{
167-
ExpiresAt: jwt2.NewNumericDate(time.Now().UTC().Add(j.ttl)),
167+
ExpiresAt: jwt2.NewNumericDate(time.Now().UTC().Add(req.TTL)),
168168
ID: uuid.NewString(),
169169
},
170170
UserID: req.User.ID,

pkg/auth/jwt_test.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ package auth_test
22

33
import (
44
"context"
5+
"testing"
6+
"time"
7+
58
"github.com/ft-t/go-money/pkg/auth"
69
"github.com/ft-t/go-money/pkg/database"
710
"github.com/golang-jwt/jwt/v5"
811
"github.com/stretchr/testify/assert"
9-
"testing"
10-
"time"
1112
)
1213

1314
func TestJwtToken_Success(t *testing.T) {
@@ -166,11 +167,16 @@ func TestCreateServiceToken_Success(t *testing.T) {
166167
keyGen := auth.NewKeyGenerator()
167168
key := keyGen.Generate()
168169

169-
jwtGenerator, err := auth.NewService(string(keyGen.Serialize(key)), 5*time.Minute)
170+
serviceTTL := 5 * time.Minute
171+
requestTTL := 24 * time.Hour
172+
173+
jwtGenerator, err := auth.NewService(string(keyGen.Serialize(key)), serviceTTL)
170174
assert.NoError(t, err)
171175

176+
beforeGeneration := time.Now().UTC()
177+
172178
claims, token, err := jwtGenerator.CreateServiceToken(context.TODO(), &auth.GenerateTokenRequest{
173-
TTL: 24 * time.Hour,
179+
TTL: requestTTL,
174180
User: &database.User{
175181
ID: 123,
176182
Login: "testuser",
@@ -183,6 +189,11 @@ func TestCreateServiceToken_Success(t *testing.T) {
183189
assert.EqualValues(t, 123, claims.UserID)
184190
assert.Equal(t, auth.ServiceTokenType, claims.TokenType)
185191
assert.NotEmpty(t, claims.ID)
192+
193+
expectedExpiresAt := beforeGeneration.Add(requestTTL)
194+
actualExpiresAt := claims.ExpiresAt.Time
195+
196+
assert.WithinDuration(t, expectedExpiresAt, actualExpiresAt, 5*time.Second)
186197
}
187198

188199
func TestCreateServiceToken_Failure(t *testing.T) {

pkg/mcp/server.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,16 @@ func (s *Server) registerTools() {
6060
)
6161
s.mcpServer.AddTool(queryTool, s.handleQuery)
6262

63-
setTransactionCategoryTool := mcp.NewTool(
64-
"set_transaction_category",
65-
mcp.WithDescription("Set or clear the category of a transaction"),
66-
mcp.WithNumber(
67-
"transaction_id",
68-
mcp.Description("The ID of the transaction to update"),
63+
bulkSetTransactionCategoryTool := mcp.NewTool(
64+
"bulk_set_transaction_category",
65+
mcp.WithDescription("Set or clear categories for multiple transactions in a single call"),
66+
mcp.WithArray(
67+
"assignments",
68+
mcp.Description("Array of objects with transaction_id (required) and category_id (optional, null to clear)"),
6969
mcp.Required(),
7070
),
71-
mcp.WithNumber(
72-
"category_id",
73-
mcp.Description("The category ID to set (omit or null to clear category)"),
74-
),
7571
)
76-
s.mcpServer.AddTool(setTransactionCategoryTool, s.handleSetTransactionCategory)
72+
s.mcpServer.AddTool(bulkSetTransactionCategoryTool, s.handleBulkSetTransactionCategory)
7773

7874
createCategoryTool := mcp.NewTool(
7975
"create_category",

pkg/mcp/transaction_category_tool.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,48 @@ import (
44
"context"
55
"fmt"
66

7-
"github.com/ft-t/go-money/pkg/database"
87
"github.com/mark3labs/mcp-go/mcp"
98
)
109

11-
func (s *Server) handleSetTransactionCategory(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
10+
func (s *Server) handleBulkSetTransactionCategory(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1211
args := request.GetArguments()
1312

14-
transactionID, ok := args["transaction_id"].(float64)
15-
if !ok {
16-
return mcp.NewToolResultError("transaction_id parameter is required"), nil
17-
}
18-
19-
var categoryID *int32
20-
if catID, exists := args["category_id"]; exists && catID != nil {
21-
catIDFloat, ok := catID.(float64)
22-
if !ok {
23-
return mcp.NewToolResultError("category_id must be a number or null"), nil
24-
}
25-
catIDInt := int32(catIDFloat)
26-
categoryID = &catIDInt
13+
assignmentsRaw, ok := args["assignments"].([]any)
14+
if !ok || len(assignmentsRaw) == 0 {
15+
return mcp.NewToolResultError("assignments parameter is required and must be a non-empty array"), nil
2716
}
2817

2918
queryCtx, cancel := context.WithTimeout(ctx, queryTimeout)
3019
defer cancel()
3120

32-
var tx database.Transaction
33-
if err := s.db.WithContext(queryCtx).First(&tx, int64(transactionID)).Error; err != nil {
34-
return mcp.NewToolResultError(fmt.Sprintf("transaction not found: %v", err)), nil
35-
}
21+
for i, item := range assignmentsRaw {
22+
itemMap, ok := item.(map[string]any)
23+
if !ok {
24+
return mcp.NewToolResultError(fmt.Sprintf("assignment[%d] must be an object", i)), nil
25+
}
3626

37-
tx.CategoryID = categoryID
27+
txID, ok := itemMap["transaction_id"].(float64)
28+
if !ok {
29+
return mcp.NewToolResultError(fmt.Sprintf("assignment[%d].transaction_id is required and must be a number", i)), nil
30+
}
3831

39-
if err := s.db.WithContext(queryCtx).Save(&tx).Error; err != nil {
40-
return mcp.NewToolResultError(fmt.Sprintf("failed to update transaction: %v", err)), nil
41-
}
32+
var categoryID *int32
33+
if catID, exists := itemMap["category_id"]; exists && catID != nil {
34+
catIDFloat, ok := catID.(float64)
35+
if !ok {
36+
return mcp.NewToolResultError(fmt.Sprintf("assignment[%d].category_id must be a number or null", i)), nil
37+
}
38+
catIDInt := int32(catIDFloat)
39+
categoryID = &catIDInt
40+
}
4241

43-
if categoryID == nil {
44-
return mcp.NewToolResultText(fmt.Sprintf("Transaction %d category cleared", int64(transactionID))), nil
42+
if err := s.db.WithContext(queryCtx).
43+
Table("transactions").
44+
Where("id = ?", int64(txID)).
45+
Update("category_id", categoryID).Error; err != nil {
46+
return mcp.NewToolResultError(fmt.Sprintf("failed to update transaction %d: %v", int64(txID), err)), nil
47+
}
4548
}
4649

47-
return mcp.NewToolResultText(fmt.Sprintf("Transaction %d category set to %d", int64(transactionID), *categoryID)), nil
50+
return mcp.NewToolResultText(fmt.Sprintf("Updated %d transactions", len(assignmentsRaw))), nil
4851
}

pkg/mcp/transaction_category_tool_test.go

Lines changed: 69 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package mcp_test
33
import (
44
"context"
55
"testing"
6-
"time"
76

87
"github.com/DATA-DOG/go-sqlmock"
98
"github.com/golang/mock/gomock"
@@ -15,26 +14,38 @@ import (
1514
"github.com/ft-t/go-money/pkg/testingutils"
1615
)
1716

18-
func TestServer_HandleSetTransactionCategory_Success(t *testing.T) {
17+
func TestServer_HandleBulkSetTransactionCategory_Success(t *testing.T) {
1918
type tc struct {
20-
name string
21-
txID float64
22-
categoryID any
23-
expected string
19+
name string
20+
assignments []map[string]any
2421
}
2522

2623
cases := []tc{
2724
{
28-
name: "set category",
29-
txID: 1,
30-
categoryID: float64(5),
31-
expected: "Transaction 1 category set to 5",
25+
name: "set single category",
26+
assignments: []map[string]any{
27+
{"transaction_id": float64(1), "category_id": float64(5)},
28+
},
3229
},
3330
{
34-
name: "clear category with nil",
35-
txID: 2,
36-
categoryID: nil,
37-
expected: "Transaction 2 category cleared",
31+
name: "set multiple categories",
32+
assignments: []map[string]any{
33+
{"transaction_id": float64(1), "category_id": float64(5)},
34+
{"transaction_id": float64(2), "category_id": float64(10)},
35+
},
36+
},
37+
{
38+
name: "clear category",
39+
assignments: []map[string]any{
40+
{"transaction_id": float64(1), "category_id": nil},
41+
},
42+
},
43+
{
44+
name: "mixed set and clear",
45+
assignments: []map[string]any{
46+
{"transaction_id": float64(1), "category_id": float64(5)},
47+
{"transaction_id": float64(2), "category_id": nil},
48+
},
3849
},
3950
}
4051

@@ -44,114 +55,102 @@ func TestServer_HandleSetTransactionCategory_Success(t *testing.T) {
4455
gormDB, mockDB, mock := testingutils.GormMock()
4556
defer func() { _ = mockDB.Close() }()
4657

47-
selectRows := sqlmock.NewRows([]string{
48-
"id", "source_amount", "source_currency", "destination_amount",
49-
"destination_currency", "source_account_id", "destination_account_id",
50-
"category_id", "transaction_type", "created_at", "updated_at",
51-
}).AddRow(
52-
int64(c.txID), nil, "USD", nil, "USD", 1, 2, nil, 3,
53-
time.Now(), time.Now(),
54-
)
55-
mock.ExpectQuery("SELECT \\* FROM \"transactions\"").WillReturnRows(selectRows)
56-
mock.ExpectBegin()
57-
mock.ExpectExec("UPDATE \"transactions\"").WillReturnResult(sqlmock.NewResult(0, 1))
58-
mock.ExpectCommit()
59-
60-
catSvc := NewMockCategoryService(ctrl)
61-
rulesSvc := NewMockRulesService(ctrl)
62-
dryRunSvc := NewMockDryRunService(ctrl)
58+
for range c.assignments {
59+
mock.ExpectBegin()
60+
mock.ExpectExec("UPDATE \"transactions\"").WillReturnResult(sqlmock.NewResult(0, 1))
61+
mock.ExpectCommit()
62+
}
6363

6464
server := gomcp.NewServer(&gomcp.ServerConfig{
6565
DB: gormDB,
6666
Docs: "test docs",
67-
CategorySvc: catSvc,
68-
RulesSvc: rulesSvc,
69-
DryRunSvc: dryRunSvc,
67+
CategorySvc: NewMockCategoryService(ctrl),
68+
RulesSvc: NewMockRulesService(ctrl),
69+
DryRunSvc: NewMockDryRunService(ctrl),
7070
})
7171

72-
mcpServer := server.MCPServer()
73-
tool := mcpServer.GetTool("set_transaction_category")
72+
tool := server.MCPServer().GetTool("bulk_set_transaction_category")
7473
require.NotNil(t, tool)
7574

76-
args := map[string]any{"transaction_id": c.txID}
77-
if c.categoryID != nil {
78-
args["category_id"] = c.categoryID
75+
assignmentsAny := make([]any, len(c.assignments))
76+
for i, a := range c.assignments {
77+
assignmentsAny[i] = a
7978
}
8079

8180
result, err := tool.Handler(context.Background(), mcp.CallToolRequest{
8281
Params: mcp.CallToolParams{
83-
Name: "set_transaction_category",
84-
Arguments: args,
82+
Name: "bulk_set_transaction_category",
83+
Arguments: map[string]any{"assignments": assignmentsAny},
8584
},
8685
})
8786

8887
require.NoError(t, err)
8988
require.NotNil(t, result)
9089
assert.False(t, result.IsError)
91-
assert.Contains(t, result.Content[0].(mcp.TextContent).Text, c.expected)
90+
assert.Contains(t, result.Content[0].(mcp.TextContent).Text, "Updated")
9291
})
9392
}
9493
}
9594

96-
func TestServer_HandleSetTransactionCategory_Failure(t *testing.T) {
95+
func TestServer_HandleBulkSetTransactionCategory_Failure(t *testing.T) {
9796
type tc struct {
9897
name string
9998
args map[string]any
100-
setupMock func(sqlmock.Sqlmock)
10199
expectedError string
102100
}
103101

104102
cases := []tc{
105103
{
106-
name: "missing transaction_id",
104+
name: "missing assignments",
107105
args: map[string]any{},
108-
setupMock: func(m sqlmock.Sqlmock) {},
109-
expectedError: "transaction_id parameter is required",
106+
expectedError: "assignments parameter is required",
110107
},
111108
{
112-
name: "invalid category_id type",
113-
args: map[string]any{"transaction_id": float64(1), "category_id": "invalid"},
114-
setupMock: func(m sqlmock.Sqlmock) {},
115-
expectedError: "category_id must be a number or null",
109+
name: "empty assignments",
110+
args: map[string]any{"assignments": []any{}},
111+
expectedError: "assignments parameter is required and must be a non-empty array",
116112
},
117113
{
118-
name: "transaction not found",
119-
args: map[string]any{"transaction_id": float64(999)},
120-
setupMock: func(m sqlmock.Sqlmock) {
121-
m.ExpectQuery("SELECT \\* FROM \"transactions\"").
122-
WillReturnRows(sqlmock.NewRows([]string{"id"}))
123-
},
124-
expectedError: "transaction not found",
114+
name: "invalid assignment type",
115+
args: map[string]any{"assignments": []any{"not an object"}},
116+
expectedError: "assignment[0] must be an object",
117+
},
118+
{
119+
name: "missing transaction_id",
120+
args: map[string]any{"assignments": []any{
121+
map[string]any{"category_id": float64(5)},
122+
}},
123+
expectedError: "assignment[0].transaction_id is required",
124+
},
125+
{
126+
name: "invalid category_id type",
127+
args: map[string]any{"assignments": []any{
128+
map[string]any{"transaction_id": float64(1), "category_id": "invalid"},
129+
}},
130+
expectedError: "assignment[0].category_id must be a number or null",
125131
},
126132
}
127133

128134
for _, c := range cases {
129135
t.Run(c.name, func(t *testing.T) {
130136
ctrl := gomock.NewController(t)
131-
gormDB, mockDB, mock := testingutils.GormMock()
137+
gormDB, mockDB, _ := testingutils.GormMock()
132138
defer func() { _ = mockDB.Close() }()
133139

134-
c.setupMock(mock)
135-
136-
catSvc := NewMockCategoryService(ctrl)
137-
rulesSvc := NewMockRulesService(ctrl)
138-
dryRunSvc := NewMockDryRunService(ctrl)
139-
140140
server := gomcp.NewServer(&gomcp.ServerConfig{
141141
DB: gormDB,
142142
Docs: "test docs",
143-
CategorySvc: catSvc,
144-
RulesSvc: rulesSvc,
145-
DryRunSvc: dryRunSvc,
143+
CategorySvc: NewMockCategoryService(ctrl),
144+
RulesSvc: NewMockRulesService(ctrl),
145+
DryRunSvc: NewMockDryRunService(ctrl),
146146
})
147147

148-
mcpServer := server.MCPServer()
149-
tool := mcpServer.GetTool("set_transaction_category")
148+
tool := server.MCPServer().GetTool("bulk_set_transaction_category")
150149
require.NotNil(t, tool)
151150

152151
result, err := tool.Handler(context.Background(), mcp.CallToolRequest{
153152
Params: mcp.CallToolParams{
154-
Name: "set_transaction_category",
153+
Name: "bulk_set_transaction_category",
155154
Arguments: c.args,
156155
},
157156
})

pkg/testingutils/gorm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func FlushAllTables(config boilerplate.DbConfig) error {
4141
}
4242

4343
func ensureThatItsLocal(config boilerplate.DbConfig) {
44-
allowedHosts := []string{"localhost", "127.0.0.1", "postgres"}
44+
allowedHosts := []string{"localhost", "127.0.0.1", "postgres", "tools.lan"}
4545

4646
for _, h := range allowedHosts {
4747
if strings.EqualFold(h, config.Host) {

0 commit comments

Comments
 (0)