Skip to content

Commit 6a6f954

Browse files
committed
split middlewares into seperate files
1 parent e68f117 commit 6a6f954

File tree

6 files changed

+401
-345
lines changed

6 files changed

+401
-345
lines changed
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package middlewares
2+
3+
import (
4+
"encoding/base64"
5+
"net/http"
6+
"net/url"
7+
"strings"
8+
9+
log "github.com/codeshelldev/secured-signal-api/utils/logger"
10+
)
11+
12+
type AuthMiddleware struct {
13+
Next http.Handler
14+
Token string
15+
}
16+
17+
type authType string
18+
19+
const (
20+
Bearer authType = "Bearer"
21+
Basic authType = "Basic"
22+
Query authType = "Query"
23+
None authType = "None"
24+
)
25+
26+
func getAuthType(str string) authType {
27+
switch str {
28+
case "Bearer":
29+
return Bearer
30+
case "Basic":
31+
return Basic
32+
default:
33+
return None
34+
}
35+
}
36+
37+
func (data AuthMiddleware) Use() http.Handler {
38+
next := data.Next
39+
token := data.Token
40+
41+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
42+
if token == "" {
43+
next.ServeHTTP(w, req)
44+
return
45+
}
46+
47+
log.Info("Request:", req.Method, req.URL.Path)
48+
49+
authHeader := req.Header.Get("Authorization")
50+
51+
authQuery := req.URL.Query().Get("@authorization")
52+
53+
var authType authType = None
54+
55+
success := false
56+
57+
if authHeader != "" {
58+
authBody := strings.Split(authHeader, " ")
59+
60+
authType = getAuthType(authBody[0])
61+
authToken := authBody[1]
62+
63+
switch authType {
64+
case Bearer:
65+
if authToken == token {
66+
success = true
67+
}
68+
69+
case Basic:
70+
basicAuthBody, err := base64.StdEncoding.DecodeString(authToken)
71+
72+
if err != nil {
73+
log.Error("Could not decode Basic Auth Payload: ", err.Error())
74+
}
75+
76+
basicAuth := string(basicAuthBody)
77+
basicAuthParams := strings.Split(basicAuth, ":")
78+
79+
user := "api"
80+
81+
if basicAuthParams[0] == user && basicAuthParams[1] == token {
82+
success = true
83+
}
84+
}
85+
86+
} else if authQuery != "" {
87+
authType = Query
88+
89+
authToken, _ := url.QueryUnescape(authQuery)
90+
91+
if authToken == token {
92+
success = true
93+
94+
modifiedQuery := req.URL.Query()
95+
96+
modifiedQuery.Del("@authorization")
97+
98+
req.URL.RawQuery = modifiedQuery.Encode()
99+
}
100+
}
101+
102+
if !success {
103+
w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"")
104+
105+
log.Warn("User failed ", string(authType), " Auth")
106+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
107+
return
108+
}
109+
110+
next.ServeHTTP(w, req)
111+
})
112+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package middlewares
2+
3+
import (
4+
"net/http"
5+
"slices"
6+
7+
log "github.com/codeshelldev/secured-signal-api/utils/logger"
8+
)
9+
10+
type EndpointsMiddleware struct {
11+
Next http.Handler
12+
BlockedEndpoints []string
13+
}
14+
15+
func (data EndpointsMiddleware) Use() http.Handler {
16+
next := data.Next
17+
BLOCKED_ENDPOINTS := data.BlockedEndpoints
18+
19+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
20+
reqPath := req.URL.Path
21+
22+
if slices.Contains(BLOCKED_ENDPOINTS, reqPath) {
23+
log.Warn("User tried to access blocked endpoint: ", reqPath)
24+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
25+
return
26+
}
27+
28+
next.ServeHTTP(w, req)
29+
})
30+
}
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
package middlewares
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"net/http"
8+
"net/url"
9+
"regexp"
10+
"strconv"
11+
"strings"
12+
"text/template"
13+
14+
log "github.com/codeshelldev/secured-signal-api/utils/logger"
15+
query "github.com/codeshelldev/secured-signal-api/utils/query"
16+
)
17+
18+
type TemplateMiddleware struct {
19+
Next http.Handler
20+
Variables map[string]interface{}
21+
}
22+
23+
func renderTemplate(name string, tmplStr string, data any) (string, error) {
24+
tmpl, err := template.New(name).Parse(tmplStr)
25+
26+
if err != nil {
27+
return "", err
28+
}
29+
var buf bytes.Buffer
30+
31+
err = tmpl.Execute(&buf, data)
32+
33+
if err != nil {
34+
return "", err
35+
}
36+
return buf.String(), nil
37+
}
38+
39+
func templateJSON(data map[string]interface{}, variables map[string]interface{}) map[string]interface{} {
40+
for k, v := range data {
41+
str, ok := v.(string)
42+
43+
if ok {
44+
re, err := regexp.Compile(`{{\s*\.([A-Za-z_][A-Za-z0-9_]*)\s*}}`)
45+
46+
if err != nil {
47+
log.Error("Encountered Error while Compiling Regex: ", err.Error())
48+
}
49+
50+
matches := re.FindAllStringSubmatch(str, -1)
51+
52+
if len(matches) > 1 {
53+
for i, tmplStr := range matches {
54+
55+
tmplKey := matches[i][1]
56+
57+
variable, err := json.Marshal(variables[tmplKey])
58+
59+
if err != nil {
60+
log.Error("Could not decode JSON: ", err.Error())
61+
break
62+
}
63+
64+
data[k] = strings.ReplaceAll(str, string(variable), tmplStr[0])
65+
}
66+
} else if len(matches) == 1 {
67+
tmplKey := matches[0][1]
68+
69+
data[k] = variables[tmplKey]
70+
}
71+
}
72+
}
73+
74+
return data
75+
}
76+
77+
func (data TemplateMiddleware) Use() http.Handler {
78+
next := data.Next
79+
VARIABLES := data.Variables
80+
81+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
82+
if req.Body != nil {
83+
bodyBytes, err := io.ReadAll(req.Body)
84+
85+
if err != nil {
86+
log.Error("Could not read Body: ", err.Error())
87+
http.Error(w, "Internal Error", http.StatusInternalServerError)
88+
return
89+
}
90+
91+
req.Body.Close()
92+
93+
var modifiedBodyData map[string]interface{}
94+
95+
err = json.Unmarshal(bodyBytes, &modifiedBodyData)
96+
97+
if err != nil {
98+
log.Error("Could not decode Body: ", err.Error())
99+
http.Error(w, "Internal Error", http.StatusInternalServerError)
100+
return
101+
}
102+
103+
modifiedBodyData = templateJSON(modifiedBodyData, VARIABLES)
104+
105+
if req.URL.RawQuery != "" {
106+
decodedQuery, _ := url.QueryUnescape(req.URL.RawQuery)
107+
108+
log.Debug("Decoded Query: ", decodedQuery)
109+
110+
templatedQuery, _ := renderTemplate("query", decodedQuery, VARIABLES)
111+
112+
modifiedQuery := req.URL.Query()
113+
114+
queryData := query.ParseRawQuery(templatedQuery)
115+
116+
for key, value := range queryData {
117+
keyWithoutPrefix, found := strings.CutPrefix(key, "@")
118+
119+
if found {
120+
modifiedBodyData[keyWithoutPrefix] = query.ParseTypedQuery(value)
121+
122+
modifiedQuery.Del(key)
123+
}
124+
}
125+
126+
req.URL.RawQuery = modifiedQuery.Encode()
127+
128+
log.Debug("Applied Query Templating: ", templatedQuery)
129+
}
130+
131+
modifiedBodyBytes, err := json.Marshal(modifiedBodyData)
132+
133+
if err != nil {
134+
log.Error("Could not encode Body: ", err.Error())
135+
http.Error(w, "Internal Error", http.StatusInternalServerError)
136+
return
137+
}
138+
139+
modifiedBody := string(modifiedBodyBytes)
140+
141+
log.Debug("Applied Body Templating: ", modifiedBody)
142+
143+
req.Body = io.NopCloser(bytes.NewReader(modifiedBodyBytes))
144+
145+
req.ContentLength = int64(len(modifiedBody))
146+
req.Header.Set("Content-Length", strconv.Itoa(len(modifiedBody)))
147+
}
148+
149+
reqPath := req.URL.Path
150+
reqPath, _ = url.PathUnescape(reqPath)
151+
152+
modifiedReqPath, _ := renderTemplate("path", reqPath, VARIABLES)
153+
154+
req.URL.Path = modifiedReqPath
155+
156+
next.ServeHTTP(w, req)
157+
})
158+
}

0 commit comments

Comments
 (0)