Skip to content

Commit 8988ab0

Browse files
committed
Added checks to ensure presence of notebook and user ids
1 parent 6c3e497 commit 8988ab0

File tree

1 file changed

+54
-13
lines changed

1 file changed

+54
-13
lines changed

modules/llm_module.go

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,22 @@ func (m *LlmModule) GenerateNotebook(ctx context.Context, body io.Reader) (*http
3535
return nil, fmt.Errorf("failed to decode request body as JSON: %w", err)
3636
}
3737

38-
if _, hasSessionID := requestData["session_id"]; !hasSessionID {
39-
notebookID, hasNotebookID := requestData["notebook_id"]
40-
if !hasNotebookID {
41-
return nil, fmt.Errorf("request body must contain 'session_id' or 'notebook_id'")
42-
}
43-
44-
notebookIDStr, isString := notebookID.(string)
45-
if !isString || notebookIDStr == "" {
46-
return nil, fmt.Errorf("'notebook_id' must be a non-empty string")
47-
}
48-
requestData["session_id"] = notebookIDStr
38+
// TODO: User ID should not be passed in the body.
39+
// TODO: It should be extracted from the auth context, which i am not going to do now :)
40+
// making sure user_id and notebook_id are present
41+
if _, hasUserID := requestData["user_id"]; !hasUserID {
42+
return nil, fmt.Errorf("request body must contain 'user_id'")
43+
}
44+
if _, hasNotebookID := requestData["notebook_id"]; !hasNotebookID {
45+
return nil, fmt.Errorf("request body must contain 'notebook_id'")
46+
}
47+
notebookIDStr, isString := requestData["notebook_id"].(string)
48+
if !isString || notebookIDStr == "" {
49+
return nil, fmt.Errorf("'notebook_id' must be a non-empty string")
50+
}
51+
userIDStr, isString := requestData["user_id"].(string)
52+
if !isString || userIDStr == "" {
53+
return nil, fmt.Errorf("'user_id' must be a non-empty string")
4954
}
5055

5156
finalBodyBytes, err := json.Marshal(requestData)
@@ -68,6 +73,24 @@ func (m *LlmModule) ModifyNotebook(ctx context.Context, sessionID string, body i
6873
return nil, fmt.Errorf("failed to decode request body as JSON: %w", err)
6974
}
7075

76+
// TODO: User ID should not be passed in the body.
77+
// TODO: It should be extracted from the auth context, which i am not going to do now :)
78+
// making sure user_id and notebook_id are present
79+
if _, hasUserID := requestData["user_id"]; !hasUserID {
80+
return nil, fmt.Errorf("request body must contain 'user_id'")
81+
}
82+
if _, hasNotebookID := requestData["notebook_id"]; !hasNotebookID {
83+
return nil, fmt.Errorf("request body must contain 'notebook_id'")
84+
}
85+
notebookIDStr, isString := requestData["notebook_id"].(string)
86+
if !isString || notebookIDStr == "" {
87+
return nil, fmt.Errorf("'notebook_id' must be a non-empty string")
88+
}
89+
userIDStr, isString := requestData["user_id"].(string)
90+
if !isString || userIDStr == "" {
91+
return nil, fmt.Errorf("'user_id' must be a non-empty string")
92+
}
93+
7194
if instruction, ok := requestData["instruction"].(string); !ok || instruction == "" {
7295
return nil, fmt.Errorf("request body must contain a non-empty 'instruction' string")
7396
}
@@ -80,7 +103,7 @@ func (m *LlmModule) ModifyNotebook(ctx context.Context, sessionID string, body i
80103
return nil, fmt.Errorf("'current_notebook' object must contain a 'cells' array")
81104
}
82105

83-
return m.Repo.ModifyNotebook(ctx, sessionID, bytes.NewBuffer(bodyBytes))
106+
return m.Repo.ModifyNotebook(ctx, bytes.NewBuffer(bodyBytes))
84107
}
85108

86109
// FixNotebook validates and proxies the fix request.
@@ -95,6 +118,24 @@ func (m *LlmModule) FixNotebook(ctx context.Context, sessionID string, body io.R
95118
return nil, fmt.Errorf("failed to decode request body as JSON: %w", err)
96119
}
97120

121+
// TODO: User ID should not be passed in the body.
122+
// TODO: It should be extracted from the auth context, which i am not going to do now :)
123+
// making sure user_id and notebook_id are present
124+
if _, hasUserID := requestData["user_id"]; !hasUserID {
125+
return nil, fmt.Errorf("request body must contain 'user_id'")
126+
}
127+
if _, hasNotebookID := requestData["notebook_id"]; !hasNotebookID {
128+
return nil, fmt.Errorf("request body must contain 'notebook_id'")
129+
}
130+
notebookIDStr, isString := requestData["notebook_id"].(string)
131+
if !isString || notebookIDStr == "" {
132+
return nil, fmt.Errorf("'notebook_id' must be a non-empty string")
133+
}
134+
userIDStr, isString := requestData["user_id"].(string)
135+
if !isString || userIDStr == "" {
136+
return nil, fmt.Errorf("'user_id' must be a non-empty string")
137+
}
138+
98139
if traceback, ok := requestData["traceback"].(string); !ok || traceback == "" {
99140
return nil, fmt.Errorf("request body must contain a non-empty 'traceback' string")
100141
}
@@ -107,5 +148,5 @@ func (m *LlmModule) FixNotebook(ctx context.Context, sessionID string, body io.R
107148
return nil, fmt.Errorf("'current_notebook' object must contain a 'cells' array")
108149
}
109150

110-
return m.Repo.FixNotebook(ctx, sessionID, bytes.NewBuffer(bodyBytes))
151+
return m.Repo.FixNotebook(ctx, bytes.NewBuffer(bodyBytes))
111152
}

0 commit comments

Comments
 (0)