Skip to content

Commit eea6832

Browse files
committed
fix
1 parent 07a73fe commit eea6832

File tree

4 files changed

+284
-97
lines changed

4 files changed

+284
-97
lines changed

service/authclient.go

Lines changed: 107 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ type AuthResponse struct {
2424
}
2525

2626
// AuthClient abstracts the external authentication call so it can be mocked in tests.
27-
// Validate now returns a MappingRequest built from the auth response.
27+
// Validate returns all MappingRequests built from the auth response array.
2828
type AuthClient interface {
29-
Validate(ctx context.Context, extension, secret, homeserverHost string) (*models.MappingRequest, int, error)
29+
Validate(ctx context.Context, extension, secret, homeserverHost string) ([]*models.MappingRequest, bool, error)
3030
}
3131

3232
// HTTPAuthClient is the default AuthClient implementation that calls the external HTTP endpoint.
@@ -51,15 +51,13 @@ func NewHTTPAuthClient(url string, timeout time.Duration, cacheTTL time.Duration
5151
}
5252

5353
type cachedAuth struct {
54-
mapping *models.MappingRequest
55-
expiry time.Time
56-
status int
54+
expiry time.Time
5755
}
5856

5957
// Validate calls the configured external auth endpoint and converts its result
60-
// into a models.MappingRequest. homeserverHost is used to build full Matrix IDs
61-
// when the returned user_name is a localpart.
62-
func (h *HTTPAuthClient) Validate(ctx context.Context, extension, secret, homeserverHost string) (*models.MappingRequest, int, error) {
58+
// into an array of models.MappingRequest. It caches authentication success/failure.
59+
// homeserverHost is used to build full Matrix IDs when the returned user_name is a localpart.
60+
func (h *HTTPAuthClient) Validate(ctx context.Context, extension, secret, homeserverHost string) ([]*models.MappingRequest, bool, error) {
6361
// check cache
6462
key := extension + "|" + secret + "|" + homeserverHost
6563
logger.Debug().Str("key", key).Msg("authclient: validate called")
@@ -69,7 +67,8 @@ func (h *HTTPAuthClient) Validate(ctx context.Context, extension, secret, homese
6967
if time.Now().Before(c.expiry) {
7068
h.mu.RUnlock()
7169
logger.Debug().Str("key", key).Time("expiry", c.expiry).Msg("authclient: cache hit")
72-
return c.mapping, c.status, nil
70+
// Return empty array on cache hit for authenticated requests
71+
return []*models.MappingRequest{}, true, nil
7372
}
7473
}
7574
h.mu.RUnlock()
@@ -80,80 +79,138 @@ func (h *HTTPAuthClient) Validate(ctx context.Context, extension, secret, homese
8079
body, _ := json.Marshal(payload)
8180
req, err := http.NewRequestWithContext(ctx, "POST", h.url, bytes.NewReader(body))
8281
if err != nil {
83-
return nil, 0, err
82+
return []*models.MappingRequest{}, false, err
8483
}
8584
req.Header.Set("Content-Type", "application/json")
8685

8786
logger.Debug().Str("url", h.url).Str("extension", extension).Msg("authclient: sending auth request")
8887

8988
resp, err := h.client.Do(req)
9089
if err != nil {
91-
return nil, 0, err
90+
return []*models.MappingRequest{}, false, err
9291
}
9392
defer resp.Body.Close()
9493

9594
logger.Debug().Int("status", resp.StatusCode).Msg("authclient: received response")
9695
if resp.StatusCode != http.StatusOK {
9796
b, _ := io.ReadAll(resp.Body)
9897
logger.Debug().Int("status", resp.StatusCode).Bytes("body", b).Msg("authclient: non-200 response")
99-
return nil, resp.StatusCode, fmt.Errorf("status %d: %s", resp.StatusCode, string(b))
98+
return []*models.MappingRequest{}, false, fmt.Errorf("status %d: %s", resp.StatusCode, string(b))
10099
}
101100

102-
var ar AuthResponse
103-
if err := json.NewDecoder(resp.Body).Decode(&ar); err != nil {
101+
var responses []AuthResponse
102+
if err := json.NewDecoder(resp.Body).Decode(&responses); err != nil {
104103
logger.Debug().Err(err).Msg("authclient: failed to decode response")
105-
return nil, resp.StatusCode, err
104+
return []*models.MappingRequest{}, false, err
106105
}
107106

108-
logger.Debug().Str("main_extension", ar.MainExtension).Strs("sub_extensions", ar.SubExtensions).Str("user_name", ar.UserName).Msg("authclient: parsed auth response")
107+
logger.Debug().Int("response_count", len(responses)).Msg("authclient: parsed auth response array")
109108

110-
// Parse main_extension
109+
// Find the matching response based on main_extension matching the extension parameter
110+
var ar *AuthResponse
111+
for i := range responses {
112+
if strings.TrimSpace(responses[i].MainExtension) == extension {
113+
ar = &responses[i]
114+
break
115+
}
116+
}
117+
118+
if ar == nil {
119+
logger.Debug().Str("extension", extension).Msg("authclient: no matching extension in response array")
120+
return []*models.MappingRequest{}, false, fmt.Errorf("extension %s not found in auth response", extension)
121+
}
122+
123+
logger.Debug().Str("main_extension", ar.MainExtension).Strs("sub_extensions", ar.SubExtensions).Str("user_name", ar.UserName).Msg("authclient: found matching auth response")
124+
125+
// Validate that the matched extension can be parsed as a number
111126
mainExt := strings.TrimSpace(ar.MainExtension)
112127
if mainExt == "" {
113-
return nil, resp.StatusCode, fmt.Errorf("external auth response missing main_extension")
128+
logger.Warn().Msg("authclient: matched extension is empty")
129+
return []*models.MappingRequest{}, false, fmt.Errorf("matched extension is empty")
114130
}
115-
mainNum, err := strconv.Atoi(mainExt)
116-
if err != nil {
117-
return nil, resp.StatusCode, fmt.Errorf("invalid main_extension from auth: %w", err)
131+
if _, err := strconv.Atoi(mainExt); err != nil {
132+
logger.Warn().Str("main_extension", mainExt).Err(err).Msg("authclient: matched extension cannot be parsed as integer")
133+
return []*models.MappingRequest{}, false, fmt.Errorf("matched extension %s is not a valid number: %w", mainExt, err)
134+
}
135+
136+
// Cache successful authentication (without mapping data)
137+
if h.cacheTTL > 0 {
138+
h.mu.Lock()
139+
h.cache[key] = cachedAuth{expiry: time.Now().Add(h.cacheTTL)}
140+
h.mu.Unlock()
141+
logger.Debug().Str("key", key).Time("expiry", time.Now().Add(h.cacheTTL)).Msg("authclient: cached successful authentication")
118142
}
119143

120-
// Parse sub extensions
121-
subNums := make([]int, 0, len(ar.SubExtensions))
122-
for _, ssub := range ar.SubExtensions {
123-
ssub = strings.TrimSpace(ssub)
124-
if ssub == "" {
144+
// Convert all responses to MappingRequest array
145+
mappings := make([]*models.MappingRequest, 0, len(responses))
146+
var matchedUserMapping *models.MappingRequest
147+
148+
for i := range responses {
149+
authResp := &responses[i]
150+
151+
// Parse main_extension
152+
mainExt := strings.TrimSpace(authResp.MainExtension)
153+
if mainExt == "" {
154+
logger.Warn().Msg("authclient: skipping response with missing main_extension")
125155
continue
126156
}
127-
if v, err := strconv.Atoi(ssub); err == nil {
128-
subNums = append(subNums, v)
157+
mainNum, err := strconv.Atoi(mainExt)
158+
if err != nil {
159+
logger.Warn().Str("main_extension", mainExt).Err(err).Msg("authclient: invalid main_extension, skipping")
160+
continue
129161
}
130-
}
131162

132-
// Build matrix id
133-
matrixID := strings.TrimSpace(ar.UserName)
134-
if matrixID == "" {
135-
matrixID = extension
136-
}
137-
if !strings.HasPrefix(matrixID, "@") {
138-
if homeserverHost == "" {
139-
return nil, resp.StatusCode, fmt.Errorf("cannot build matrix id: homeserver host not configured")
163+
// Parse sub extensions
164+
subNums := make([]int, 0, len(authResp.SubExtensions))
165+
for _, ssub := range authResp.SubExtensions {
166+
ssub = strings.TrimSpace(ssub)
167+
if ssub == "" {
168+
continue
169+
}
170+
if v, err := strconv.Atoi(ssub); err == nil {
171+
subNums = append(subNums, v)
172+
}
173+
}
174+
175+
// Build matrix id
176+
matrixID := strings.TrimSpace(authResp.UserName)
177+
if matrixID == "" {
178+
matrixID = extension
179+
}
180+
if !strings.HasPrefix(matrixID, "@") {
181+
if homeserverHost == "" {
182+
// If this is the matched user, we need homeserver host to build their Matrix ID
183+
if authResp == ar {
184+
logger.Warn().Str("user_name", authResp.UserName).Msg("authclient: cannot build matrix id for authenticated user without homeserver host")
185+
return []*models.MappingRequest{}, false, fmt.Errorf("cannot build matrix id for authenticated user: homeserver host not configured")
186+
}
187+
logger.Warn().Str("user_name", authResp.UserName).Msg("authclient: cannot build matrix id without homeserver host, skipping")
188+
continue
189+
}
190+
local := strings.ToLower(strings.TrimSpace(matrixID))
191+
matrixID = fmt.Sprintf("@%s:%s", local, homeserverHost)
140192
}
141-
local := strings.ToLower(strings.TrimSpace(matrixID))
142-
matrixID = fmt.Sprintf("@%s:%s", local, homeserverHost)
143-
}
144193

145-
mapping := &models.MappingRequest{
146-
Number: mainNum,
147-
MatrixID: matrixID,
148-
SubNumbers: subNums,
194+
mapping := &models.MappingRequest{
195+
Number: mainNum,
196+
MatrixID: matrixID,
197+
SubNumbers: subNums,
198+
}
199+
mappings = append(mappings, mapping)
200+
201+
// Track the authenticated user's mapping
202+
if authResp == ar {
203+
matchedUserMapping = mapping
204+
}
205+
206+
logger.Debug().Int("number", mainNum).Str("matrix_id", matrixID).Ints("sub_numbers", subNums).Msg("authclient: added mapping to response")
149207
}
150208

151-
if h.cacheTTL > 0 {
152-
h.mu.Lock()
153-
h.cache[key] = cachedAuth{mapping: mapping, expiry: time.Now().Add(h.cacheTTL), status: resp.StatusCode}
154-
h.mu.Unlock()
155-
logger.Debug().Str("key", key).Time("expiry", time.Now().Add(h.cacheTTL)).Msg("authclient: stored mapping in cache")
209+
// Ensure the authenticated user's mapping was successfully created
210+
if matchedUserMapping == nil {
211+
logger.Warn().Msg("authclient: authenticated user's mapping was not created")
212+
return []*models.MappingRequest{}, false, fmt.Errorf("failed to create mapping for authenticated user")
156213
}
157214

158-
return mapping, resp.StatusCode, nil
215+
return mappings, true, nil
159216
}

service/authclient_test.go

Lines changed: 125 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"testing"
88
"time"
99

10+
"github.com/nethesis/matrix2acrobits/models"
1011
"github.com/stretchr/testify/require"
1112
)
1213

@@ -18,37 +19,151 @@ func TestHTTPAuthClient_Non200Response(t *testing.T) {
1819
defer ts.Close()
1920

2021
c := NewHTTPAuthClient(ts.URL, 2*time.Second, 0)
21-
_, status, err := c.Validate(context.TODO(), "123", "secret", "example.com")
22+
_, ok, err := c.Validate(context.TODO(), "123", "secret", "example.com")
2223
require.Error(t, err)
23-
require.Equal(t, http.StatusInternalServerError, status)
24+
require.False(t, ok)
2425
}
2526

2627
func TestHTTPAuthClient_InvalidMainExtension(t *testing.T) {
2728
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2829
w.Header().Set("Content-Type", "application/json")
2930
w.WriteHeader(http.StatusOK)
30-
_, _ = w.Write([]byte(`{"main_extension":"not-a-number","sub_extensions":[],"user_name":"alice"}`))
31+
// Return an array with one valid and one invalid entry
32+
_, _ = w.Write([]byte(`[
33+
{"main_extension":"201","sub_extensions":[],"user_name":"giacomo"},
34+
{"main_extension":"not-a-number","sub_extensions":[],"user_name":"alice"}
35+
]`))
3136
}))
3237
defer ts.Close()
3338

3439
c := NewHTTPAuthClient(ts.URL, 2*time.Second, 0)
35-
mapping, status, err := c.Validate(context.TODO(), "123", "secret", "example.com")
40+
// Request the invalid extension - should not find it in the array
41+
mappings, ok, err := c.Validate(context.TODO(), "not-a-number", "secret", "example.com")
3642
require.Error(t, err)
37-
require.Nil(t, mapping)
38-
require.Equal(t, http.StatusOK, status)
43+
require.False(t, ok)
44+
require.Empty(t, mappings)
3945
}
4046

4147
func TestHTTPAuthClient_MissingHomeserverHost(t *testing.T) {
4248
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4349
w.Header().Set("Content-Type", "application/json")
4450
w.WriteHeader(http.StatusOK)
45-
_, _ = w.Write([]byte(`{"main_extension":"1","sub_extensions":[],"user_name":"alice"}`))
51+
// The server returns entries with localpart user_name (no @domain)
52+
// When we have no homeserverHost, we should skip these
53+
_, _ = w.Write([]byte(`[{"main_extension":"1","sub_extensions":[],"user_name":"alice"}]`))
4654
}))
4755
defer ts.Close()
4856

4957
c := NewHTTPAuthClient(ts.URL, 2*time.Second, 0)
50-
mapping, status, err := c.Validate(context.TODO(), "123", "secret", "")
58+
// Request extension 1 with no homeserver host configured
59+
// This should result in the entry being skipped, returning empty array and an error
60+
mappings, ok, err := c.Validate(context.TODO(), "1", "secret", "")
5161
require.Error(t, err)
52-
require.Nil(t, mapping)
53-
require.Equal(t, http.StatusOK, status)
62+
require.False(t, ok)
63+
require.Empty(t, mappings)
64+
}
65+
66+
func TestHTTPAuthClient_MatchingExtensionFromArray(t *testing.T) {
67+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
68+
w.Header().Set("Content-Type", "application/json")
69+
w.WriteHeader(http.StatusOK)
70+
_, _ = w.Write([]byte(`[
71+
{"main_extension":"201","sub_extensions":["91201","92201"],"user_name":"giacomo"},
72+
{"main_extension":"202","sub_extensions":["91202"],"user_name":"mario"}
73+
]`))
74+
}))
75+
defer ts.Close()
76+
77+
c := NewHTTPAuthClient(ts.URL, 2*time.Second, 0)
78+
mappings, ok, err := c.Validate(context.TODO(), "202", "secret", "example.com")
79+
require.NoError(t, err)
80+
require.True(t, ok)
81+
require.Len(t, mappings, 2)
82+
83+
// Find the mario entry (202)
84+
var marioMapping *models.MappingRequest
85+
for _, m := range mappings {
86+
if m.Number == 202 {
87+
marioMapping = m
88+
break
89+
}
90+
}
91+
require.NotNil(t, marioMapping)
92+
require.Equal(t, 202, marioMapping.Number)
93+
require.Equal(t, "@mario:example.com", marioMapping.MatrixID)
94+
require.Equal(t, []int{91202}, marioMapping.SubNumbers)
95+
}
96+
97+
func TestHTTPAuthClient_ExtensionNotFound(t *testing.T) {
98+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
99+
w.Header().Set("Content-Type", "application/json")
100+
w.WriteHeader(http.StatusOK)
101+
_, _ = w.Write([]byte(`[
102+
{"main_extension":"201","sub_extensions":["91201","92201"],"user_name":"giacomo"},
103+
{"main_extension":"202","sub_extensions":["91202"],"user_name":"mario"}
104+
]`))
105+
}))
106+
defer ts.Close()
107+
108+
c := NewHTTPAuthClient(ts.URL, 2*time.Second, 0)
109+
mappings, ok, err := c.Validate(context.TODO(), "999", "secret", "example.com")
110+
require.Error(t, err)
111+
require.False(t, ok)
112+
require.Empty(t, mappings)
113+
}
114+
115+
func TestHTTPAuthClient_CacheHitReturnsEmptyMappings(t *testing.T) {
116+
callCount := 0
117+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118+
callCount++
119+
w.Header().Set("Content-Type", "application/json")
120+
w.WriteHeader(http.StatusOK)
121+
_, _ = w.Write([]byte(`[
122+
{"main_extension":"201","sub_extensions":["91201"],"user_name":"giacomo"},
123+
{"main_extension":"202","sub_extensions":["91202"],"user_name":"mario"}
124+
]`))
125+
}))
126+
defer ts.Close()
127+
128+
c := NewHTTPAuthClient(ts.URL, 2*time.Second, 100*time.Millisecond)
129+
130+
// First call should make request and return all mappings
131+
mappings1, ok1, err1 := c.Validate(context.TODO(), "202", "secret", "example.com")
132+
require.NoError(t, err1)
133+
require.True(t, ok1)
134+
require.Len(t, mappings1, 2)
135+
require.Equal(t, 1, callCount)
136+
137+
// Second call should use cache and return empty mappings
138+
mappings2, ok2, err2 := c.Validate(context.TODO(), "202", "secret", "example.com")
139+
require.NoError(t, err2)
140+
require.True(t, ok2)
141+
require.Len(t, mappings2, 0) // Cache returns empty array
142+
require.Equal(t, 1, callCount) // No additional call
143+
}
144+
145+
func TestHTTPAuthClient_CacheFailedAuth(t *testing.T) {
146+
callCount := 0
147+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
148+
callCount++
149+
w.WriteHeader(http.StatusUnauthorized)
150+
_, _ = w.Write([]byte("unauthorized"))
151+
}))
152+
defer ts.Close()
153+
154+
c := NewHTTPAuthClient(ts.URL, 2*time.Second, 100*time.Millisecond)
155+
156+
// First call should make request and fail
157+
mappings1, ok1, err1 := c.Validate(context.TODO(), "999", "wrongsecret", "example.com")
158+
require.Error(t, err1)
159+
require.False(t, ok1)
160+
require.Empty(t, mappings1)
161+
require.Equal(t, 1, callCount)
162+
163+
// Second call should NOT use cache (failed auth not cached) and make new request
164+
mappings2, ok2, err2 := c.Validate(context.TODO(), "999", "wrongsecret", "example.com")
165+
require.Error(t, err2)
166+
require.False(t, ok2)
167+
require.Empty(t, mappings2)
168+
require.Equal(t, 2, callCount) // Additional call made
54169
}

0 commit comments

Comments
 (0)