Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion containers/api-proxy/server.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ const { generateRequestId, sanitizeForLog, logRequest } = require('./logging');
const metrics = require('./metrics');
const rateLimiter = require('./rate-limiter');
let trackTokenUsage;
let trackWebSocketTokenUsage;
let closeLogStream;
try {
({ trackTokenUsage, closeLogStream } = require('./token-tracker'));
({ trackTokenUsage, trackWebSocketTokenUsage, closeLogStream } = require('./token-tracker'));
} catch (err) {
if (err && err.code === 'MODULE_NOT_FOUND') {
trackTokenUsage = () => {};
trackWebSocketTokenUsage = () => {};
closeLogStream = () => {};
} else {
throw err;
Expand Down Expand Up @@ -672,6 +674,15 @@ function proxyWebSocket(req, socket, head, targetHost, injectHeaders, provider,
tlsSocket.pipe(socket);
socket.pipe(tlsSocket);

// Attach WebSocket token usage tracking (non-blocking, sniffs upstream frames)
trackWebSocketTokenUsage(tlsSocket, {
requestId,
provider,
path: sanitizeForLog(req.url),
startTime,
metrics,
});

// Finalize once when either side closes; destroy the other side.
socket.once('close', () => { finalize(false); tlsSocket.destroy(); });
tlsSocket.once('close', () => { finalize(false); socket.destroy(); });
Expand Down
279 changes: 272 additions & 7 deletions containers/api-proxy/token-tracker.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,37 @@
// Token usage log file path (inside the mounted log volume)
const TOKEN_LOG_DIR = process.env.AWF_TOKEN_LOG_DIR || '/var/log/api-proxy';
const TOKEN_LOG_FILE = path.join(TOKEN_LOG_DIR, 'token-usage.jsonl');
const DIAG_LOG_FILE = path.join(TOKEN_LOG_DIR, 'token-diag.log');

let logStream = null;
let diagStream = null;

// Log that the module loaded successfully
try {
fs.mkdirSync(TOKEN_LOG_DIR, { recursive: true });
fs.writeFileSync(
DIAG_LOG_FILE,
`${new Date().toISOString()} TOKEN_TRACKER_LOADED dir=${TOKEN_LOG_DIR}\n`,
{ flag: 'a' }
);
} catch { /* best-effort — dir may not be writable yet */ }

/**
* Write a diagnostic line to the diagnostics log file.
* This file is captured in the artifact alongside token-usage.jsonl.
*/
function diag(msg, data) {
try {
if (!diagStream) {
fs.mkdirSync(TOKEN_LOG_DIR, { recursive: true });
diagStream = fs.createWriteStream(DIAG_LOG_FILE, { flags: 'a' });
diagStream.on('error', () => { diagStream = null; });
}
const line = `${new Date().toISOString()} ${msg}` +
(data ? ' ' + JSON.stringify(data) : '') + '\n';
diagStream.write(line);

Check warning

Code scanning / CodeQL

Network data written to file Medium

Write to file system depends on
Untrusted data
.
} catch { /* best-effort */ }
}

/**
* Get or create the JSONL append stream for token usage logs.
Expand Down Expand Up @@ -253,6 +282,17 @@
function trackTokenUsage(proxyRes, opts) {
const { requestId, provider, path: reqPath, startTime, metrics: metricsRef } = opts;
const streaming = isStreamingResponse(proxyRes.headers);
const contentType = proxyRes.headers['content-type'] || '(none)';

logRequest('debug', 'token_track_start', {
request_id: requestId,
provider,
path: reqPath,
streaming,
content_type: contentType,
status: proxyRes.statusCode,
});
diag('HTTP_TRACK_START', { request_id: requestId, provider, path: reqPath, streaming, content_type: contentType, status: proxyRes.statusCode });

// Accumulate response body for usage extraction
const chunks = [];
Expand Down Expand Up @@ -302,7 +342,15 @@

proxyRes.on('end', () => {
// Only process successful responses (2xx)
if (proxyRes.statusCode < 200 || proxyRes.statusCode >= 300) return;
if (proxyRes.statusCode < 200 || proxyRes.statusCode >= 300) {
logRequest('debug', 'token_track_skip_status', {
request_id: requestId,
provider,
status: proxyRes.statusCode,
});
diag('HTTP_TRACK_SKIP_STATUS', { request_id: requestId, provider, status: proxyRes.statusCode });
return;
}

const duration = Date.now() - startTime;
let usage = null;
Expand Down Expand Up @@ -334,6 +382,18 @@
model = result.model;
}

logRequest('debug', 'token_track_end', {
request_id: requestId,
provider,
streaming,
total_bytes: totalBytes,
overflow,
has_usage: !!usage,
usage_keys: usage ? Object.keys(usage) : [],
model,
});
diag('HTTP_TRACK_END', { request_id: requestId, provider, streaming, total_bytes: totalBytes, overflow, has_usage: !!usage, usage_keys: usage ? Object.keys(usage) : [], model });

const normalized = normalizeUsage(usage);
if (!normalized) return;

Expand Down Expand Up @@ -377,30 +437,235 @@
});
}

/**
* Parse WebSocket frames from a buffer (server→client direction, unmasked).
*
* Returns an object with:
* - messages: Array of decoded text frame payloads (strings)
* - consumed: Number of bytes consumed from the buffer
*
* Only handles non-fragmented text frames (FIN=1, opcode=1).
* Other frame types (binary, ping, pong, close, continuation) are consumed
* but their payloads are not returned.
*
* @param {Buffer} buf - Buffer containing WebSocket frame data
* @returns {{ messages: string[], consumed: number }}
*/
function parseWebSocketFrames(buf) {
const messages = [];
let pos = 0;

while (pos + 2 <= buf.length) {
const firstByte = buf[pos];
const secondByte = buf[pos + 1];
const fin = (firstByte & 0x80) !== 0;
const opcode = firstByte & 0x0F;
const masked = (secondByte & 0x80) !== 0;
let payloadLength = secondByte & 0x7F;
let headerSize = 2;

if (payloadLength === 126) {
if (pos + 4 > buf.length) break;
payloadLength = buf.readUInt16BE(pos + 2);
headerSize = 4;
} else if (payloadLength === 127) {
if (pos + 10 > buf.length) break;
payloadLength = Number(buf.readBigUInt64BE(pos + 2));
headerSize = 10;
}

if (masked) headerSize += 4; // skip masking key

const frameEnd = pos + headerSize + payloadLength;
if (frameEnd > buf.length) break;

// Extract text frames (opcode 1) with FIN set
if (opcode === 1 && fin) {
messages.push(buf.slice(pos + headerSize, frameEnd).toString('utf8'));
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parseWebSocketFrames advances past the masking key when masked is set, but it does not unmask the payload before decoding it as UTF-8. That means masked text frames would produce corrupted messages. Since this parser is exported and already branches on masked, either properly unmask the payload or explicitly treat masked frames as unsupported and skip returning them.

Suggested change
if (masked) headerSize += 4; // skip masking key
const frameEnd = pos + headerSize + payloadLength;
if (frameEnd > buf.length) break;
// Extract text frames (opcode 1) with FIN set
if (opcode === 1 && fin) {
messages.push(buf.slice(pos + headerSize, frameEnd).toString('utf8'));
let maskingKey = null;
if (masked) {
// Ensure we have enough bytes for the masking key
if (pos + headerSize + 4 > buf.length) break;
maskingKey = buf.slice(pos + headerSize, pos + headerSize + 4);
headerSize += 4;
}
const payloadStart = pos + headerSize;
const frameEnd = payloadStart + payloadLength;
if (frameEnd > buf.length) break;
// Extract text frames (opcode 1) with FIN set
if (opcode === 1 && fin) {
if (masked && maskingKey) {
const maskedPayload = buf.slice(payloadStart, frameEnd);
const unmaskedPayload = Buffer.allocUnsafe(payloadLength);
for (let i = 0; i < payloadLength; i++) {
unmaskedPayload[i] = maskedPayload[i] ^ maskingKey[i % 4];
}
messages.push(unmaskedPayload.toString('utf8'));
} else {
messages.push(buf.slice(payloadStart, frameEnd).toString('utf8'));
}

Copilot uses AI. Check for mistakes.
}

pos = frameEnd;
}

return { messages, consumed: pos };
}

/**
* Attach token usage tracking to a WebSocket upstream connection.
*
* Claude Code CLI uses WebSocket streaming to the Anthropic API. The
* api-proxy relays this as a raw socket pipe (tlsSocket ↔ clientSocket).
* This function adds a non-blocking 'data' listener on the upstream socket
* to parse WebSocket frames and extract token usage from JSON text messages.
*
* The upstream stream starts with an HTTP 101 response header, followed by
* WebSocket frames. This function skips the HTTP header before parsing frames.
*
* @param {import('tls').TLSSocket} upstreamSocket - Upstream TLS socket
* @param {object} opts
* @param {string} opts.requestId - Request ID for correlation
* @param {string} opts.provider - Provider name (anthropic, copilot, etc.)
* @param {string} opts.path - Request path
* @param {number} opts.startTime - Request start time (Date.now())
* @param {object} opts.metrics - Metrics module reference
*/
function trackWebSocketTokenUsage(upstreamSocket, opts) {
const { requestId, provider, path: reqPath, startTime, metrics: metricsRef } = opts;

logRequest('debug', 'ws_token_track_start', {
request_id: requestId,
provider,
path: reqPath,
});
diag('WS_TRACK_START', { request_id: requestId, provider, path: reqPath });

let httpHeaderParsed = false;
let buffer = Buffer.alloc(0);
let totalBytes = 0;
let streamingUsage = {};
let streamingModel = null;
let finalized = false;
let frameCount = 0;
let textMessageCount = 0;

// Max buffer to prevent unbounded memory growth (1 MB)
const MAX_WS_BUFFER = 1 * 1024 * 1024;

upstreamSocket.on('data', (chunk) => {
totalBytes += chunk.length;
buffer = Buffer.concat([buffer, chunk]);

// Safety: drop buffer if it grows too large (malformed frames)
if (buffer.length > MAX_WS_BUFFER) {
buffer = Buffer.alloc(0);
httpHeaderParsed = true; // skip header parsing
return;
}

// Skip the HTTP 101 Switching Protocols response header
if (!httpHeaderParsed) {
const headerEnd = buffer.indexOf('\r\n\r\n');
if (headerEnd === -1) return; // need more data for full header
buffer = buffer.slice(headerEnd + 4);
httpHeaderParsed = true;
Comment on lines +614 to +631
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

response_bytes in trackWebSocketTokenUsage currently counts the initial HTTP 101 upgrade header bytes as well as WebSocket frame bytes. In trackTokenUsage (HTTP) response_bytes counts only the response body, so this makes the field inconsistent across transports. Consider tracking header bytes separately and reporting only post-header WebSocket payload bytes as response_bytes (or renaming/adding a field to clarify semantics).

Copilot uses AI. Check for mistakes.
}

// Parse any complete WebSocket frames
const { messages, consumed } = parseWebSocketFrames(buffer);
if (consumed > 0) {
buffer = buffer.slice(consumed);
}
frameCount += messages.length;

for (const text of messages) {
textMessageCount++;
const { usage, model } = extractUsageFromSseLine(text);
if (model && !streamingModel) streamingModel = model;
if (usage) {
logRequest('debug', 'ws_token_usage_found', {
request_id: requestId,
provider,
usage_keys: Object.keys(usage),
model,
});
for (const [k, v] of Object.entries(usage)) {
streamingUsage[k] = v;
}
}
}
});

function doFinalize() {
if (finalized) return;
finalized = true;

logRequest('debug', 'ws_token_track_end', {
request_id: requestId,
provider,
total_bytes: totalBytes,
frame_count: frameCount,
text_message_count: textMessageCount,
has_usage: Object.keys(streamingUsage).length > 0,
usage_keys: Object.keys(streamingUsage),
model: streamingModel,
});
diag('WS_TRACK_END', { request_id: requestId, provider, total_bytes: totalBytes, frame_count: frameCount, text_message_count: textMessageCount, has_usage: Object.keys(streamingUsage).length > 0, usage_keys: Object.keys(streamingUsage), model: streamingModel });

if (Object.keys(streamingUsage).length === 0) return;

const duration = Date.now() - startTime;
const normalized = normalizeUsage(streamingUsage);
if (!normalized) return;

if (metricsRef) {
metricsRef.increment('input_tokens_total', { provider }, normalized.input_tokens);
metricsRef.increment('output_tokens_total', { provider }, normalized.output_tokens);
}

const record = {
timestamp: new Date().toISOString(),
request_id: requestId,
provider,
model: streamingModel || 'unknown',
path: reqPath,
status: 200,
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In WebSocket token usage records, status is hard-coded to 200. For a WebSocket upgrade the actual HTTP status is 101 (and the surrounding WebSocket request metrics/logging treat it as 1xx). Please set this to 101 (or capture the real upgrade status) so downstream analysis doesn’t misclassify these entries as normal 200 HTTP responses.

Suggested change
status: 200,
status: 101,

Copilot uses AI. Check for mistakes.
streaming: true,
input_tokens: normalized.input_tokens,
output_tokens: normalized.output_tokens,
cache_read_tokens: normalized.cache_read_tokens,
cache_write_tokens: normalized.cache_write_tokens,
duration_ms: duration,
response_bytes: totalBytes,
};

writeTokenUsage(record);

logRequest('info', 'token_usage', {
request_id: requestId,
provider,
model: streamingModel || 'unknown',
input_tokens: normalized.input_tokens,
output_tokens: normalized.output_tokens,
cache_read_tokens: normalized.cache_read_tokens,
cache_write_tokens: normalized.cache_write_tokens,
streaming: true,
transport: 'websocket',
});
}

upstreamSocket.on('close', doFinalize);
upstreamSocket.on('end', doFinalize);
}

/**
* Close the log stream (for graceful shutdown).
* Returns a Promise that resolves once the stream has flushed.
*/
function closeLogStream() {
return new Promise((resolve) => {
let pending = 0;
const check = () => { if (pending === 0) resolve(); };
if (logStream) {
logStream.end(() => {
logStream = null;
resolve();
});
} else {
resolve();
pending++;
logStream.end(() => { logStream = null; pending--; check(); });
}
if (diagStream) {
pending++;
diagStream.end(() => { diagStream = null; pending--; check(); });
}
if (pending === 0) resolve();
});
}

module.exports = {
trackTokenUsage,
trackWebSocketTokenUsage,
closeLogStream,
// Exported for testing
extractUsageFromJson,
extractUsageFromSseLine,
parseSseDataLines,
parseWebSocketFrames,
normalizeUsage,
isStreamingResponse,
writeTokenUsage,
Expand Down
Loading
Loading