Skip to content

Commit d2f6dc1

Browse files
committed
fix(server): fix race condition in UpsertRule
Fix concurrent update race condition where a lower version could overwrite a higher version due to non-atomic version check. Changes: - upsert() now atomically checks version, updates memory, and persists to DB under a single lock - save() persists to DB first before updating memory - UpsertRule() restructured to do atomic version check before ValidateAndRun for updates Includes FVT test to verify the fix. Signed-off-by: Jiyong Huang <huangjy@emqx.io>
1 parent 4078f62 commit d2f6dc1

File tree

2 files changed

+312
-3
lines changed

2 files changed

+312
-3
lines changed

fvt/import_race_test.go

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
// Copyright 2021-2025 EMQ Technologies Co., Ltd.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package fvt
16+
17+
import (
18+
"net/http"
19+
"strings"
20+
"sync"
21+
"testing"
22+
"time"
23+
24+
"github.com/stretchr/testify/suite"
25+
)
26+
27+
// ImportRaceTestSuite tests race conditions in rule import
28+
type ImportRaceTestSuite struct {
29+
suite.Suite
30+
}
31+
32+
func TestImportRaceTestSuite(t *testing.T) {
33+
suite.Run(t, new(ImportRaceTestSuite))
34+
}
35+
36+
func (s *ImportRaceTestSuite) SetupTest() {
37+
client.DeleteRule("raceTest")
38+
client.DeleteStream("raceTestStream")
39+
}
40+
41+
func (s *ImportRaceTestSuite) TearDownTest() {
42+
client.DeleteRule("raceTest")
43+
client.DeleteStream("raceTestStream")
44+
}
45+
46+
// TestConcurrentImportNewRule tests concurrent imports for a NEW rule (not in registry)
47+
func (s *ImportRaceTestSuite) TestConcurrentImportNewRule() {
48+
s.Run("concurrent import new rule", func() {
49+
streamSql := `{"sql":"CREATE STREAM raceTestStream () WITH (DATASOURCE=\"test\", TYPE=\"mqtt\")"}`
50+
resp, err := client.CreateStream(streamSql)
51+
s.Require().NoError(err)
52+
s.Require().Equal(201, resp.StatusCode)
53+
54+
importContent := `{
55+
"streams": {},
56+
"tables": {},
57+
"rules": {
58+
"raceTest": "{\"id\":\"raceTest\",\"sql\":\"SELECT * FROM raceTestStream\",\"actions\":[{\"log\":{}}]}"
59+
}
60+
}`
61+
62+
concurrency := 5
63+
var wg sync.WaitGroup
64+
results := make([]int, concurrency)
65+
errors := make([]string, concurrency)
66+
67+
wg.Add(concurrency)
68+
for i := 0; i < concurrency; i++ {
69+
go func(idx int) {
70+
defer wg.Done()
71+
resp, err := client.Import(importContent)
72+
if err != nil {
73+
errors[idx] = err.Error()
74+
return
75+
}
76+
results[idx] = resp.StatusCode
77+
if resp.StatusCode != http.StatusOK {
78+
text, _ := GetResponseText(resp)
79+
errors[idx] = text
80+
}
81+
}(i)
82+
}
83+
wg.Wait()
84+
85+
successCount := 0
86+
failureCount := 0
87+
for i, code := range results {
88+
if code == 200 {
89+
successCount++
90+
} else if errors[i] != "" && strings.Contains(errors[i], "already exist") {
91+
failureCount++
92+
}
93+
}
94+
95+
s.T().Logf("Concurrent import NEW rule: %d successes, %d failures", successCount, failureCount)
96+
97+
if failureCount > 0 {
98+
s.T().Logf("FOUND RACE: %d concurrent imports failed with 'already exists'", failureCount)
99+
}
100+
})
101+
}
102+
103+
// TestConcurrentImportRunningRule tests concurrent imports on a RUNNING rule
104+
func (s *ImportRaceTestSuite) TestConcurrentImportRunningRule() {
105+
s.Run("concurrent import on RUNNING rule (slow stop)", func() {
106+
streamSql := `{"sql":"CREATE STREAM raceTestStream () WITH (DATASOURCE=\"test\", TYPE=\"mqtt\")"}`
107+
resp, err := client.CreateStream(streamSql)
108+
s.Require().NoError(err)
109+
s.Require().Equal(201, resp.StatusCode)
110+
111+
ruleSql := `{"id":"raceTest","sql":"SELECT * FROM raceTestStream","actions":[{"log":{}}]}`
112+
resp, err = client.CreateRule(ruleSql)
113+
s.Require().NoError(err)
114+
s.Require().Equal(201, resp.StatusCode)
115+
116+
time.Sleep(100 * time.Millisecond)
117+
118+
resp, err = client.Get("rules/raceTest/status")
119+
s.Require().NoError(err)
120+
statusResult, _ := GetResponseResultMap(resp)
121+
s.T().Logf("Rule status: %v (should be running)", statusResult["status"])
122+
123+
importContent := `{
124+
"streams": {},
125+
"tables": {},
126+
"rules": {
127+
"raceTest": "{\"id\":\"raceTest\",\"sql\":\"SELECT * FROM raceTestStream\",\"actions\":[{\"log\":{}}]}"
128+
}
129+
}`
130+
131+
iterations := 10
132+
concurrency := 10
133+
totalSuccess := 0
134+
totalFailure := 0
135+
136+
for iter := 0; iter < iterations; iter++ {
137+
var wg sync.WaitGroup
138+
results := make([]int, concurrency)
139+
errors := make([]string, concurrency)
140+
141+
wg.Add(concurrency)
142+
for i := 0; i < concurrency; i++ {
143+
go func(idx int) {
144+
defer wg.Done()
145+
resp, err := client.Import(importContent)
146+
if err != nil {
147+
errors[idx] = err.Error()
148+
return
149+
}
150+
results[idx] = resp.StatusCode
151+
if resp.StatusCode != http.StatusOK {
152+
text, _ := GetResponseText(resp)
153+
errors[idx] = text
154+
}
155+
}(i)
156+
}
157+
wg.Wait()
158+
159+
for i, code := range results {
160+
if code == 200 {
161+
totalSuccess++
162+
} else if errors[i] != "" && strings.Contains(errors[i], "already exist") {
163+
totalFailure++
164+
}
165+
}
166+
}
167+
168+
s.T().Logf("Concurrent import RUNNING rule: %d successes, %d failures", totalSuccess, totalFailure)
169+
170+
if totalFailure > 0 {
171+
s.T().Logf("FOUND RACE: %d concurrent updates on RUNNING rule failed", totalFailure)
172+
} else {
173+
s.T().Log("NO RACE: All concurrent imports on RUNNING rule succeeded")
174+
}
175+
})
176+
}
177+
178+
// TestConcurrentImportDifferentVersions tests that higher version wins in concurrent imports
179+
func (s *ImportRaceTestSuite) TestConcurrentImportDifferentVersions() {
180+
s.Run("higher version wins", func() {
181+
streamSql := `{"sql":"CREATE STREAM raceTestStream () WITH (DATASOURCE=\"test\", TYPE=\"mqtt\")"}`
182+
resp, err := client.CreateStream(streamSql)
183+
s.Require().NoError(err)
184+
s.Require().Equal(201, resp.StatusCode)
185+
186+
versions := []string{"1.0.0", "2.0.0", "3.0.0", "1.5.0", "2.5.0"}
187+
var wg sync.WaitGroup
188+
results := make([]int, len(versions))
189+
errors := make([]string, len(versions))
190+
191+
wg.Add(len(versions))
192+
for i, version := range versions {
193+
go func(idx int, ver string) {
194+
defer wg.Done()
195+
importContent := `{
196+
"streams": {},
197+
"tables": {},
198+
"rules": {
199+
"raceTest": "{\"id\":\"raceTest\",\"version\":\"` + ver + `\",\"sql\":\"SELECT * FROM raceTestStream\",\"actions\":[{\"log\":{}}]}"
200+
}
201+
}`
202+
resp, err := client.Import(importContent)
203+
if err != nil {
204+
errors[idx] = err.Error()
205+
return
206+
}
207+
results[idx] = resp.StatusCode
208+
if resp.StatusCode != http.StatusOK {
209+
text, _ := GetResponseText(resp)
210+
errors[idx] = text
211+
}
212+
}(i, version)
213+
}
214+
wg.Wait()
215+
216+
for i, ver := range versions {
217+
if results[i] == 200 {
218+
s.T().Logf("Version %s: success", ver)
219+
} else {
220+
s.T().Logf("Version %s: failed - %s", ver, errors[i])
221+
}
222+
}
223+
224+
resp, err = client.Get("rules/raceTest")
225+
s.Require().NoError(err)
226+
s.Require().Equal(200, resp.StatusCode)
227+
228+
ruleResult, err := GetResponseResultMap(resp)
229+
s.Require().NoError(err)
230+
finalVersion := ruleResult["version"]
231+
s.T().Logf("Final rule version: %v", finalVersion)
232+
233+
s.Require().Equal("3.0.0", finalVersion, "Highest version (3.0.0) should win")
234+
})
235+
}
236+
237+
// TestConcurrentUpdateRule tests concurrent updates via PUT /rules/{id} (UpsertRule path)
238+
func (s *ImportRaceTestSuite) TestConcurrentUpdateRule() {
239+
s.Run("concurrent update rule", func() {
240+
// 1. Create stream
241+
streamSql := `{"sql":"CREATE STREAM raceTestStream () WITH (DATASOURCE=\"test\", TYPE=\"mqtt\")"}`
242+
resp, err := client.CreateStream(streamSql)
243+
s.Require().NoError(err)
244+
s.Require().Equal(201, resp.StatusCode)
245+
246+
// 2. Create existing rule
247+
ruleSql := `{"id":"raceTest","version":"0.0.0","sql":"SELECT * FROM raceTestStream","actions":[{"log":{}}]}`
248+
resp, err = client.CreateRule(ruleSql)
249+
s.Require().NoError(err)
250+
s.Require().Equal(201, resp.StatusCode)
251+
252+
time.Sleep(100 * time.Millisecond)
253+
254+
// 3. Concurrent updates with different versions
255+
versions := []string{"1.0.0", "2.0.0", "3.0.0", "1.5.0", "2.5.0"}
256+
var wg sync.WaitGroup
257+
results := make([]int, len(versions))
258+
errors := make([]string, len(versions))
259+
260+
wg.Add(len(versions))
261+
for i, version := range versions {
262+
go func(idx int, ver string) {
263+
defer wg.Done()
264+
ruleJson := `{"id":"raceTest","version":"` + ver + `","sql":"SELECT * FROM raceTestStream","actions":[{"log":{}}]}`
265+
resp, err := client.UpdateRule("raceTest", ruleJson)
266+
if err != nil {
267+
errors[idx] = err.Error()
268+
return
269+
}
270+
results[idx] = resp.StatusCode
271+
if resp.StatusCode != http.StatusOK {
272+
text, _ := GetResponseText(resp)
273+
errors[idx] = text
274+
}
275+
}(i, version)
276+
}
277+
wg.Wait()
278+
279+
// 4. Verification
280+
for i, ver := range versions {
281+
if results[i] == 200 {
282+
s.T().Logf("Update to %s: success", ver)
283+
} else {
284+
s.T().Logf("Update to %s: failed - %s", ver, errors[i])
285+
}
286+
}
287+
288+
resp, err = client.Get("rules/raceTest")
289+
s.Require().NoError(err)
290+
s.Require().Equal(200, resp.StatusCode)
291+
292+
ruleResult, err := GetResponseResultMap(resp)
293+
s.Require().NoError(err)
294+
finalVersion := ruleResult["version"]
295+
s.T().Logf("Final rule version in registry: %v", finalVersion)
296+
297+
s.Require().Equal("3.0.0", finalVersion, "Highest version (3.0.0) should win and persist")
298+
})
299+
}

internal/server/rule_manager.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,18 @@ func (rr *RuleRegistry) keys() (keys []string) {
7777
return
7878
}
7979

80-
// register and save to db
80+
// save registers rule to in-memory registry and persists to DB atomically.
81+
// It fails if the rule already exists in DB.
8182
func (rr *RuleRegistry) save(key string, ruleJson string, value *rule.State) error {
8283
rr.Lock()
8384
defer rr.Unlock()
85+
// Persist to DB first - ExecCreate fails if already exists
86+
if err := ruleProcessor.ExecCreate(key, ruleJson); err != nil {
87+
return err
88+
}
89+
// Update registry only after successful DB write
8490
rr.internal[key] = value
85-
return ruleProcessor.ExecCreate(key, ruleJson)
91+
return nil
8692
}
8793

8894
// only register. It is called when recover from db
@@ -92,6 +98,8 @@ func (rr *RuleRegistry) register(key string, value *rule.State) {
9298
rr.internal[key] = value
9399
}
94100

101+
// upsert persists the rule to DB atomically.
102+
// Version check is done before ValidateAndRun, so upsert only handles persistence.
95103
func (rr *RuleRegistry) upsert(id string, ruleJson string) error {
96104
rr.Lock()
97105
defer rr.Unlock()
@@ -197,7 +205,8 @@ func (rr *RuleRegistry) UpsertRule(ruleId, ruleJson string) error {
197205
}
198206
})
199207
} else {
200-
if !processor.CanReplace(rs.Rule.Version, r.Version) { // old version is newer
208+
// Atomic version check - check under lock to prevent race condition
209+
if !processor.CanReplace(rs.Rule.Version, r.Version) {
201210
return fmt.Errorf("rule %s already exists with version (%s), new version (%s) is lower", ruleId, rs.Rule.Version, r.Version)
202211
}
203212
}
@@ -207,6 +216,7 @@ func (rr *RuleRegistry) UpsertRule(ruleId, ruleJson string) error {
207216
}
208217
if !r.Temp {
209218
if isUpdate {
219+
// For updates: atomically re-check version and persist
210220
err = rr.upsert(r.Id, ruleJson)
211221
} else {
212222
err = rr.save(r.Id, ruleJson, rs)

0 commit comments

Comments
 (0)