diff --git a/doc/adr/0009-getobject-action-field.md b/doc/adr/0009-getobject-action-field.md new file mode 100644 index 0000000000..fd141e850d --- /dev/null +++ b/doc/adr/0009-getobject-action-field.md @@ -0,0 +1,88 @@ +# 9. GetObject action field for reliable state synchronization + +Date: 2026-02-26 + +## Status + +Accepted (supersedes ADR 8) + +## Context + +The diff-based GetObject protocol tracks state on both sides: the sender maintains a `remoteObjects` map recording what the receiver last successfully received, and uses this as the baseline for computing diffs on subsequent transfers. This tracking is updated optimistically — the sender assumes the receiver consumed the data successfully once streaming completes. + +When the receiver fails mid-deserialization (e.g., `ClassCastException` from an invalid AST node), the two sides go out of sync: the sender thinks the receiver has version N, but the receiver discarded it. + +### The Print problem + +This manifests concretely with `Print`. After a composite recipe runs, Java computes diffs by printing both the `before` and `after` trees. For RPC-based languages (Python, JavaScript), printing works as follows: + +1. Java sends a Print RPC to the remote (Python) +2. Python's `handle_print` calls `get_object_from_java(tree_id)`, sending GetObject back to Java +3. Java's `GetObject.Handler` computes a diff against `remoteObjects[id]` (its belief of what Python has) and streams the result + +If a prior Visit failed in the *reverse* direction (Java requesting a modified tree from Python), the cleanup at `RewriteRpc.getObject()` only removes Java's **requester-side** `remoteObjects` entry. Java's **handler-side** `remoteObjects` (used by `GetObject.Handler` when Python requests from Java) may still reflect a state that Python no longer has. The subsequent Print-triggered GetObject computes a diff against the wrong baseline, producing corrupt data or errors. + +### Fundamental issue: unilateral state updates + +The root cause is that `remoteObjects` is updated unilaterally by the sender without confirmation from the receiver. If the receiver fails to deserialize, the sender has no way to learn this — the stale state persists and affects all subsequent operations in either direction. + +## Decision + +Add an `action` field to the GetObject request. This nullable string field allows the receiver to send corrective actions back to the handler. When null, the request is a normal data-transfer request. + +### The `revert` action + +When the receiver fails to deserialize a transferred object, it sends a GetObject request with `action: "revert"`. The handler: + +1. Restores `remoteObjects[id]` to the pre-transfer value (stored in an `actionBaseline` map at transfer start) +2. Restores `localObjects[id]` to the same pre-transfer value — this ensures the failed modification is discarded rather than retried with the same broken diff +3. Cancels any in-progress batch send for that ID + +This reverts both sides to a consistent, known-good state. The receiver also clears its own `remoteObjects[id]` tracking, so the next transfer starts fresh. + +### Optimistic updates with rollback + +Unlike a deferred-commit (ACK/NACK) approach, `remoteObjects` is updated optimistically when streaming completes — no extra round-trip is needed on the success path. The `actionBaseline` map stores the pre-transfer value so that `revert` can roll it back on the failure path. + +### Extensibility + +The `action` field is designed to support future corrective actions beyond `revert`: + +- `"clear"` / `"remove"` — tell the handler to drop all tracking for this ID (e.g., when the caller knows the object is no longer needed) +- `"abort"` — cancel an in-progress batched transfer mid-stream +- `"reset"` — force a full re-serialization + +### Protocol flow + +**Success path** (no extra round-trip): +1. Handler streams batches, optimistically updates `remoteObjects[id] = after` +2. Receiver processes batches, updates its own `remoteObjects` and `localObjects` +3. Done — no confirmation needed + +**Failure path** (one extra round-trip): +1. Handler streams batches, optimistically updates `remoteObjects[id] = after` +2. Receiver fails to deserialize +3. Receiver sends `GetObject(id, sourceFileType, action="revert")` +4. Handler restores `remoteObjects[id]` and `localObjects[id]` from `actionBaseline` +5. Handler returns empty list + +### Relationship to ADR 8 (`reset` flag) + +The `reset` flag from ADR 8 is removed. The `revert` action makes it unnecessary — instead of the receiver hinting "I lost sync" on its *next* request, it explicitly tells the handler to roll back immediately after failure. + +## Consequences + +**Positive:** +- No extra round-trip on the success path (unlike an ACK-based approach) +- On failure, reverts both `remoteObjects` and `localObjects` to a consistent pre-transfer state, preventing cascading errors +- Fixes the Print problem: the handler's `remoteObjects` is rolled back before any Print-triggered GetObject can observe stale state +- Extensible: the `action` field can carry future corrective actions without protocol changes +- Works for all GetObject consumers (Visit, Print, Generate) in both directions + +**Negative:** +- Handler must store pre-transfer baselines (`actionBaseline` map) for potential rollback — one extra object reference per active transfer +- Reverting `localObjects` means the handler discards its local modification on failure, which is a deliberate policy choice: if the receiver can't deserialize it, retrying would just fail again + +**Trade-offs:** +- The `actionBaseline` entries persist until overwritten by the next transfer for the same ID, rather than being cleaned up immediately on success. The memory cost is bounded by the number of active object IDs and is comparable to `remoteObjects` itself +- The inline-Visit optimization (bundling tree data with Visit request/response to eliminate GetObject round-trips) remains a complementary performance improvement that could be pursued independently diff --git a/doc/adr/README.md b/doc/adr/README.md index d6e5e0baae..f1fe1294c9 100644 --- a/doc/adr/README.md +++ b/doc/adr/README.md @@ -4,3 +4,7 @@ * [2. Naming recipes](0002-recipe-naming.md) * [3. OSS contributor's guidelines](0003-oss-contributors.md) * [4. Library migration recipe conventions](0004-library-migration-conventions.md) +* [5. Parser and LST conventions](0005-parser-lst-conventions.md) +* [6. Recipe marketplace CSV format](0006-recipe-marketplace-csv-format.md) +* [7. JavaScript templating engine enhancements](0007-javascript-templating-enhancements.md) +* [9. GetObject action field for error recovery](0009-getobject-action-field.md) diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java b/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java index 94d3b94701..fdd5cfe3c2 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/RewriteRpc.java @@ -469,22 +469,41 @@ public T getObject(String id, @Nullable String sourceFileType) { RpcReceiveQueue q = new RpcReceiveQueue( remoteRefs, - () -> send("GetObject", new GetObject(id, sourceFileType), GetObjectResponse.class), + () -> send("GetObject", new GetObject(id, sourceFileType, null), GetObjectResponse.class), sourceFileType, log.get() ); Object remoteObject; try { remoteObject = q.receive(before, null); + if (q.take().getState() != END_OF_OBJECT) { + throw new IllegalStateException("Expected END_OF_OBJECT"); + } } catch (Exception e) { - // Reset our tracking of the remote state so the next interaction - // forces a full object sync (ADD) instead of a delta (CHANGE). - remoteObjects.remove(id); + // Tell the handler to revert both remoteObjects and localObjects + // to the pre-transfer state + try { + send("GetObject", new GetObject(id, sourceFileType, "revert"), GetObjectResponse.class); + } catch (Exception revertError) { + PrintStream logFile = log.get(); + if (logFile != null) { + revertError.printStackTrace(logFile); + } + } + // Revert our tracking to match the handler's reverted state. + // The handler restored remoteObjects[id] to the pre-transfer + // value, so the requester must do the same to stay in sync. + if (before != null) { + remoteObjects.put(id, before); + } else { + remoteObjects.remove(id); + } + // Back out refs registered during this failed receive + for (Integer refId : q.getNewRefIds()) { + remoteRefs.remove(refId); + } throw e; } - if (q.take().getState() != END_OF_OBJECT) { - throw new IllegalStateException("Expected END_OF_OBJECT"); - } //noinspection ConstantValue if (remoteObject != null) { diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java index 4ebc87316b..3c5f6939c1 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcReceiveQueue.java @@ -34,6 +34,7 @@ public class RpcReceiveQueue { private final Supplier> pull; private final @Nullable String sourceFileType; private final @Nullable PrintStream log; + private final List newRefIds = new ArrayList<>(); public RpcReceiveQueue(Map refs, Supplier> pull, @Nullable String sourceFileType, @Nullable PrintStream log) { @@ -44,6 +45,13 @@ public RpcReceiveQueue(Map refs, Supplier> this.pull = pull; } + /** + * @return the ref IDs that were newly registered during this receive. + */ + public List getNewRefIds() { + return newRefIds; + } + public RpcObjectData take() { if (batch.isEmpty()) { List data = pull.get(); @@ -123,6 +131,7 @@ public T receive(@Nullable T before, @Nullable UnaryOperator onChange) { // immutable updates because of its cyclic nature, the before instance will ultimately // be the same as the after instance below. refs.put(ref, before); + newRefIds.add(ref); } } // Intentional fall-through... diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java index f2f085027b..6ee880b729 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/RpcSendQueue.java @@ -33,6 +33,7 @@ public class RpcSendQueue { private final IdentityHashMap refs; private final @Nullable String sourceFileType; private final boolean trace; + private final List newRefObjects = new ArrayList<>(); private @Nullable Object before; @@ -46,6 +47,13 @@ public RpcSendQueue(int batchSize, ThrowingConsumer> drain, this.trace = trace; } + /** + * @return the objects that were newly registered as refs during this send. + */ + public List getNewRefObjects() { + return newRefObjects; + } + public void put(RpcObjectData rpcObjectData) { batch.add(rpcObjectData); if (batch.size() == batchSize) { @@ -176,6 +184,7 @@ private void add(Object after, @Nullable Runnable onChange) { } ref = refs.size() + 1; refs.put(afterVal, ref); + newRefObjects.add(afterVal); } RpcCodec afterCodec = RpcCodec.forInstance(afterVal, sourceFileType); put(new RpcObjectData(ADD, getValueType(afterVal), diff --git a/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java b/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java index cefcf84176..3eb2ca54f6 100644 --- a/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java +++ b/rewrite-core/src/main/java/org/openrewrite/rpc/request/GetObject.java @@ -23,15 +23,14 @@ import org.openrewrite.rpc.RpcSendQueue; import java.io.PrintStream; -import java.util.ArrayList; -import java.util.IdentityHashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import static java.util.Collections.emptyList; import static org.openrewrite.rpc.RpcObjectData.State.DELETE; import static org.openrewrite.rpc.RpcObjectData.State.END_OF_OBJECT; @@ -42,6 +41,20 @@ public class GetObject implements RpcRequest { @Nullable String sourceFileType; + /** + * An action for the handler to perform instead of a normal data transfer. + * When null, this is a normal data-transfer request. + *

+ * Supported actions: + *

    + *
  • "revert" — sent by the receiver after a deserialization failure. + * The handler reverts both {@code remoteObjects} and {@code localObjects} + * for this ID to the pre-transfer state.
  • + *
+ */ + @Nullable + String action; + @RequiredArgsConstructor public static class Handler extends JsonRpcMethod { private static final ExecutorService forkJoin = ForkJoinPool.commonPool(); @@ -59,37 +72,107 @@ public static class Handler extends JsonRpcMethod { private final AtomicReference log; private final Supplier traceGetObject; - private final Map>> inProgressGetRpcObjects = new ConcurrentHashMap<>(); + @RequiredArgsConstructor + private static class InProgressSend { + final BlockingQueue> queue; + final @Nullable Object before; + final AtomicBoolean cancelled; + } + + private final Map inProgressGetRpcObjects = new ConcurrentHashMap<>(); + + /** + * Stores the pre-transfer {@code remoteObjects} value for each in-flight + * or recently completed transfer. Used by the "revert" action to restore + * both {@code remoteObjects} and {@code localObjects} to the state before + * the transfer started. + */ + private final Map actionBaseline = new HashMap<>(); + + /** + * Tracks objects newly registered as refs during each in-flight transfer. + * On revert or sender-side failure, these are removed from {@code localRefs}. + * On success (all data consumed), the entry is cleared. + */ + private final Map> pendingNewRefs = new ConcurrentHashMap<>(); @Override protected List handle(GetObject request) throws Exception { + String action = request.getAction(); + if (action != null) { + if ("revert".equals(action)) { + String id = request.getId(); + InProgressSend stale = inProgressGetRpcObjects.remove(id); + if (stale != null) { + stale.cancelled.set(true); + } + if (actionBaseline.containsKey(id)) { + Object before = actionBaseline.remove(id); + if (before != null) { + remoteObjects.put(id, before); + localObjects.put(id, before); + } else { + remoteObjects.remove(id); + localObjects.remove(id); + } + } + } + return emptyList(); + } + Object after = localObjects.get(request.getId()); if (after == null) { + // Clean up any stale in-progress send for this ID + InProgressSend stale = inProgressGetRpcObjects.remove(request.getId()); + if (stale != null) { + stale.cancelled.set(true); + } + List deleted = new ArrayList<>(2); deleted.add(new RpcObjectData(DELETE, null, null, null, traceGetObject.get())); deleted.add(new RpcObjectData(END_OF_OBJECT, null, null, null, traceGetObject.get())); return deleted; } - BlockingQueue> q = inProgressGetRpcObjects.computeIfAbsent(request.getId(), id -> { + Object currentBefore = remoteObjects.get(request.getId()); + + InProgressSend inProgress = inProgressGetRpcObjects.computeIfAbsent(request.getId(), id -> { + // Save the pre-transfer baseline for potential revert + actionBaseline.put(id, currentBefore); + BlockingQueue> batch = new ArrayBlockingQueue<>(1); - Object before = remoteObjects.get(id); + AtomicBoolean cancelled = new AtomicBoolean(false); RpcSendQueue sendQueue = new RpcSendQueue(batchSize.get(), batch::put, localRefs, request.getSourceFileType(), traceGetObject.get()); forkJoin.submit(() -> { try { - sendQueue.send(after, before, null); + sendQueue.send(after, currentBefore, null); + + // Track newly registered refs for potential revert + if (!sendQueue.getNewRefObjects().isEmpty()) { + pendingNewRefs.put(id, sendQueue.getNewRefObjects()); + } - // All the data has been sent, and the remote should have received - // the full tree, so update our understanding of the remote state - // of this tree. - remoteObjects.put(id, after); + // Optimistically update remoteObjects — the receiver is + // expected to send action="revert" if deserialization fails, + // which will roll this back. + if (!cancelled.get()) { + remoteObjects.put(id, after); + } } catch (Throwable t) { // Reset our tracking of the remote state so the next interaction // forces a full object sync (ADD) instead of a delta (CHANGE) // against the stale, partially-sent baseline. remoteObjects.remove(id); + // Remove the baseline so a subsequent "revert" from the + // receiver doesn't restore the entry we just removed. + actionBaseline.remove(id); + // Back out refs registered during this failed send + for (Object obj : sendQueue.getNewRefObjects()) { + localRefs.remove(obj); + } + pendingNewRefs.remove(id); PrintStream logFile = log.get(); //noinspection ConstantValue if (logFile != null) { @@ -101,12 +184,14 @@ protected List handle(GetObject request) throws Exception { } return 0; }); - return batch; + return new InProgressSend(batch, currentBefore, cancelled); }); - List batch = q.take(); + List batch = inProgress.queue.take(); if (batch.get(batch.size() - 1).getState() == END_OF_OBJECT) { inProgressGetRpcObjects.remove(request.getId()); + // Transfer completed successfully — refs are committed + pendingNewRefs.remove(request.getId()); } return batch; diff --git a/rewrite-javascript/rewrite/src/rpc/queue.ts b/rewrite-javascript/rewrite/src/rpc/queue.ts index bc43d667ce..db193dd84c 100644 --- a/rewrite-javascript/rewrite/src/rpc/queue.ts +++ b/rewrite-javascript/rewrite/src/rpc/queue.ts @@ -112,6 +112,7 @@ export class RpcCodecs { export class RpcSendQueue { private q: RpcObjectData[] = []; + private readonly _newRefIds: number[] = []; private before?: any; @@ -120,6 +121,10 @@ export class RpcSendQueue { private readonly trace: boolean) { } + get newRefIds(): number[] { + return this._newRefIds; + } + async generate(after: any, before: any): Promise { await this.send(after, before); @@ -233,6 +238,7 @@ export class RpcSendQueue { return; } ref = this.refs.create(after); + this._newRefIds.push(ref); } let afterCodec = onChange ? undefined : RpcCodecs.forInstance(after, this.sourceFileType); this.put({ @@ -275,6 +281,7 @@ export class RpcSendQueue { export class RpcReceiveQueue { private batch: RpcObjectData[] = []; + private readonly _newRefIds: number[] = []; constructor(private readonly refs: Map, private readonly sourceFileType: string | undefined, @@ -283,6 +290,10 @@ export class RpcReceiveQueue { private readonly trace: boolean) { } + get newRefIds(): number[] { + return this._newRefIds; + } + async take(): Promise { if (this.batch.length === 0) { this.batch = await this.pull(); @@ -336,6 +347,7 @@ export class RpcReceiveQueue { // immutable updates because of its cyclic nature, the before instance will ultimately // be the same as the after instance below. this.refs.set(ref, before); + this._newRefIds.push(ref); } } // Intentional fall-through... diff --git a/rewrite-javascript/rewrite/src/rpc/request/get-object.ts b/rewrite-javascript/rewrite/src/rpc/request/get-object.ts index cb9ce154df..400e8105fa 100644 --- a/rewrite-javascript/rewrite/src/rpc/request/get-object.ts +++ b/rewrite-javascript/rewrite/src/rpc/request/get-object.ts @@ -20,7 +20,8 @@ import {extractSourcePath, withMetrics} from "./metrics"; export class GetObject { constructor(private readonly id: string, - private readonly sourceFileType?: string) { + private readonly sourceFileType?: string, + private readonly action?: string) { } static handle( @@ -33,6 +34,8 @@ export class GetObject { metricsCsv?: string, ): void { const pendingData = new Map(); + const actionBaseline = new Map(); + const pendingNewRefs = new Map(); connection.onRequest( new rpc.RequestType("GetObject"), @@ -41,6 +44,32 @@ export class GetObject { metricsCsv, (context) => async request => { const objId = request.id; + + // Handle actions from the receiver + if (request.action) { + if (request.action === 'revert') { + const before = actionBaseline.get(objId); + actionBaseline.delete(objId); + pendingData.delete(objId); + if (before !== undefined) { + remoteObjects.set(objId, before); + localObjects.set(objId, before); + } else { + remoteObjects.delete(objId); + localObjects.delete(objId); + } + const newRefs = pendingNewRefs.get(objId); + if (newRefs) { + for (const refId of newRefs) { + localRefs.deleteByRefId(refId); + } + pendingNewRefs.delete(objId); + } + } + context.target = ''; + return []; + } + if (!localObjects.has(objId)) { context.target = ''; return [ @@ -63,10 +92,19 @@ export class GetObject { const after = obj; const before = remoteObjects.get(objId); - allData = await new RpcSendQueue(localRefs, request.sourceFileType, trace()) - .generate(after, before); + // Save baseline for potential revert + actionBaseline.set(objId, before); + + const sendQueue = new RpcSendQueue(localRefs, request.sourceFileType, trace()); + allData = await sendQueue.generate(after, before); pendingData.set(objId, allData); + // Track newly registered refs for potential revert + if (sendQueue.newRefIds.length > 0) { + pendingNewRefs.set(objId, sendQueue.newRefIds); + } + + // Optimistic update — receiver sends action="revert" on failure remoteObjects.set(objId, after); } @@ -75,6 +113,8 @@ export class GetObject { // If we've sent all data, remove from pending if (allData.length === 0) { pendingData.delete(objId); + // Transfer completed successfully — refs are committed + pendingNewRefs.delete(objId); } return batch; diff --git a/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts b/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts index 9f66799046..65c72311e3 100644 --- a/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts +++ b/rewrite-javascript/rewrite/src/rpc/rewrite-rpc.ts @@ -140,19 +140,38 @@ export class RewriteRpc { let remoteObject: P; try { remoteObject = await q.receive

(before as P); + + const eof = (await q.take()); + if (eof.state !== RpcObjectState.END_OF_OBJECT) { + RpcObjectData.logTrace(eof, this.traceGetObject.receive, this.logger); + throw new Error(`Expected END_OF_OBJECT but got: ${eof.state}`); + } } catch (e) { - // Reset our tracking of the remote state so the next interaction - // forces a full object sync (ADD) instead of a delta (CHANGE). - this.remoteObjects.delete(id); + // Tell the handler to revert both remoteObjects and localObjects + // to the pre-transfer state + try { + await this.connection.sendRequest( + new rpc.RequestType("GetObject"), + new GetObject(id, sourceFileType, 'revert'), + ); + } catch { + // Best-effort revert + } + // Revert our tracking to match the handler's reverted state. + // The handler restored remoteObjects[id] to the pre-transfer + // value, so the requester must do the same to stay in sync. + if (before !== undefined) { + this.remoteObjects.set(id, before); + } else { + this.remoteObjects.delete(id); + } + // Back out refs registered during this failed receive + for (const refId of q.newRefIds) { + this.remoteRefs.delete(refId); + } throw e; } - const eof = (await q.take()); - if (eof.state !== RpcObjectState.END_OF_OBJECT) { - RpcObjectData.logTrace(eof, this.traceGetObject.receive, this.logger); - throw new Error(`Expected END_OF_OBJECT but got: ${eof.state}`); - } - this.remoteObjects.set(id, remoteObject); this.localObjects.set(id, remoteObject); diff --git a/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py b/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py index 946aa8c923..15529d02c1 100644 --- a/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py +++ b/rewrite-python/rewrite/src/rewrite/rpc/receive_queue.py @@ -72,6 +72,12 @@ def __init__( self._source_file_type = source_file_type self._pull = pull self._trace = trace + self._new_ref_ids: List[int] = [] + + @property + def new_ref_ids(self) -> List[int]: + """Ref IDs that were newly registered during this receive.""" + return self._new_ref_ids def take(self) -> RpcObjectData: """Take the next message from the queue, fetching more if needed.""" @@ -161,6 +167,7 @@ def receive( if ref is not None: # Store for future references (handles cyclic graphs) self._refs[ref] = before + self._new_ref_ids.append(ref) # Fall through to CHANGE for field-by-field deserialization return self._do_change(before, on_change, message, ref) diff --git a/rewrite-python/rewrite/src/rewrite/rpc/server.py b/rewrite-python/rewrite/src/rewrite/rpc/server.py index f291c89fb1..2b550e7d49 100644 --- a/rewrite-python/rewrite/src/rewrite/rpc/server.py +++ b/rewrite-python/rewrite/src/rewrite/rpc/server.py @@ -57,6 +57,8 @@ remote_objects: Dict[str, Any] = {} # Remote refs - maps reference IDs to objects for cyclic graph handling remote_refs: Dict[int, Any] = {} +# Action baseline - stores pre-transfer remote_objects values for potential revert +_action_baseline: Dict[str, Any] = {} # Request ID counter for outgoing requests _request_id_counter = 0 @@ -151,6 +153,8 @@ def get_object_from_java(obj_id: str, source_file_type: Optional[str] = None) -> # Track whether we've received the complete object received_end = False + before = remote_objects.get(obj_id) + def pull_batch() -> List[Dict[str, Any]]: """Pull the next batch of RpcObjectData from Java. @@ -161,7 +165,7 @@ def pull_batch() -> List[Dict[str, Any]]: IMPORTANT: We filter out END_OF_OBJECT from the returned batch to prevent it from being accidentally consumed during nested operations (like receive_list expecting positions). Java's RewriteRpc.java explicitly consumes END_OF_OBJECT - after receive() completes (line 474), and we do the same by tracking received_end. + after receive() completes, and we do the same by tracking received_end. """ nonlocal received_end @@ -169,10 +173,11 @@ def pull_batch() -> List[Dict[str, Any]]: return [] # Request the next batch from Java - batch = send_request('GetObject', { + request = { 'id': obj_id, 'sourceFileType': source_file_type - }) + } + batch = send_request('GetObject', request) if not batch: received_end = True @@ -190,10 +195,6 @@ def pull_batch() -> List[Dict[str, Any]]: q = RpcReceiveQueue(remote_refs, source_file_type, pull_batch, trace=_trace_rpc) receiver = PythonRpcReceiver() - # Get the "before" state - our understanding of what Java had - # This is used to apply diffs from the GetObject response - before = remote_objects.get(obj_id) - # Receive and deserialize the object (applies diffs to before state) try: obj = receiver.receive(before, q) @@ -207,9 +208,26 @@ def pull_batch() -> List[Dict[str, Any]]: if not received_end: raise RuntimeError(f"Did not receive END_OF_OBJECT marker for object {obj_id}") except Exception: - # Reset our tracking of the remote state so the next interaction - # forces a full object sync (ADD) instead of a delta (CHANGE). - remote_objects.pop(obj_id, None) + # Tell the handler to revert both remoteObjects and localObjects + # to the pre-transfer state + try: + send_request('GetObject', { + 'id': obj_id, + 'sourceFileType': source_file_type, + 'action': 'revert' + }) + except Exception: + pass # Best-effort revert + # Revert our tracking to match the handler's reverted state. + # The handler restored remoteObjects[id] to the pre-transfer + # value, so the requester must do the same to stay in sync. + if before is not None: + remote_objects[obj_id] = before + else: + remote_objects.pop(obj_id, None) + # Back out refs registered during this failed receive + for ref_id in q.new_ref_ids: + remote_refs.pop(ref_id, None) raise if obj is not None: @@ -466,14 +484,28 @@ def handle_get_object(params: dict) -> List[dict]: This serializes an object for RPC transfer as RpcObjectData[]. Returns list of RpcObjectData objects that Java can deserialize. - After sending, we update remote_objects to track that the remote (Java) - now has this version of the object. This is essential for the diff-based - RPC protocol to work correctly. + The remoteObjects update is optimistic — on failure, the receiver + sends action="revert" to roll back both remote and local state. """ obj_id = params.get('id') source_file_type = params.get('sourceFileType') + action = params.get('action') + if obj_id is None: return [{'state': 'DELETE'}, {'state': 'END_OF_OBJECT'}] + + # Handle actions from the receiver + if action is not None: + if action == 'revert': + before = _action_baseline.pop(obj_id, None) + if before is not None: + remote_objects[obj_id] = before + local_objects[obj_id] = before + else: + remote_objects.pop(obj_id, None) + local_objects.pop(obj_id, None) + return [] + obj = local_objects.get(obj_id) logger.debug(f"handle_get_object: id={obj_id}, type={type(obj).__name__ if obj else 'None'}") @@ -489,10 +521,13 @@ def handle_get_object(params: dict) -> List[dict]: # Get the "before" state - what we previously sent to Java before = remote_objects.get(obj_id) + # Save baseline for potential revert + _action_baseline[obj_id] = before + q = RpcSendQueue(source_file_type) result = q.generate(obj, before) - # Update remote_objects to track that Java now has this version + # Optimistically update — receiver sends action="revert" on failure remote_objects[obj_id] = obj return result @@ -572,6 +607,7 @@ def handle_reset(params: dict) -> bool: local_objects.clear() remote_objects.clear() remote_refs.clear() + _action_baseline.clear() _prepared_recipes.clear() _execution_contexts.clear() _recipe_accumulators.clear()