Skip to content

Commit 5ec2935

Browse files
fix: allow any AsyncSequence type in subscription
This allows intermediate `map`/`filter` requests to occur before binding the GraphQL subscription to the messenger.
1 parent 1e5738a commit 5ec2935

File tree

4 files changed

+49
-50
lines changed

4 files changed

+49
-50
lines changed

Package.resolved

Lines changed: 30 additions & 31 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ let package = Package(
1313
],
1414
dependencies: [
1515
.package(url: "https://github.com/GraphQLSwift/Graphiti.git", from: "3.0.0"),
16-
.package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.0"),
16+
.package(url: "https://github.com/GraphQLSwift/GraphQL.git", from: "4.0.1"),
1717
],
1818
targets: [
1919
.target(

Sources/GraphQLTransportWS/Server.swift

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@ import GraphQL
44
/// Server implements the server-side portion of the protocol, allowing a few callbacks for customization.
55
///
66
/// By default, there are no authorization checks
7-
public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
7+
public class Server<
8+
InitPayload: Equatable & Codable,
9+
SubscriptionSequenceType: AsyncSequence
10+
>: @unchecked Sendable where
11+
SubscriptionSequenceType.Element == GraphQLResult
12+
{
13+
814
// We keep this weak because we strongly inject this object into the messenger callback
915
weak var messenger: Messenger?
1016

1117
let onExecute: (GraphQLRequest) async throws -> GraphQLResult
12-
let onSubscribe: (GraphQLRequest) async throws -> Result<AsyncThrowingStream<GraphQLResult, Error>, GraphQLErrors>
18+
let onSubscribe: (GraphQLRequest) async throws -> SubscriptionSequenceType
1319
var auth: (InitPayload) async throws -> Void
1420

1521
var onExit: () async throws -> Void = {}
@@ -33,7 +39,7 @@ public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
3339
public init(
3440
messenger: Messenger,
3541
onExecute: @escaping (GraphQLRequest) async throws -> GraphQLResult,
36-
onSubscribe: @escaping (GraphQLRequest) async throws -> Result<AsyncThrowingStream<GraphQLResult, Error>, GraphQLErrors>
42+
onSubscribe: @escaping (GraphQLRequest) async throws -> SubscriptionSequenceType
3743
) {
3844
self.messenger = messenger
3945
self.onExecute = onExecute
@@ -160,16 +166,9 @@ public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
160166
}
161167

162168
if isStreaming {
163-
do {
164-
let result = try await onSubscribe(graphQLRequest)
165-
let stream: AsyncThrowingStream<GraphQLResult, Error>
169+
subscriptionTasks[id] = Task {
166170
do {
167-
stream = try result.get()
168-
} catch {
169-
try await sendError(error, id: id)
170-
return
171-
}
172-
subscriptionTasks[id] = Task {
171+
let stream = try await onSubscribe(graphQLRequest)
173172
for try await event in stream {
174173
try Task.checkCancellation()
175174
do {
@@ -179,10 +178,11 @@ public class Server<InitPayload: Equatable & Codable>: @unchecked Sendable {
179178
throw error
180179
}
181180
}
182-
try await self.sendComplete(id: id)
181+
} catch {
182+
try await sendError(error, id: id)
183+
throw error
183184
}
184-
} catch {
185-
try await sendError(error, id: id)
185+
try await self.sendComplete(id: id)
186186
}
187187
} else {
188188
do {

Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import XCTest
88
class GraphqlTransportWSTests: XCTestCase {
99
var clientMessenger: TestMessenger!
1010
var serverMessenger: TestMessenger!
11-
var server: Server<TokenInitPayload>!
11+
var server: Server<TokenInitPayload, AsyncThrowingStream<GraphQLResult, Error>>!
1212
var context: TestContext!
1313

1414
override func setUp() {
@@ -21,7 +21,7 @@ class GraphqlTransportWSTests: XCTestCase {
2121
let api = TestAPI()
2222
let context = TestContext()
2323

24-
server = Server<TokenInitPayload>(
24+
server = .init(
2525
messenger: serverMessenger,
2626
onExecute: { graphQLRequest in
2727
try await api.execute(
@@ -33,7 +33,7 @@ class GraphqlTransportWSTests: XCTestCase {
3333
try await api.subscribe(
3434
request: graphQLRequest.query,
3535
context: context
36-
)
36+
).get()
3737
}
3838
)
3939
self.context = context

0 commit comments

Comments
 (0)