Skip to content

Commit 40e226d

Browse files
fix: allow any AsyncSequence type in subscription
This allows intermediate `map`/`filter` requests to occur before binding the GraphQL subscription to the messenger.
1 parent 01dc28b commit 40e226d

File tree

4 files changed

+49
-48
lines changed

4 files changed

+49
-48
lines changed

Package.resolved

Lines changed: 31 additions & 29 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ let package = Package(
1313
],
1414
dependencies: [
1515
.package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"),
16-
.package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.0"),
16+
.package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.1"),
1717
],
1818
targets: [
1919
.target(

Sources/GraphQLWS/Server.swift

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@ import GraphQL
44
/// Server implements the server-side portion of the protocol, allowing a few callbacks for customization.
55
///
66
/// By default, there are no authorization checks
7-
public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
7+
public class Server<
8+
InitPayload: Equatable & Codable,
9+
SubscriptionSequenceType: AsyncSequence
10+
>: @unchecked Sendable where
11+
SubscriptionSequenceType.Element == GraphQLResult
12+
{
813
// We keep this weak because we strongly inject this object into the messenger callback
914
weak var messenger: Messenger?
1015

1116
let onExecute: (GraphQLRequest) async throws -> GraphQLResult
12-
let onSubscribe: (GraphQLRequest) async throws -> Result<AsyncThrowingStream<GraphQLResult, Error>, GraphQLErrors>
17+
let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType
1318
var auth: (InitPayload) async throws -> Void
1419

1520
var onExit: () async throws -> Void = {}
@@ -33,7 +38,7 @@ public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
3338
public init(
3439
messenger: Messenger,
3540
onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult,
36-
onSubscribe: @escaping (GraphQLRequest) async throws -> Result<AsyncThrowingStream<GraphQLResult, Error>, GraphQLErrors>
41+
onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType
3742
) {
3843
self.messenger = messenger
3944
self.onExecute = onExecute
@@ -166,16 +171,9 @@ public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
166171
}
167172

168173
if isStreaming {
169-
do {
170-
let result = try await onSubscribe(graphQLRequest)
171-
let stream: AsyncThrowingStream<GraphQLResult, Error>
174+
subscriptionTasks[id] = Task {
172175
do {
173-
stream = try result.get()
174-
} catch {
175-
try await sendError(error, id: id)
176-
return
177-
}
178-
subscriptionTasks[id] = Task {
176+
let stream = try await onSubscribe(graphQLRequest)
179177
for try await event in stream {
180178
try Task.checkCancellation()
181179
do {
@@ -185,10 +183,11 @@ public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
185183
throw error
186184
}
187185
}
188-
try await self.sendComplete(id: id)
186+
} catch {
187+
try await sendError(error, id: id)
188+
throw error
189189
}
190-
} catch {
191-
try await sendError(error, id: id)
190+
try await self.sendComplete(id: id)
192191
}
193192
} else {
194193
do {

Tests/GraphQLWSTests/GraphQLWSTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import XCTest
88
class GraphqlWsTests: XCTestCase {
99
var clientMessenger: TestMessenger!
1010
var serverMessenger: TestMessenger!
11-
var server: Server<TokenInitPayload>!
11+
var server: Server<TokenInitPayload, AsyncThrowingStream<GraphQLResult, Error>>!
1212
var context: TestContext!
1313

1414
override func setUp() {
@@ -21,7 +21,7 @@ class GraphqlWsTests: XCTestCase {
2121
let api = TestAPI()
2222
let context = TestContext()
2323

24-
server = Server<TokenInitPayload>(
24+
server = .init(
2525
messenger: serverMessenger,
2626
onExecute: { graphQLRequest in
2727
try await api.execute(
@@ -33,7 +33,7 @@ class GraphqlWsTests: XCTestCase {
3333
try await api.subscribe(
3434
request: graphQLRequest.query,
3535
context: context
36-
)
36+
).get()
3737
}
3838
)
3939
self.context = context

0 commit comments

Comments
 (0)