From 127dae3038bca0f19f78b9eb70d7fd93980aae65 Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Fri, 27 Feb 2026 07:52:01 +0100 Subject: [PATCH 1/2] RPC: Add action field to GetObject for error recovery Add a nullable `action` field to GetObject requests. When null, the request is a normal data transfer. When set to "revert", the handler restores both remoteObjects and localObjects to the pre-transfer state, and the requester reverts its own remoteObjects tracking to match. This fixes state desynchronization after failed deserialization: the handler stores the pre-transfer baseline in an actionBaseline map at transfer start, optimistically updates remoteObjects when streaming completes, and rolls both maps back on revert. The action field is extensible for future corrective actions (clear, reset, abort) without protocol changes. --- doc/adr/0009-getobject-action-field.md | 88 ++++++++++++++++++ doc/adr/README.md | 4 + .../java/org/openrewrite/rpc/RewriteRpc.java | 32 ++++++- .../openrewrite/rpc/request/GetObject.java | 91 ++++++++++++++++--- .../rewrite/src/rpc/request/get-object.ts | 27 +++++- .../rewrite/src/rpc/rewrite-rpc.ts | 36 ++++++-- .../rewrite/src/rewrite/rpc/server.py | 72 +++++++++++---- 7 files changed, 308 insertions(+), 42 deletions(-) create mode 100644 doc/adr/0009-getobject-action-field.md diff --git a/doc/adr/0009-getobject-action-field.md b/doc/adr/0009-getobject-action-field.md new file mode 100644 index 00000000000..fd141e850db --- /dev/null +++ b/doc/adr/0009-getobject-action-field.md @@ -0,0 +1,88 @@ +# 9. GetObject action field for reliable state synchronization + +Date: 2026-02-26 + +## Status + +Accepted (supersedes ADR 8) + +## Context + +The diff-based GetObject protocol tracks state on both sides: the sender maintains a `remoteObjects` map recording what the receiver last successfully received, and uses this as the baseline for computing diffs on subsequent transfers. This tracking is updated optimistically — the sender assumes the receiver consumed the data successfully once streaming completes. + +When the receiver fails mid-deserialization (e.g., `ClassCastException` from an invalid AST node), the two sides go out of sync: the sender thinks the receiver has version N, but the receiver discarded it. + +### The Print problem + +This manifests concretely with `Print`. After a composite recipe runs, Java computes diffs by printing both the `before` and `after` trees. For RPC-based languages (Python, JavaScript), printing works as follows: + +1. Java sends a Print RPC to the remote (Python) +2. Python's `handle_print` calls `get_object_from_java(tree_id)`, sending GetObject back to Java +3. Java's `GetObject.Handler` computes a diff against `remoteObjects[id]` (its belief of what Python has) and streams the result + +If a prior Visit failed in the *reverse* direction (Java requesting a modified tree from Python), the cleanup at `RewriteRpc.getObject()` only removes Java's **requester-side** `remoteObjects` entry. Java's **handler-side** `remoteObjects` (used by `GetObject.Handler` when Python requests from Java) may still reflect a state that Python no longer has. The subsequent Print-triggered GetObject computes a diff against the wrong baseline, producing corrupt data or errors. + +### Fundamental issue: unilateral state updates + +The root cause is that `remoteObjects` is updated unilaterally by the sender without confirmation from the receiver. If the receiver fails to deserialize, the sender has no way to learn this — the stale state persists and affects all subsequent operations in either direction. + +## Decision + +Add an `action` field to the GetObject request. This nullable string field allows the receiver to send corrective actions back to the handler. When null, the request is a normal data-transfer request. + +### The `revert` action + +When the receiver fails to deserialize a transferred object, it sends a GetObject request with `action: "revert"`. The handler: + +1. Restores `remoteObjects[id]` to the pre-transfer value (stored in an `actionBaseline` map at transfer start) +2. Restores `localObjects[id]` to the same pre-transfer value — this ensures the failed modification is discarded rather than retried with the same broken diff +3. Cancels any in-progress batch send for that ID + +This reverts both sides to a consistent, known-good state. The receiver also clears its own `remoteObjects[id]` tracking, so the next transfer starts fresh. + +### Optimistic updates with rollback + +Unlike a deferred-commit (ACK/NACK) approach, `remoteObjects` is updated optimistically when streaming completes — no extra round-trip is needed on the success path. The `actionBaseline` map stores the pre-transfer value so that `revert` can roll it back on the failure path. + +### Extensibility + +The `action` field is designed to support future corrective actions beyond `revert`: + +- `"clear"` / `"remove"` — tell the handler to drop all tracking for this ID (e.g., when the caller knows the object is no longer needed) +- `"abort"` — cancel an in-progress batched transfer mid-stream +- `"reset"` — force a full re-serialization + +### Protocol flow + +**Success path** (no extra round-trip): +1. Handler streams batches, optimistically updates `remoteObjects[id] = after` +2. Receiver processes batches, updates its own `remoteObjects` and `localObjects` +3. Done — no confirmation needed + +**Failure path** (one extra round-trip): +1. Handler streams batches, optimistically updates `remoteObjects[id] = after` +2. Receiver fails to deserialize +3. Receiver sends `GetObject(id, sourceFileType, action="revert")` +4. Handler restores `remoteObjects[id]` and `localObjects[id]` from `actionBaseline` +5. Handler returns empty list + +### Relationship to ADR 8 (`reset` flag) + +The `reset` flag from ADR 8 is removed. The `revert` action makes it unnecessary — instead of the receiver hinting "I lost sync" on its *next* request, it explicitly tells the handler to roll back immediately after failure. + +## Consequences + +**Positive:** +- No extra round-trip on the success path (unlike an ACK-based approach) +- On failure, reverts both `remoteObjects` and `localObjects` to a consistent pre-transfer state, preventing cascading errors +- Fixes the Print problem: the handler's `remoteObjects` is rolled back before any Print-triggered GetObject can observe stale state +- Extensible: the `action` field can carry future corrective actions without protocol changes +- Works for all GetObject consumers (Visit, Print, Generate) in both directions + +**Negative:** +- Handler must store pre-transfer baselines (`actionBaseline` map) for potential rollback — one extra object reference per active transfer +- Reverting `localObjects` means the handler discards its local modification on failure, which is a deliberate policy choice: if the receiver can't deserialize it, retrying would just fail again + +**Trade-offs:** +- The `actionBaseline` entries persist until overwritten by the next transfer for the same ID, rather than being cleaned up immediately on success. The memory cost is bounded by the number of active object IDs and is comparable to `remoteObjects` itself +- The inline-Visit optimization (bundling tree data with Visit request/response to eliminate GetObject round-trips) remains a complementary performance improvement that could be pursued independently diff --git a/doc/adr/README.md b/doc/adr/README.md index d6e5e0baae1..f1fe1294c97 100644 --- a/doc/adr/README.md +++ b/doc/adr/README.md @@ -4,3 +4,7 @@ * [2. Naming recipes](0002-recipe-naming.md) * [3. OSS contributor's guidelines](0003-oss-contributors.md) * [4. Library migration recipe conventions](0004-library-migration-conventions.md) +* [5. Parser and LST conventions](0005-parser-lst-conventions.md) +* [6. Recipe marketplace CSV format](0006-recipe-marketplace-csv-format.md) +* [7. JavaScript templating engine enhancements](0007-javascript-templating-enhancements.md) +* [9. GetObject action field for error recovery](0009-getobject-action-field.md) diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java b/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java index 7556d3bc85d..bbcb756da0f 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java @@ -466,13 +466,37 @@ public T getObject(String id, @Nullable String sourceFileType) { RpcReceiveQueue q = new RpcReceiveQueue( remoteRefs, - () -> send("GetObject", new GetObject(id, sourceFileType), GetObjectResponse.class), + () -> send("GetObject", new GetObject(id, sourceFileType, null), GetObjectResponse.class), sourceFileType, log.get() ); - Object remoteObject = q.receive(localObject, null); - if (q.take().getState() != END_OF_OBJECT) { - throw new IllegalStateException("Expected END_OF_OBJECT"); + Object before = remoteObjects.get(id); + Object remoteObject; + try { + remoteObject = q.receive(localObject, null); + if (q.take().getState() != END_OF_OBJECT) { + throw new IllegalStateException("Expected END_OF_OBJECT"); + } + } catch (Exception e) { + // Tell the handler to revert both remoteObjects and localObjects + // to the pre-transfer state + try { + send("GetObject", new GetObject(id, sourceFileType, "revert"), GetObjectResponse.class); + } catch (Exception revertError) { + PrintStream logFile = log.get(); + if (logFile != null) { + revertError.printStackTrace(logFile); + } + } + // Revert our tracking to match the handler's reverted state. + // The handler restored remoteObjects[id] to the pre-transfer + // value, so the requester must do the same to stay in sync. + if (before != null) { + remoteObjects.put(id, before); + } else { + remoteObjects.remove(id); + } + throw e; } //noinspection ConstantValue diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java b/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java index e8fda4192fe..726d1e6a4b3 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java @@ -23,15 +23,14 @@ import org.openrewrite.rpc.RpcSendQueue; import java.io.PrintStream; -import java.util.ArrayList; -import java.util.IdentityHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static java.util.Collections.emptyList; import static org.openrewrite.rpc.RpcObjectData.State.DELETE; import static org.openrewrite.rpc.RpcObjectData.State.END_OF_OBJECT; @@ -42,6 +41,20 @@ public class GetObject implements RpcRequest { @Nullable String sourceFileType; + /** + * An action for the handler to perform instead of a normal data transfer. + * When null, this is a normal data-transfer request. + *

+ * Supported actions: + *

+ */ + @Nullable + String action; + @RequiredArgsConstructor public static class Handler extends JsonRpcMethod { private static final ExecutorService forkJoin = ForkJoinPool.commonPool(); @@ -59,32 +72,82 @@ public static class Handler extends JsonRpcMethod { private final AtomicReference log; private final Supplier traceGetObject; - private final Map>> inProgressGetRpcObjects = new ConcurrentHashMap<>(); + @RequiredArgsConstructor + private static class InProgressSend { + final BlockingQueue> queue; + final @Nullable Object before; + final AtomicBoolean cancelled; + } + + private final Map inProgressGetRpcObjects = new ConcurrentHashMap<>(); + + /** + * Stores the pre-transfer {@code remoteObjects} value for each in-flight + * or recently completed transfer. Used by the "revert" action to restore + * both {@code remoteObjects} and {@code localObjects} to the state before + * the transfer started. + */ + private final Map actionBaseline = new HashMap<>(); @Override protected List handle(GetObject request) throws Exception { + String action = request.getAction(); + if (action != null) { + if ("revert".equals(action)) { + String id = request.getId(); + InProgressSend stale = inProgressGetRpcObjects.remove(id); + if (stale != null) { + stale.cancelled.set(true); + } + if (actionBaseline.containsKey(id)) { + Object before = actionBaseline.remove(id); + if (before != null) { + remoteObjects.put(id, before); + localObjects.put(id, before); + } else { + remoteObjects.remove(id); + localObjects.remove(id); + } + } + } + return emptyList(); + } + Object after = localObjects.get(request.getId()); if (after == null) { + // Clean up any stale in-progress send for this ID + InProgressSend stale = inProgressGetRpcObjects.remove(request.getId()); + if (stale != null) { + stale.cancelled.set(true); + } + List deleted = new ArrayList<>(2); deleted.add(new RpcObjectData(DELETE, null, null, null, traceGetObject.get())); deleted.add(new RpcObjectData(END_OF_OBJECT, null, null, null, traceGetObject.get())); return deleted; } - BlockingQueue> q = inProgressGetRpcObjects.computeIfAbsent(request.getId(), id -> { + Object currentBefore = remoteObjects.get(request.getId()); + + InProgressSend inProgress = inProgressGetRpcObjects.computeIfAbsent(request.getId(), id -> { + // Save the pre-transfer baseline for potential revert + actionBaseline.put(id, currentBefore); + BlockingQueue> batch = new ArrayBlockingQueue<>(1); - Object before = remoteObjects.get(id); + AtomicBoolean cancelled = new AtomicBoolean(false); RpcSendQueue sendQueue = new RpcSendQueue(batchSize.get(), batch::put, localRefs, request.getSourceFileType(), traceGetObject.get()); forkJoin.submit(() -> { try { - sendQueue.send(after, before, null); + sendQueue.send(after, currentBefore, null); - // All the data has been sent, and the remote should have received - // the full tree, so update our understanding of the remote state - // of this tree. - remoteObjects.put(id, after); + // Optimistically update remoteObjects — the receiver is + // expected to send action="revert" if deserialization fails, + // which will roll this back. + if (!cancelled.get()) { + remoteObjects.put(id, after); + } } catch (Throwable t) { PrintStream logFile = log.get(); //noinspection ConstantValue @@ -97,10 +160,10 @@ protected List handle(GetObject request) throws Exception { } return 0; }); - return batch; + return new InProgressSend(batch, currentBefore, cancelled); }); - List batch = q.take(); + List batch = inProgress.queue.take(); if (batch.get(batch.size() - 1).getState() == END_OF_OBJECT) { inProgressGetRpcObjects.remove(request.getId()); } diff --git a/rewrite-javascript/rewrite/src/rpc/request/get-object.ts b/rewrite-javascript/rewrite/src/rpc/request/get-object.ts index cb9ce154dfa..71f4d5ab235 100644 --- a/rewrite-javascript/rewrite/src/rpc/request/get-object.ts +++ b/rewrite-javascript/rewrite/src/rpc/request/get-object.ts @@ -20,7 +20,8 @@ import {extractSourcePath, withMetrics} from "./metrics"; export class GetObject { constructor(private readonly id: string, - private readonly sourceFileType?: string) { + private readonly sourceFileType?: string, + private readonly action?: string) { } static handle( @@ -33,6 +34,7 @@ export class GetObject { metricsCsv?: string, ): void { const pendingData = new Map(); + const actionBaseline = new Map(); connection.onRequest( new rpc.RequestType("GetObject"), @@ -41,6 +43,25 @@ export class GetObject { metricsCsv, (context) => async request => { const objId = request.id; + + // Handle actions from the receiver + if (request.action) { + if (request.action === 'revert') { + const before = actionBaseline.get(objId); + actionBaseline.delete(objId); + pendingData.delete(objId); + if (before !== undefined) { + remoteObjects.set(objId, before); + localObjects.set(objId, before); + } else { + remoteObjects.delete(objId); + localObjects.delete(objId); + } + } + context.target = ''; + return []; + } + if (!localObjects.has(objId)) { context.target = ''; return [ @@ -63,10 +84,14 @@ export class GetObject { const after = obj; const before = remoteObjects.get(objId); + // Save baseline for potential revert + actionBaseline.set(objId, before); + allData = await new RpcSendQueue(localRefs, request.sourceFileType, trace()) .generate(after, before); pendingData.set(objId, allData); + // Optimistic update — receiver sends action="revert" on failure remoteObjects.set(objId, after); } diff --git a/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts b/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts index d62cc5ec7bf..3f390514ebf 100644 --- a/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts +++ b/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts @@ -133,12 +133,36 @@ export class RewriteRpc { ); }, this.logger, this.traceGetObject.receive); - const remoteObject = await q.receive

(localObject); - - const eof = (await q.take()); - if (eof.state !== RpcObjectState.END_OF_OBJECT) { - RpcObjectData.logTrace(eof, this.traceGetObject.receive, this.logger); - throw new Error(`Expected END_OF_OBJECT but got: ${eof.state}`); + const before = this.remoteObjects.get(id); + let remoteObject: P; + try { + remoteObject = await q.receive

(localObject); + + const eof = (await q.take()); + if (eof.state !== RpcObjectState.END_OF_OBJECT) { + RpcObjectData.logTrace(eof, this.traceGetObject.receive, this.logger); + throw new Error(`Expected END_OF_OBJECT but got: ${eof.state}`); + } + } catch (e) { + // Tell the handler to revert both remoteObjects and localObjects + // to the pre-transfer state + try { + await this.connection.sendRequest( + new rpc.RequestType("GetObject"), + new GetObject(id, sourceFileType, 'revert'), + ); + } catch { + // Best-effort revert + } + // Revert our tracking to match the handler's reverted state. + // The handler restored remoteObjects[id] to the pre-transfer + // value, so the requester must do the same to stay in sync. + if (before !== undefined) { + this.remoteObjects.set(id, before); + } else { + this.remoteObjects.delete(id); + } + throw e; } this.remoteObjects.set(id, remoteObject); diff --git a/rewrite-python/rewrite/src/rewrite/rpc/server.py b/rewrite-python/rewrite/src/rewrite/rpc/server.py index c151fc0db74..5844985012c 100644 --- a/rewrite-python/rewrite/src/rewrite/rpc/server.py +++ b/rewrite-python/rewrite/src/rewrite/rpc/server.py @@ -56,6 +56,8 @@ remote_objects: Dict[str, Any] = {} # Remote refs - maps reference IDs to objects for cyclic graph handling remote_refs: Dict[int, Any] = {} +# Action baseline - stores pre-transfer remote_objects values for potential revert +_action_baseline: Dict[str, Any] = {} # Request ID counter for outgoing requests _request_id_counter = 0 @@ -150,6 +152,8 @@ def get_object_from_java(obj_id: str, source_file_type: Optional[str] = None) -> # Track whether we've received the complete object received_end = False + before = remote_objects.get(obj_id) + def pull_batch() -> List[Dict[str, Any]]: """Pull the next batch of RpcObjectData from Java. @@ -160,7 +164,7 @@ def pull_batch() -> List[Dict[str, Any]]: IMPORTANT: We filter out END_OF_OBJECT from the returned batch to prevent it from being accidentally consumed during nested operations (like receive_list expecting positions). Java's RewriteRpc.java explicitly consumes END_OF_OBJECT - after receive() completes (line 474), and we do the same by tracking received_end. + after receive() completes, and we do the same by tracking received_end. """ nonlocal received_end @@ -168,10 +172,11 @@ def pull_batch() -> List[Dict[str, Any]]: return [] # Request the next batch from Java - batch = send_request('GetObject', { + request = { 'id': obj_id, 'sourceFileType': source_file_type - }) + } + batch = send_request('GetObject', request) if not batch: received_end = True @@ -189,17 +194,32 @@ def pull_batch() -> List[Dict[str, Any]]: q = RpcReceiveQueue(remote_refs, source_file_type, pull_batch, trace=_trace_rpc) receiver = PythonRpcReceiver() - # Get the "before" state - our understanding of what Java had - # This is used to apply diffs from the GetObject response - before = remote_objects.get(obj_id) - # Receive and deserialize the object (applies diffs to before state) - obj = receiver.receive(before, q) - - # Verify we received the complete object (END_OF_OBJECT was in the final batch) - # This matches Java's RewriteRpc.java line 474-475 which explicitly checks for END_OF_OBJECT - if not received_end: - raise RuntimeError(f"Did not receive END_OF_OBJECT marker for object {obj_id}") + try: + obj = receiver.receive(before, q) + + # Verify we received the complete object (END_OF_OBJECT was in the final batch) + if not received_end: + raise RuntimeError(f"Did not receive END_OF_OBJECT marker for object {obj_id}") + except Exception: + # Tell the handler to revert both remoteObjects and localObjects + # to the pre-transfer state + try: + send_request('GetObject', { + 'id': obj_id, + 'sourceFileType': source_file_type, + 'action': 'revert' + }) + except Exception: + pass # Best-effort revert + # Revert our tracking to match the handler's reverted state. + # The handler restored remoteObjects[id] to the pre-transfer + # value, so the requester must do the same to stay in sync. + if before is not None: + remote_objects[obj_id] = before + else: + remote_objects.pop(obj_id, None) + raise if obj is not None: # Update our understanding of what Java has @@ -456,14 +476,28 @@ def handle_get_object(params: dict) -> List[dict]: This serializes an object for RPC transfer as RpcObjectData[]. Returns list of RpcObjectData objects that Java can deserialize. - After sending, we update remote_objects to track that the remote (Java) - now has this version of the object. This is essential for the diff-based - RPC protocol to work correctly. + The remoteObjects update is optimistic — on failure, the receiver + sends action="revert" to roll back both remote and local state. """ obj_id = params.get('id') source_file_type = params.get('sourceFileType') + action = params.get('action') + if obj_id is None: return [{'state': 'DELETE'}, {'state': 'END_OF_OBJECT'}] + + # Handle actions from the receiver + if action is not None: + if action == 'revert': + before = _action_baseline.pop(obj_id, None) + if before is not None: + remote_objects[obj_id] = before + local_objects[obj_id] = before + else: + remote_objects.pop(obj_id, None) + local_objects.pop(obj_id, None) + return [] + obj = local_objects.get(obj_id) logger.debug(f"handle_get_object: id={obj_id}, type={type(obj).__name__ if obj else 'None'}") @@ -479,10 +513,13 @@ def handle_get_object(params: dict) -> List[dict]: # Get the "before" state - what we previously sent to Java before = remote_objects.get(obj_id) + # Save baseline for potential revert + _action_baseline[obj_id] = before + q = RpcSendQueue(source_file_type) result = q.generate(obj, before) - # Update remote_objects to track that Java now has this version + # Optimistically update — receiver sends action="revert" on failure remote_objects[obj_id] = obj return result @@ -562,6 +599,7 @@ def handle_reset(params: dict) -> bool: local_objects.clear() remote_objects.clear() remote_refs.clear() + _action_baseline.clear() _prepared_recipes.clear() _execution_contexts.clear() _recipe_accumulators.clear() From 781297558941c67aa681b5c418bab2b6f324d369 Mon Sep 17 00:00:00 2001 From: Knut Wannheden Date: Thu, 5 Mar 2026 17:39:33 +0100 Subject: [PATCH 2/2] Back out ref registrations on failed RPC transfers When an RPC object transfer fails, refs registered in localRefs (sender) and remoteRefs (receiver) during the failed transfer were not cleaned up. On retry, the sender would send ref-ID-only (no payload) but the receiver would throw "reference to object not previously sent". Track newly registered refs during each transfer and remove them from the shared maps on failure. Also fix sender-side failure cleanup: remove from actionBaseline so a subsequent "revert" from the receiver doesn't restore the remoteObjects entry the sender already cleaned up. --- .../java/org/openrewrite/rpc/RewriteRpc.java | 4 ++++ .../org/openrewrite/rpc/RpcReceiveQueue.java | 9 ++++++++ .../org/openrewrite/rpc/RpcSendQueue.java | 9 ++++++++ .../openrewrite/rpc/request/GetObject.java | 22 +++++++++++++++++++ rewrite-javascript/rewrite/src/rpc/queue.ts | 12 ++++++++++ .../rewrite/src/rpc/request/get-object.ts | 19 ++++++++++++++-- .../rewrite/src/rpc/rewrite-rpc.ts | 4 ++++ .../rewrite/src/rewrite/rpc/receive_queue.py | 7 ++++++ .../rewrite/src/rewrite/rpc/server.py | 3 +++ 9 files changed, 87 insertions(+), 2 deletions(-) diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java b/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java index 01407b962a5..fdd5cfe3c29 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java @@ -498,6 +498,10 @@ public T getObject(String id, @Nullable String sourceFileType) { } else { remoteObjects.remove(id); } + // Back out refs registered during this failed receive + for (Integer refId : q.getNewRefIds()) { + remoteRefs.remove(refId); + } throw e; } diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java index 4ebc87316bc..3c5f6939c11 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java @@ -34,6 +34,7 @@ public class RpcReceiveQueue { private final Supplier> pull; private final @Nullable String sourceFileType; private final @Nullable PrintStream log; + private final List newRefIds = new ArrayList<>(); public RpcReceiveQueue(Map refs, Supplier> pull, @Nullable String sourceFileType, @Nullable PrintStream log) { @@ -44,6 +45,13 @@ public RpcReceiveQueue(Map refs, Supplier> this.pull = pull; } + /** + * @return the ref IDs that were newly registered during this receive. + */ + public List getNewRefIds() { + return newRefIds; + } + public RpcObjectData take() { if (batch.isEmpty()) { List data = pull.get(); @@ -123,6 +131,7 @@ public T receive(@Nullable T before, @Nullable UnaryOperator onChange) { // immutable updates because of its cyclic nature, the before instance will ultimately // be the same as the after instance below. refs.put(ref, before); + newRefIds.add(ref); } } // Intentional fall-through... diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java index f2f085027b5..6ee880b7294 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java @@ -33,6 +33,7 @@ public class RpcSendQueue { private final IdentityHashMap refs; private final @Nullable String sourceFileType; private final boolean trace; + private final List newRefObjects = new ArrayList<>(); private @Nullable Object before; @@ -46,6 +47,13 @@ public RpcSendQueue(int batchSize, ThrowingConsumer> drain, this.trace = trace; } + /** + * @return the objects that were newly registered as refs during this send. + */ + public List getNewRefObjects() { + return newRefObjects; + } + public void put(RpcObjectData rpcObjectData) { batch.add(rpcObjectData); if (batch.size() == batchSize) { @@ -176,6 +184,7 @@ private void add(Object after, @Nullable Runnable onChange) { } ref = refs.size() + 1; refs.put(afterVal, ref); + newRefObjects.add(afterVal); } RpcCodec afterCodec = RpcCodec.forInstance(afterVal, sourceFileType); put(new RpcObjectData(ADD, getValueType(afterVal), diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java b/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java index 0d2258b56a0..3eb2ca54f62 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java @@ -89,6 +89,13 @@ private static class InProgressSend { */ private final Map actionBaseline = new HashMap<>(); + /** + * Tracks objects newly registered as refs during each in-flight transfer. + * On revert or sender-side failure, these are removed from {@code localRefs}. + * On success (all data consumed), the entry is cleared. + */ + private final Map> pendingNewRefs = new ConcurrentHashMap<>(); + @Override protected List handle(GetObject request) throws Exception { String action = request.getAction(); @@ -142,6 +149,11 @@ protected List handle(GetObject request) throws Exception { try { sendQueue.send(after, currentBefore, null); + // Track newly registered refs for potential revert + if (!sendQueue.getNewRefObjects().isEmpty()) { + pendingNewRefs.put(id, sendQueue.getNewRefObjects()); + } + // Optimistically update remoteObjects — the receiver is // expected to send action="revert" if deserialization fails, // which will roll this back. @@ -153,6 +165,14 @@ protected List handle(GetObject request) throws Exception { // forces a full object sync (ADD) instead of a delta (CHANGE) // against the stale, partially-sent baseline. remoteObjects.remove(id); + // Remove the baseline so a subsequent "revert" from the + // receiver doesn't restore the entry we just removed. + actionBaseline.remove(id); + // Back out refs registered during this failed send + for (Object obj : sendQueue.getNewRefObjects()) { + localRefs.remove(obj); + } + pendingNewRefs.remove(id); PrintStream logFile = log.get(); //noinspection ConstantValue if (logFile != null) { @@ -170,6 +190,8 @@ protected List handle(GetObject request) throws Exception { List batch = inProgress.queue.take(); if (batch.get(batch.size() - 1).getState() == END_OF_OBJECT) { inProgressGetRpcObjects.remove(request.getId()); + // Transfer completed successfully — refs are committed + pendingNewRefs.remove(request.getId()); } return batch; diff --git a/rewrite-javascript/rewrite/src/rpc/queue.ts b/rewrite-javascript/rewrite/src/rpc/queue.ts index bc43d667ce2..db193dd84cf 100644 --- a/rewrite-javascript/rewrite/src/rpc/queue.ts +++ b/rewrite-javascript/rewrite/src/rpc/queue.ts @@ -112,6 +112,7 @@ export class RpcCodecs { export class RpcSendQueue { private q: RpcObjectData[] = []; + private readonly _newRefIds: number[] = []; private before?: any; @@ -120,6 +121,10 @@ export class RpcSendQueue { private readonly trace: boolean) { } + get newRefIds(): number[] { + return this._newRefIds; + } + async generate(after: any, before: any): Promise { await this.send(after, before); @@ -233,6 +238,7 @@ export class RpcSendQueue { return; } ref = this.refs.create(after); + this._newRefIds.push(ref); } let afterCodec = onChange ? undefined : RpcCodecs.forInstance(after, this.sourceFileType); this.put({ @@ -275,6 +281,7 @@ export class RpcSendQueue { export class RpcReceiveQueue { private batch: RpcObjectData[] = []; + private readonly _newRefIds: number[] = []; constructor(private readonly refs: Map, private readonly sourceFileType: string | undefined, @@ -283,6 +290,10 @@ export class RpcReceiveQueue { private readonly trace: boolean) { } + get newRefIds(): number[] { + return this._newRefIds; + } + async take(): Promise { if (this.batch.length === 0) { this.batch = await this.pull(); @@ -336,6 +347,7 @@ export class RpcReceiveQueue { // immutable updates because of its cyclic nature, the before instance will ultimately // be the same as the after instance below. this.refs.set(ref, before); + this._newRefIds.push(ref); } } // Intentional fall-through... diff --git a/rewrite-javascript/rewrite/src/rpc/request/get-object.ts b/rewrite-javascript/rewrite/src/rpc/request/get-object.ts index 71f4d5ab235..400e8105fa9 100644 --- a/rewrite-javascript/rewrite/src/rpc/request/get-object.ts +++ b/rewrite-javascript/rewrite/src/rpc/request/get-object.ts @@ -35,6 +35,7 @@ export class GetObject { ): void { const pendingData = new Map(); const actionBaseline = new Map(); + const pendingNewRefs = new Map(); connection.onRequest( new rpc.RequestType("GetObject"), @@ -57,6 +58,13 @@ export class GetObject { remoteObjects.delete(objId); localObjects.delete(objId); } + const newRefs = pendingNewRefs.get(objId); + if (newRefs) { + for (const refId of newRefs) { + localRefs.deleteByRefId(refId); + } + pendingNewRefs.delete(objId); + } } context.target = ''; return []; @@ -87,10 +95,15 @@ export class GetObject { // Save baseline for potential revert actionBaseline.set(objId, before); - allData = await new RpcSendQueue(localRefs, request.sourceFileType, trace()) - .generate(after, before); + const sendQueue = new RpcSendQueue(localRefs, request.sourceFileType, trace()); + allData = await sendQueue.generate(after, before); pendingData.set(objId, allData); + // Track newly registered refs for potential revert + if (sendQueue.newRefIds.length > 0) { + pendingNewRefs.set(objId, sendQueue.newRefIds); + } + // Optimistic update — receiver sends action="revert" on failure remoteObjects.set(objId, after); } @@ -100,6 +113,8 @@ export class GetObject { // If we've sent all data, remove from pending if (allData.length === 0) { pendingData.delete(objId); + // Transfer completed successfully — refs are committed + pendingNewRefs.delete(objId); } return batch; diff --git a/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts b/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts index 93e2d4361cc..65c72311e32 100644 --- a/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts +++ b/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts @@ -165,6 +165,10 @@ export class RewriteRpc { } else { this.remoteObjects.delete(id); } + // Back out refs registered during this failed receive + for (const refId of q.newRefIds) { + this.remoteRefs.delete(refId); + } throw e; } diff --git a/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py b/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py index 946aa8c9238..15529d02c10 100644 --- a/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py +++ b/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py @@ -72,6 +72,12 @@ def __init__( self._source_file_type = source_file_type self._pull = pull self._trace = trace + self._new_ref_ids: List[int] = [] + + @property + def new_ref_ids(self) -> List[int]: + """Ref IDs that were newly registered during this receive.""" + return self._new_ref_ids def take(self) -> RpcObjectData: """Take the next message from the queue, fetching more if needed.""" @@ -161,6 +167,7 @@ def receive( if ref is not None: # Store for future references (handles cyclic graphs) self._refs[ref] = before + self._new_ref_ids.append(ref) # Fall through to CHANGE for field-by-field deserialization return self._do_change(before, on_change, message, ref) diff --git a/rewrite-python/rewrite/src/rewrite/rpc/server.py b/rewrite-python/rewrite/src/rewrite/rpc/server.py index 6ffe6915595..2b550e7d491 100644 --- a/rewrite-python/rewrite/src/rewrite/rpc/server.py +++ b/rewrite-python/rewrite/src/rewrite/rpc/server.py @@ -225,6 +225,9 @@ def pull_batch() -> List[Dict[str, Any]]: remote_objects[obj_id] = before else: remote_objects.pop(obj_id, None) + # Back out refs registered during this failed receive + for ref_id in q.new_ref_ids: + remote_refs.pop(ref_id, None) raise if obj is not None: