Skip to content

Commit e83b273

Browse files
authored
fix(shell-api): implement own map/forEach for cursors MONGOSH-936 (#1062)
Because the driver expects `.map()` and `.forEach()` callbacks to work synchronously, but mongosh users can use pseudo-synchronous async functions, the driver cannot properly handle the functions we pass to it. In order to fix this, if we don’t jump through a lot of extra hoops, it’s easiest to just implement `.map()` and `.forEach()` ourselves.
1 parent bbeffc7 commit e83b273

File tree

11 files changed

+93
-113
lines changed

11 files changed

+93
-113
lines changed

packages/cli-repl/test/e2e.spec.ts

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,59 @@ describe('e2e', function() {
446446
});
447447
});
448448
});
449+
450+
describe('cursor transform operations', () => {
451+
beforeEach(async() => {
452+
await shell.executeLine(`use ${dbName}`);
453+
await shell.executeLine('for (let i = 0; i < 3; i++) db.coll.insertOne({i})');
454+
});
455+
456+
it('works with .map() with immediate .toArray() iteration', async() => {
457+
const result = await shell.executeLine(`const cs = db.coll.find().map((doc) => {
458+
print('mapped');
459+
return db.coll.find({_id:doc._id}).toArray()
460+
}); print('after'); cs.toArray()`);
461+
expect(result).to.include('after');
462+
expect(result).to.include('mapped');
463+
expect(result).to.include('i: 1');
464+
});
465+
466+
it('works with .map() with later .toArray() iteration', async() => {
467+
const before = await shell.executeLine(`const cs = db.coll.find().map((doc) => {
468+
print('mapped');
469+
return db.coll.find({_id:doc._id}).toArray()
470+
}); print('after');`);
471+
expect(before).to.include('after');
472+
expect(before).not.to.include('mapped');
473+
const result = await shell.executeLine('cs.toArray()');
474+
expect(result).to.include('mapped');
475+
expect(result).to.include('i: 1');
476+
});
477+
478+
it('works with .map() with implicit iteration', async() => {
479+
const before = await shell.executeLine(`const cs = db.coll.find().map((doc) => {
480+
print('mapped');
481+
return db.coll.findOne({_id:doc._id});
482+
}); print('after');`);
483+
expect(before).to.include('after');
484+
expect(before).not.to.include('mapped');
485+
const result = await shell.executeLine('cs');
486+
expect(result).to.include('mapped');
487+
expect(result).to.include('i: 1');
488+
});
489+
490+
it('works with .forEach() iteration', async() => {
491+
await shell.executeLine('out = [];');
492+
const before = await shell.executeLine(`db.coll.find().forEach((doc) => {
493+
print('enter forEach');
494+
out.push(db.coll.findOne({_id:doc._id}));
495+
print('leave forEach');
496+
}); print('after');`);
497+
expect(before).to.match(/(enter forEach\r?\nleave forEach\r?\n){3}after/);
498+
const result = await shell.executeLine('out[1]');
499+
expect(result).to.include('i: 1');
500+
});
501+
});
449502
});
450503

451504
describe('with --host', () => {

packages/java-shell/src/main/kotlin/com/mongodb/mongosh/service/BaseMongoIterableHelper.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ internal abstract class BaseMongoIterableHelper<T : MongoIterable<*>>(val iterab
1515
abstract val converters: Map<String, (T, Any?) -> Either<T>>
1616
abstract val defaultConverter: (T, String, Any?) -> Either<T>
1717

18-
fun map(function: Value): BaseMongoIterableHelper<*> {
19-
return helper(iterable.map { v ->
20-
converter.toJava(function.execute(converter.toJs(v))).value
21-
}, converter)
22-
}
23-
2418
fun itcount(): Int {
2519
return iterable.count()
2620
}
@@ -202,4 +196,4 @@ internal fun helper(iterable: MongoIterable<out Any?>, converter: MongoShellConv
202196
is AggregateIterable -> AggregateIterableHelper(iterable, converter, Document(), null)
203197
else -> MongoIterableHelper(iterable, converter, Document())
204198
}
205-
}
199+
}

packages/java-shell/src/main/kotlin/com/mongodb/mongosh/service/Cursor.kt

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,6 @@ internal class Cursor(private var helper: BaseMongoIterableHelper<*>, private va
103103
return converter.toJs(helper.explain(verbosity))
104104
}
105105

106-
@HostAccess.Export
107-
override fun forEach(func: Value) {
108-
if (!func.canExecute()) {
109-
throw IllegalArgumentException("Expected one argument of type function. Got: $func")
110-
}
111-
getOrCreateIterator().forEach { v ->
112-
func.execute(converter.toJs(v))
113-
}
114-
}
115-
116106
@HostAccess.Export
117107
override fun hasNext(): Boolean = getOrCreateIterator().hasNext()
118108

@@ -148,16 +138,6 @@ internal class Cursor(private var helper: BaseMongoIterableHelper<*>, private va
148138
return this
149139
}
150140

151-
@HostAccess.Export
152-
override fun map(func: Value): Cursor {
153-
checkQueryNotExecuted()
154-
if (!func.canExecute()) {
155-
throw IllegalArgumentException("Expected one argument of type function. Got: $func")
156-
}
157-
helper = helper.map(func)
158-
return this
159-
}
160-
161141
@HostAccess.Export
162142
override fun max(v: Value): Cursor {
163143
checkQueryNotExecuted()

packages/java-shell/src/main/kotlin/com/mongodb/mongosh/service/ServiceProviderCursor.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@ interface ServiceProviderCursor {
1313
fun collation(v: Value): ServiceProviderCursor
1414
fun comment(v: String): ServiceProviderCursor
1515
fun count(): Long
16-
fun forEach(func: Value)
1716
fun hasNext(): Boolean
1817
fun hint(v: Value): ServiceProviderCursor
1918
fun isExhausted(): Boolean
2019
fun itcount(): Int
2120
fun limit(v: Int): ServiceProviderCursor
22-
fun map(func: Value): ServiceProviderCursor
2321
fun max(v: Value): ServiceProviderCursor
2422
fun maxTimeMS(v: Long): ServiceProviderCursor
2523
fun maxAwaitTimeMS(value: Int): ServiceProviderCursor

packages/shell-api/src/abstract-cursor.ts

Lines changed: 29 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,15 @@ import { iterate, validateExplainableVerbosity, markAsExplainOutput } from './he
2121
export abstract class AbstractCursor<CursorType extends ServiceProviderAggregationCursor | ServiceProviderCursor> extends ShellApiWithMongoClass {
2222
_mongo: Mongo;
2323
_cursor: CursorType;
24+
_transform: ((doc: any) => any) | null;
2425

2526
_currentIterationResult: CursorIterationResult | null = null;
26-
_mapError: Error | null = null;
2727

2828
constructor(mongo: Mongo, cursor: CursorType) {
2929
super();
3030
this._mongo = mongo;
3131
this._cursor = cursor;
32-
}
33-
34-
// Wrap a function with checks before and after that verify whether a .map()
35-
// callback has resulted in an exception. Such an error would otherwise result
36-
// in an uncaught exception, bringing the whole process down.
37-
// The downside to this is that errors will not actually be visible until
38-
// the caller tries to interact with this cursor in a way that triggers
39-
// these checks. Since that is also the behavior for errors coming from the
40-
// database server, it makes sense to match that.
41-
// Ideally, this kind of code could be lifted into the driver (NODE-3231 and
42-
// NODE-3232 are the tickets for that).
43-
async _withCheckMapError<Ret>(fn: () => Ret): Promise<Ret> {
44-
if (this._mapError) {
45-
// If an error has already occurred, we don't want to call the function
46-
// at all.
47-
throw this._mapError;
48-
}
49-
// eslint-disable-next-line @typescript-eslint/await-thenable
50-
const ret = await fn();
51-
if (this._mapError) {
52-
// If an error occurred during the function, we don't want to forward its
53-
// results.
54-
throw this._mapError;
55-
}
56-
return ret;
32+
this._transform = null;
5733
}
5834

5935
/**
@@ -82,32 +58,27 @@ export abstract class AbstractCursor<CursorType extends ServiceProviderAggregati
8258
}
8359

8460
@returnsPromise
85-
async forEach(f: (doc: Document) => void): Promise<void> {
86-
// Work around https://jira.mongodb.org/browse/NODE-3231
87-
let exception;
88-
const wrapped = (doc: Document): boolean | undefined => {
89-
try {
90-
f(doc);
91-
return undefined;
92-
} catch (err) {
93-
exception = err;
94-
return false; // Stop iteration.
61+
async forEach(f: (doc: Document) => void | boolean | Promise<void> | Promise<boolean>): Promise<void> {
62+
// Do not use the driver method because it does not have Promise support.
63+
for await (const doc of this) {
64+
if ((await f(doc)) === false) {
65+
break;
9566
}
96-
};
97-
await this._cursor.forEach(wrapped);
98-
if (exception) {
99-
throw exception;
10067
}
10168
}
10269

10370
@returnsPromise
10471
async hasNext(): Promise<boolean> {
105-
return this._withCheckMapError(() => this._cursor.hasNext());
72+
return this._cursor.hasNext();
10673
}
10774

10875
@returnsPromise
10976
async tryNext(): Promise<Document | null> {
110-
return this._withCheckMapError(() => this._cursor.tryNext());
77+
let result = await this._cursor.tryNext();
78+
if (result !== null && this._transform !== null) {
79+
result = await this._transform(result);
80+
}
81+
return result;
11182
}
11283

11384
async* [Symbol.asyncIterator]() {
@@ -136,7 +107,11 @@ export abstract class AbstractCursor<CursorType extends ServiceProviderAggregati
136107

137108
@returnsPromise
138109
async toArray(): Promise<Document[]> {
139-
return this._withCheckMapError(() => this._cursor.toArray());
110+
const result = [];
111+
for await (const doc of this) {
112+
result.push(doc);
113+
}
114+
return result;
140115
}
141116

142117
@returnType('this')
@@ -146,20 +121,12 @@ export abstract class AbstractCursor<CursorType extends ServiceProviderAggregati
146121

147122
@returnType('this')
148123
map(f: (doc: Document) => Document): this {
149-
// Work around https://jira.mongodb.org/browse/NODE-3232
150-
const wrapped = (doc: Document): Document => {
151-
if (this._mapError) {
152-
// These errors should never become visible to the user.
153-
return { __errored: true };
154-
}
155-
try {
156-
return f(doc);
157-
} catch (err) {
158-
this._mapError = err;
159-
return { __errored: true };
160-
}
161-
};
162-
this._cursor.map(wrapped);
124+
if (this._transform === null) {
125+
this._transform = f;
126+
} else {
127+
const g = this._transform;
128+
this._transform = (doc: any) => f(g(doc));
129+
}
163130
return this;
164131
}
165132

@@ -171,7 +138,11 @@ export abstract class AbstractCursor<CursorType extends ServiceProviderAggregati
171138

172139
@returnsPromise
173140
async next(): Promise<Document | null> {
174-
return this._withCheckMapError(() => this._cursor.next());
141+
let result = await this._cursor.next();
142+
if (result !== null && this._transform !== null) {
143+
result = await this._transform(result);
144+
}
145+
return result;
175146
}
176147

177148
@returnType('this')

packages/shell-api/src/aggregation-cursor.spec.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,6 @@ describe('AggregationCursor', () => {
6565
it('pretty returns the same cursor', () => {
6666
expect(cursor.pretty()).to.equal(cursor);
6767
});
68-
69-
it('calls wrappee.map with arguments', () => {
70-
const arg = {};
71-
cursor.map(arg);
72-
expect(wrappee.map).to.have.callCount(1);
73-
});
7468
});
7569

7670
describe('Cursor Internals', () => {

packages/shell-api/src/collection.spec.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,16 @@ describe('Collection', () => {
180180
});
181181

182182
it('returns an AggregationCursor that wraps the service provider one', async() => {
183-
const toArrayResult = [];
184-
serviceProviderCursor.toArray.resolves(toArrayResult);
183+
const toArrayResult = [{ foo: 'bar' }];
184+
serviceProviderCursor.tryNext.onFirstCall().resolves({ foo: 'bar' });
185+
serviceProviderCursor.tryNext.onSecondCall().resolves(null);
185186
serviceProvider.aggregate.returns(serviceProviderCursor);
186187

187188
const cursor = await collection.aggregate([{
188189
$piplelineStage: {}
189190
}]);
190191

191-
expect(await (cursor as AggregationCursor).toArray()).to.equal(toArrayResult);
192+
expect(await (cursor as AggregationCursor).toArray()).to.deep.equal(toArrayResult);
192193
});
193194

194195
it('throws if serviceProvider.aggregate rejects', async() => {

packages/shell-api/src/cursor.spec.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,6 @@ describe('Cursor', () => {
7070
expect(cursor.pretty()).to.equal(cursor);
7171
});
7272

73-
it('calls wrappee.map with arguments', () => {
74-
const arg = {};
75-
cursor.map(arg);
76-
expect(wrappee.map).to.have.callCount(1);
77-
});
78-
7973
it('has the correct metadata', () => {
8074
expect(cursor.collation.serverVersions).to.deep.equal(['3.4.0', ServerVersions.latest]);
8175
});

packages/shell-api/src/database.spec.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,13 @@ describe('Database', () => {
297297
});
298298

299299
it('returns an AggregationCursor that wraps the service provider one', async() => {
300-
const toArrayResult = [];
301-
serviceProviderCursor.toArray.resolves(toArrayResult);
300+
const toArrayResult = [{ foo: 'bar' }];
301+
serviceProviderCursor.tryNext.onFirstCall().resolves({ foo: 'bar' });
302+
serviceProviderCursor.tryNext.onSecondCall().resolves(null);
302303
serviceProvider.aggregateDb.returns(serviceProviderCursor);
303304

304305
const cursor = await database.aggregate([{ $piplelineStage: {} }]);
305-
expect(await cursor.toArray()).to.equal(toArrayResult);
306+
expect(await cursor.toArray()).to.deep.equal(toArrayResult);
306307
});
307308

308309
it('throws if serviceProvider.aggregateDb rejects', async() => {

packages/shell-api/src/explainable-cursor.spec.ts

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ describe('ExplainableCursor', () => {
5555
expect(eCursor.map()).to.equal(eCursor);
5656
});
5757

58-
it('calls wrappee.map with arguments', () => {
59-
const arg = () => {};
60-
eCursor.map(arg);
61-
expect(wrappee.map).to.have.callCount(1);
62-
});
63-
6458
it('has the correct metadata', () => {
6559
expect(eCursor.collation.serverVersions).to.deep.equal(['3.4.0', ServerVersions.latest]);
6660
});

0 commit comments

Comments
 (0)