Skip to content

Commit ae00574

Browse files
ThomasK33NanoBoom
authored andcommitted
fix(server): sync disconnect callbacks (#176)
1 parent 82a08da commit ae00574

File tree

7 files changed

+217
-31
lines changed

7 files changed

+217
-31
lines changed

lua/claudecode/server/init.lua

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@ local M = {}
1212
---@field server table|nil The TCP server instance
1313
---@field port number|nil The port server is running on
1414
---@field auth_token string|nil The authentication token for validating connections
15-
---@field clients table<string, WebSocketClient> A list of connected clients
1615
---@field handlers table Message handlers by method name
1716
---@field ping_timer table|nil Timer for sending pings
1817
M.state = {
1918
server = nil,
2019
port = nil,
2120
auth_token = nil,
22-
clients = {},
2321
handlers = {},
2422
ping_timer = nil,
2523
}
@@ -53,8 +51,6 @@ function M.start(config, auth_token)
5351
M._handle_message(client, message)
5452
end,
5553
on_connect = function(client)
56-
M.state.clients[client.id] = client
57-
5854
-- Log connection with auth status
5955
if M.state.auth_token then
6056
logger.debug("server", "Authenticated WebSocket client connected:", client.id)
@@ -71,7 +67,6 @@ function M.start(config, auth_token)
7167
end
7268
end,
7369
on_disconnect = function(client, code, reason)
74-
M.state.clients[client.id] = nil
7570
logger.debug(
7671
"server",
7772
"WebSocket client disconnected:",
@@ -124,8 +119,6 @@ function M.stop()
124119
M.state.server = nil
125120
M.state.port = nil
126121
M.state.auth_token = nil
127-
M.state.clients = {}
128-
129122
return true
130123
end
131124

@@ -213,8 +206,6 @@ end
213206
local module_instance_id = math.random(10000, 99999)
214207
logger.debug("server", "Server module loaded with instance ID:", module_instance_id)
215208

216-
-- Note: debug_deferred_table function removed as deferred_responses table is no longer used
217-
218209
function M._setup_deferred_response(deferred_info)
219210
local co = deferred_info.coroutine
220211

lua/claudecode/server/mock.lua

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ local tools = require("claudecode.tools.init")
1212
M.state = {
1313
server = nil,
1414
port = nil,
15-
clients = {},
1615
handlers = {},
1716
messages = {}, -- Store messages for testing
1817
}
@@ -74,7 +73,6 @@ function M.stop()
7473
-- Reset state
7574
M.state.server = nil
7675
M.state.port = nil
77-
M.state.clients = {}
7876
M.state.messages = {}
7977

8078
return true
@@ -101,29 +99,36 @@ end
10199
---@param client_id string A unique client identifier
102100
---@return table client The client object
103101
function M.add_client(client_id)
102+
assert(type(client_id) == "string", "Expected client_id to be a string")
104103
if not M.state.server then
105104
error("Server not running")
106105
end
106+
assert(type(M.state.server.clients) == "table", "Expected mock server.clients to be a table")
107107

108108
local client = {
109109
id = client_id,
110110
connected = true,
111111
messages = {},
112112
}
113113

114-
M.state.clients[client_id] = client
114+
M.state.server.clients[client_id] = client
115115
return client
116116
end
117117

118118
---Remove a client from the server
119119
---@param client_id string The client identifier
120120
---@return boolean success Whether removal was successful
121121
function M.remove_client(client_id)
122-
if not M.state.server or not M.state.clients[client_id] then
122+
assert(type(client_id) == "string", "Expected client_id to be a string")
123+
if not M.state.server or type(M.state.server.clients) ~= "table" then
123124
return false
124125
end
125126

126-
M.state.clients[client_id] = nil
127+
if not M.state.server.clients[client_id] then
128+
return false
129+
end
130+
131+
M.state.server.clients[client_id] = nil
127132
return true
128133
end
129134

@@ -136,7 +141,10 @@ function M.send(client, method, params)
136141
local client_obj
137142

138143
if type(client) == "string" then
139-
client_obj = M.state.clients[client]
144+
if not M.state.server or type(M.state.server.clients) ~= "table" then
145+
return false
146+
end
147+
client_obj = M.state.server.clients[client]
140148
else
141149
client_obj = client
142150
end
@@ -172,7 +180,10 @@ function M.send_response(client, id, result, error)
172180
local client_obj
173181

174182
if type(client) == "string" then
175-
client_obj = M.state.clients[client]
183+
if not M.state.server or type(M.state.server.clients) ~= "table" then
184+
return false
185+
end
186+
client_obj = M.state.server.clients[client]
176187
else
177188
client_obj = client
178189
end
@@ -208,9 +219,13 @@ end
208219
---@param params table The parameters to send
209220
---@return boolean success Whether broadcasting was successful
210221
function M.broadcast(method, params)
222+
if not M.state.server or type(M.state.server.clients) ~= "table" then
223+
return false
224+
end
225+
211226
local success = true
212227

213-
for client_id, _ in pairs(M.state.clients) do
228+
for client_id, _ in pairs(M.state.server.clients) do
214229
local send_success = M.send(client_id, method, params)
215230
success = success and send_success
216231
end
@@ -223,7 +238,12 @@ end
223238
---@param message table The message to process
224239
---@return table|nil response The response if any
225240
function M.simulate_message(client_id, message)
226-
local client = M.state.clients[client_id]
241+
assert(type(client_id) == "string", "Expected client_id to be a string")
242+
if not M.state.server or type(M.state.server.clients) ~= "table" then
243+
return nil
244+
end
245+
246+
local client = M.state.server.clients[client_id]
227247

228248
if not client then
229249
return nil
@@ -255,7 +275,11 @@ end
255275
function M.clear_messages()
256276
M.state.messages = {}
257277

258-
for _, client in pairs(M.state.clients) do
278+
if not M.state.server or type(M.state.server.clients) ~= "table" then
279+
return
280+
end
281+
282+
for _, client in pairs(M.state.server.clients) do
259283
client.messages = {}
260284
end
261285
end

lua/claudecode/server/tcp.lua

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,33 +124,68 @@ function M._handle_new_connection(server)
124124
-- Set up data handler
125125
client_tcp:read_start(function(err, data)
126126
if err then
127-
server.on_error("Client read error: " .. err)
128-
M._remove_client(server, client)
127+
local error_msg = "Client read error: " .. err
128+
server.on_error(error_msg)
129+
M._disconnect_client(server, client, 1006, error_msg)
129130
return
130131
end
131132

132133
if not data then
133134
-- EOF - client disconnected
134-
M._remove_client(server, client)
135+
M._disconnect_client(server, client, 1006, "EOF")
135136
return
136137
end
137138

138139
-- Process incoming data
139140
client_manager.process_data(client, data, function(cl, message)
140141
server.on_message(cl, message)
141142
end, function(cl, code, reason)
142-
server.on_disconnect(cl, code, reason)
143-
M._remove_client(server, cl)
143+
M._disconnect_client(server, cl, code, reason)
144144
end, function(cl, error_msg)
145145
server.on_error("Client " .. cl.id .. " error: " .. error_msg)
146-
M._remove_client(server, cl)
146+
M._disconnect_client(server, cl, 1006, "Client error: " .. error_msg)
147147
end, server.auth_token)
148148
end)
149149

150150
-- Notify about new connection
151151
server.on_connect(client)
152152
end
153153

154+
---Disconnect a client and remove it from the server.
155+
---This ensures `server.on_disconnect` is invoked for every disconnect path
156+
---(EOF, read errors, protocol errors, timeouts), and only once per client.
157+
---@param server TCPServer The server object
158+
---@param client WebSocketClient The client to disconnect
159+
---@param code number|nil WebSocket close code
160+
---@param reason string|nil WebSocket close reason
161+
function M._disconnect_client(server, client, code, reason)
162+
assert(type(server) == "table", "Expected server to be a table")
163+
local on_disconnect_type = type(server.on_disconnect)
164+
local on_disconnect_mt = on_disconnect_type == "table" and getmetatable(server.on_disconnect) or nil
165+
assert(
166+
on_disconnect_type == "function" or (on_disconnect_mt ~= nil and type(on_disconnect_mt.__call) == "function"),
167+
"Expected server.on_disconnect to be callable"
168+
)
169+
assert(type(server.clients) == "table", "Expected server.clients to be a table")
170+
assert(type(client) == "table", "Expected client to be a table")
171+
assert(type(client.id) == "string", "Expected client.id to be a string")
172+
if code ~= nil then
173+
assert(type(code) == "number", "Expected code to be a number")
174+
end
175+
if reason ~= nil then
176+
assert(type(reason) == "string", "Expected reason to be a string")
177+
end
178+
179+
-- Idempotency: a client can hit multiple disconnect paths (e.g. CLOSE frame
180+
-- followed by a TCP EOF). Only notify/remove once.
181+
if not server.clients[client.id] then
182+
return
183+
end
184+
185+
server.on_disconnect(client, code, reason)
186+
M._remove_client(server, client)
187+
end
188+
154189
---Remove a client from the server
155190
---@param server TCPServer The server object
156191
---@param client WebSocketClient The client to remove
@@ -293,7 +328,7 @@ function M.start_ping_timer(server, interval)
293328
string.format("Client %s keepalive timeout (%ds idle), closing connection", client.id, time_since_pong)
294329
)
295330
client_manager.close_client(client, 1006, "Connection timeout")
296-
M._remove_client(server, client)
331+
M._disconnect_client(server, client, 1006, "Connection timeout")
297332
end
298333
end
299334
end

tests/mocks/vim.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,7 @@ local vim = {
881881
return true
882882
end,
883883
read_start = function(self, callback)
884+
self._read_cb = callback
884885
return true
885886
end,
886887
write = function(self, data, callback)

tests/server_test.lua

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ describe("Server module", function()
226226
assert(type(server.state) == "table")
227227
assert(server.state.server == nil)
228228
assert(server.state.port == nil)
229-
assert(type(server.state.clients) == "table")
230229
assert(type(server.state.handlers) == "table")
231230
end)
232231

@@ -259,8 +258,11 @@ describe("Server module", function()
259258
assert(stop_success == true)
260259
assert(server.state.server == nil)
261260
assert(server.state.port == nil)
262-
assert(type(server.state.clients) == "table")
263-
assert(0 == #server.state.clients)
261+
262+
local status = server.get_status()
263+
assert(status.running == false)
264+
assert(status.port == nil)
265+
assert(status.client_count == 0)
264266
end)
265267

266268
it("should not stop the server if not running", function()

0 commit comments

Comments
 (0)