Skip to content

Commit 2ad9678

Browse files
authored
Add stream method to ArrowFlight service (#5)
* add stream method to arrow flight service * further restrict the schema shape if none provided * remove unused import * rename to streamQuery
1 parent 29b90c2 commit 2ad9678

File tree

2 files changed

+309
-84
lines changed

2 files changed

+309
-84
lines changed

packages/amp/src/ArrowFlight.ts

Lines changed: 202 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,23 @@ import {
99
type Transport as ConnectTransport
1010
} from "@connectrpc/connect"
1111
import * as Arr from "effect/Array"
12-
import * as Console from "effect/Console"
12+
import * as Cause from "effect/Cause"
1313
import * as Context from "effect/Context"
1414
import * as Effect from "effect/Effect"
15+
import { identity } from "effect/Function"
1516
import * as Layer from "effect/Layer"
1617
import * as Option from "effect/Option"
18+
import * as Predicate from "effect/Predicate"
1719
import * as Redacted from "effect/Redacted"
1820
import * as Schema from "effect/Schema"
1921
import * as Stream from "effect/Stream"
2022
import { Auth } from "./Auth.ts"
21-
import { decodeDictionaryBatch, decodeRecordBatch, DictionaryRegistry } from "./internal/arrow-flight-ipc/Decoder.ts"
23+
import { decodeRecordBatch, DictionaryRegistry } from "./internal/arrow-flight-ipc/Decoder.ts"
2224
import { recordBatchToJson } from "./internal/arrow-flight-ipc/Json.ts"
23-
import { readColumnValues } from "./internal/arrow-flight-ipc/Readers.ts"
24-
import { parseDictionaryBatch, parseRecordBatch } from "./internal/arrow-flight-ipc/RecordBatch.ts"
25+
import { parseRecordBatch } from "./internal/arrow-flight-ipc/RecordBatch.ts"
2526
import { type ArrowSchema, getMessageType, MessageHeaderType, parseSchema } from "./internal/arrow-flight-ipc/Schema.ts"
26-
import type { AuthInfo } from "./Models.ts"
27+
import type { AuthInfo, BlockRange, RecordBatchMetadata } from "./Models.ts"
28+
import { RecordBatchMetadataFromUint8Array } from "./Models.ts"
2729
import { FlightDescriptor_DescriptorType, FlightDescriptorSchema, FlightService } from "./Protobuf/Flight_pb.ts"
2830
import { CommandStatementQuerySchema } from "./Protobuf/FlightSql_pb.ts"
2931

@@ -201,6 +203,44 @@ export class ParseSchemaError extends Schema.TaggedError<ParseSchemaError>(
201203
cause: Schema.Defect
202204
}) {}
203205

206+
// =============================================================================
207+
// Types
208+
// =============================================================================
209+
210+
/**
211+
* Represents the result received from the `ArrowFlight` service when a query
212+
* is successfully executed.
213+
*/
214+
export interface QueryResult<A> {
215+
readonly data: A
216+
readonly metadata: RecordBatchMetadata
217+
}
218+
219+
/**
220+
* Represents options that can be passed to `ArrowFlight.query` to control how
221+
* the query is executed.
222+
*/
223+
export interface QueryOptions {
224+
readonly schema?: Schema.Any | undefined
225+
/**
226+
* Sets the `stream` Amp query setting to `true`.
227+
*/
228+
readonly stream?: boolean | undefined
229+
/**
230+
* A set of block ranges which will be converted into a resume watermark
231+
* header and sent with the query. This allows resumption of streaming queries.
232+
*/
233+
readonly resumeWatermark?: ReadonlyArray<BlockRange> | undefined
234+
}
235+
236+
/**
237+
* A utility type to extract the result type for a query.
238+
*/
239+
export type ExtractQueryResult<Options extends QueryOptions> = Options extends {
240+
readonly schema: Schema.Schema<infer _A, infer _I, infer _R>
241+
} ? QueryResult<_A>
242+
: Record<string, unknown>
243+
204244
// =============================================================================
205245
// Arrow Flight Service
206246
// =============================================================================
@@ -209,112 +249,169 @@ export class ParseSchemaError extends Schema.TaggedError<ParseSchemaError>(
209249
/**
210250
* A service which can be used to execute queries against an Arrow Flight API.
211251
*/
212-
export class ArrowFlight extends Context.Tag("@edgeandnode/amp/ArrowFlight")<ArrowFlight, {
252+
export class ArrowFlight extends Context.Tag("Amp/ArrowFlight")<ArrowFlight, {
213253
/**
214254
* The Connect `Client` that will be used to execute Arrow Flight queries.
215255
*/
216256
readonly client: Client<typeof FlightService>
217257

218258
/**
219-
* Executes an Arrow Flight SQL query and returns
259+
* Executes an Arrow Flight SQL query and returns a all results as an array.
260+
*/
261+
readonly query: <Options extends QueryOptions>(
262+
sql: string,
263+
options?: Options
264+
) => Effect.Effect<ReadonlyArray<ExtractQueryResult<Options>>, ArrowFlightError>
265+
266+
/**
267+
* Executes an Arrow Flight SQL query and returns a stream of results.
220268
*/
221-
readonly query: (sql: string) => Effect.Effect<unknown, ArrowFlightError>
269+
readonly streamQuery: <Options extends QueryOptions>(
270+
sql: string,
271+
options?: Options
272+
) => Stream.Stream<ExtractQueryResult<Options>, ArrowFlightError>
222273
}>() {}
223274

224275
const make = Effect.gen(function*() {
225276
const auth = yield* Effect.serviceOption(Auth)
226277
const transport = yield* Transport
227278
const client = createClient(FlightService, transport)
228279

280+
const decodeRecordBatchMetadata = Schema.decode(RecordBatchMetadataFromUint8Array)
281+
229282
/**
230283
* Execute a SQL query and return a stream of rows.
231284
*/
232-
const query = Effect.fn("ArrowFlight.request")(function*(query: string) {
233-
// Setup the query context with authentication information, if available
234-
const contextValues = createContextValues()
235-
const authInfo = Option.isSome(auth)
236-
? yield* auth.value.getCachedAuthInfo
237-
: Option.none<AuthInfo>()
238-
if (Option.isSome(authInfo)) {
239-
contextValues.set(AuthInfoContextKey, authInfo.value)
240-
}
285+
const streamQuery = (query: string, options?: QueryOptions) =>
286+
Effect.gen(function*() {
287+
const contextValues = createContextValues()
288+
const authInfo = Option.isSome(auth)
289+
? yield* auth.value.getCachedAuthInfo
290+
: Option.none<AuthInfo>()
291+
292+
// Setup the query context with authentication information, if available
293+
if (Option.isSome(authInfo)) {
294+
contextValues.set(AuthInfoContextKey, authInfo.value)
295+
}
241296

242-
const cmd = create(CommandStatementQuerySchema, { query })
243-
const any = anyPack(CommandStatementQuerySchema, cmd)
244-
const desc = create(FlightDescriptorSchema, {
245-
type: FlightDescriptor_DescriptorType.CMD,
246-
cmd: toBinary(AnySchema, any)
247-
})
248-
249-
const flightInfo = yield* Effect.tryPromise({
250-
try: (signal) => client.getFlightInfo(desc, { signal, contextValues }),
251-
catch: (cause) => new RpcError({ cause, method: "getFlightInfo" })
252-
})
253-
254-
if (flightInfo.endpoint.length !== 1) {
255-
return yield* flightInfo.endpoint.length <= 0
256-
? new NoEndpointsError({ query })
257-
: new MultipleEndpointsError({ query })
258-
}
297+
const cmd = create(CommandStatementQuerySchema, { query })
298+
const any = anyPack(CommandStatementQuerySchema, cmd)
299+
const desc = create(FlightDescriptorSchema, {
300+
type: FlightDescriptor_DescriptorType.CMD,
301+
cmd: toBinary(AnySchema, any)
302+
})
303+
304+
// Setup the query headers
305+
const headers = new Headers()
306+
if (Predicate.isNotUndefined(options?.stream)) {
307+
headers.set("amp-stream", "true")
308+
}
309+
if (Predicate.isNotUndefined(options?.resumeWatermark)) {
310+
headers.set("amp-resume", blockRangesToResumeWatermark(options.resumeWatermark))
311+
}
259312

260-
const { ticket } = flightInfo.endpoint[0]!
313+
const flightInfo = yield* Effect.tryPromise({
314+
try: (signal) => client.getFlightInfo(desc, { contextValues, headers, signal }),
315+
catch: (cause) => new RpcError({ cause, method: "getFlightInfo" })
316+
})
261317

262-
if (ticket === undefined) {
263-
return yield* new TicketNotFoundError({ query })
264-
}
318+
if (flightInfo.endpoint.length !== 1) {
319+
return yield* flightInfo.endpoint.length <= 0
320+
? new NoEndpointsError({ query })
321+
: new MultipleEndpointsError({ query })
322+
}
265323

266-
const flightDataStream = Stream.unwrapScoped(Effect.gen(function*() {
267-
const controller = yield* Effect.acquireRelease(
268-
Effect.sync(() => new AbortController()),
269-
(controller) => Effect.sync(() => controller.abort())
270-
)
271-
return Stream.fromAsyncIterable(
272-
client.doGet(ticket, { signal: controller.signal, contextValues }),
273-
(cause) => new RpcError({ cause, method: "doGet" })
324+
const { ticket } = flightInfo.endpoint[0]!
325+
326+
if (ticket === undefined) {
327+
return yield* new TicketNotFoundError({ query })
328+
}
329+
330+
const flightDataStream = Stream.unwrapScoped(Effect.gen(function*() {
331+
const controller = yield* Effect.acquireRelease(
332+
Effect.sync(() => new AbortController()),
333+
(controller) => Effect.sync(() => controller.abort())
334+
)
335+
return Stream.fromAsyncIterable(
336+
client.doGet(ticket, { signal: controller.signal, contextValues }),
337+
(cause) => new RpcError({ cause, method: "doGet" })
338+
)
339+
}))
340+
341+
let schema: ArrowSchema | undefined
342+
const dictionaryRegistry = new DictionaryRegistry()
343+
const dataSchema: Schema.Array$<
344+
Schema.Record$<
345+
typeof Schema.String,
346+
typeof Schema.Unknown
347+
>
348+
> = Schema.Array(
349+
options?.schema ?? Schema.Record({
350+
key: Schema.String,
351+
value: Schema.Unknown
352+
}) as any
274353
)
275-
}))
276-
277-
let schema: ArrowSchema | undefined
278-
const dictionaryRegistry = new DictionaryRegistry()
279-
280-
// Convert FlightData stream to a stream of rows
281-
return yield* flightDataStream.pipe(
282-
Stream.runForEach(Effect.fnUntraced(function*(flightData) {
283-
const messageType = yield* Effect.orDie(getMessageType(flightData))
284-
285-
switch (messageType) {
286-
case MessageHeaderType.SCHEMA: {
287-
schema = yield* parseSchema(flightData).pipe(
288-
Effect.mapError((cause) => new ParseSchemaError({ cause }))
289-
)
290-
break
354+
const decodeRecordBatchData = Schema.decode(dataSchema)
355+
356+
// Convert FlightData stream to a stream of rows
357+
return flightDataStream.pipe(
358+
Stream.mapEffect(Effect.fnUntraced(function*(flightData): Effect.fn.Return<
359+
Option.Option<QueryResult<any>>,
360+
ArrowFlightError
361+
> {
362+
const messageType = yield* Effect.orDie(getMessageType(flightData))
363+
364+
switch (messageType) {
365+
case MessageHeaderType.SCHEMA: {
366+
schema = yield* parseSchema(flightData).pipe(
367+
Effect.mapError((cause) => new ParseSchemaError({ cause }))
368+
)
369+
return Option.none<QueryResult<any>>()
370+
}
371+
case MessageHeaderType.DICTIONARY_BATCH: {
372+
// TODO: figure out what to do (if anything) with dictionary batches
373+
// const dictionaryBatch = yield* parseDictionaryBatch(flightData).pipe(
374+
// Effect.mapError((cause) => new ParseDictionaryBatchError({ cause }))
375+
// )
376+
// decodeDictionaryBatch(dictionaryBatch, flightData.dataBody, schema!, dictionaryRegistry, readColumnValues)
377+
return Option.none<QueryResult<any>>()
378+
}
379+
case MessageHeaderType.RECORD_BATCH: {
380+
const metadata = yield* decodeRecordBatchMetadata(flightData.appMetadata).pipe(
381+
Effect.mapError((cause) => new ParseRecordBatchError({ cause }))
382+
)
383+
const recordBatch = yield* parseRecordBatch(flightData).pipe(
384+
Effect.mapError((cause) => new ParseRecordBatchError({ cause }))
385+
)
386+
const decodedRecordBatch = decodeRecordBatch(recordBatch, flightData.dataBody, schema!)
387+
const json = recordBatchToJson(decodedRecordBatch, { dictionaryRegistry })
388+
const data = yield* decodeRecordBatchData(json).pipe(
389+
Effect.mapError((cause) => new ParseRecordBatchError({ cause }))
390+
)
391+
return Option.some({ data, metadata })
392+
}
291393
}
292-
case MessageHeaderType.DICTIONARY_BATCH: {
293-
const dictionaryBatch = yield* parseDictionaryBatch(flightData).pipe(
294-
Effect.mapError((cause) => new ParseDictionaryBatchError({ cause }))
295-
)
296-
decodeDictionaryBatch(dictionaryBatch, flightData.dataBody, schema!, dictionaryRegistry, readColumnValues)
297-
break
298-
}
299-
case MessageHeaderType.RECORD_BATCH: {
300-
const recordBatch = yield* parseRecordBatch(flightData).pipe(
301-
Effect.mapError((cause) => new ParseRecordBatchError({ cause }))
302-
)
303-
const decodedRecordBatch = decodeRecordBatch(recordBatch, flightData.dataBody, schema!)
304-
const json = recordBatchToJson(decodedRecordBatch, { dictionaryRegistry })
305-
yield* Console.dir(json, { depth: null, colors: true })
306-
break
307-
}
308-
}
309394

310-
return yield* Effect.void
311-
}))
312-
)
313-
})
395+
return yield* Effect.die(new Cause.RuntimeException(`Invalid message type received: ${messageType}`))
396+
})),
397+
Stream.filterMap(identity)
398+
)
399+
}).pipe(
400+
Stream.unwrap,
401+
Stream.withSpan("ArrowFlight.stream")
402+
) as any
403+
404+
const query = Effect.fn("ArrowFlight.query")(
405+
function*(query: string, options?: QueryOptions) {
406+
const chunk = yield* Stream.runCollect(streamQuery(query, options))
407+
return Array.from(chunk)
408+
}
409+
) as any
314410

315411
return {
316412
client,
317-
query
413+
query,
414+
streamQuery
318415
} as const
319416
})
320417

@@ -323,3 +420,24 @@ const make = Effect.gen(function*() {
323420
* service and depends upon some implementation of a `Transport`.
324421
*/
325422
export const layer: Layer.Layer<ArrowFlight, ArrowFlightError, Transport> = Layer.effect(ArrowFlight, make)
423+
424+
// =============================================================================
425+
// Internal Utilities
426+
// =============================================================================
427+
428+
/**
429+
* Converts a list of block ranges into a resume watermark string.
430+
*
431+
* @param ranges - The block ranges to convert.
432+
* @returns A resume watermark string.
433+
*/
434+
const blockRangesToResumeWatermark = (ranges: ReadonlyArray<BlockRange>): string => {
435+
const watermarks: Record<string, { number: number; hash: string }> = {}
436+
for (const range of ranges) {
437+
watermarks[range.network] = {
438+
number: range.numbers.end,
439+
hash: range.hash
440+
}
441+
}
442+
return JSON.stringify(watermarks)
443+
}

0 commit comments

Comments
 (0)