Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions Sources/SparkConnect/Documentation.docc/SparkSession.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ let csvDf = spark.read.csv("path/to/file.csv")

### Creating Sessions

- ``builder()``
- ``active()``
- ``builder``
- ``stop()``

### DataFrame Operations
### DataFrame Operations

- ``range(_:_:)``
- ``range(_:_:_:)``
- ``sql(_:)``
- ``createDataFrame(_:_:)``

### Data I/O

Expand All @@ -53,3 +51,13 @@ let csvDf = spark.read.csv("path/to/file.csv")
### Catalog Operations

- ``catalog``

### Managing Operations

- ``addTag(_:)``
- ``removeTag(_:)``
- ``getTags()``
- ``clearTags()``
- ``interruptAll()``
- ``interruptTag(_:)``
- ``interruptOperation(_:)``
44 changes: 44 additions & 0 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,50 @@ public actor SparkConnectClient {
tags.removeAll()
}

public func interruptAll() async throws -> [String] {
var request = Spark_Connect_InterruptRequest()
request.sessionID = self.sessionID!
request.userContext = self.userContext
request.clientType = self.clientType
request.interruptType = .all

return try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.interrupt(request)
return response.interruptedIds
}
}

public func interruptTag(_ tag: String) async throws -> [String] {
var request = Spark_Connect_InterruptRequest()
request.sessionID = self.sessionID!
request.userContext = self.userContext
request.clientType = self.clientType
request.interruptType = .tag
request.operationTag = tag

return try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.interrupt(request)
return response.interruptedIds
}
}

public func interruptOperation(_ operationId: String) async throws -> [String] {
var request = Spark_Connect_InterruptRequest()
request.sessionID = self.sessionID!
request.userContext = self.userContext
request.clientType = self.clientType
request.interruptType = .operationID
request.operationID = operationId

return try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.interrupt(request)
return response.interruptedIds
}
}

/// Parse a DDL string to ``Spark_Connect_DataType`` instance.
/// - Parameter ddlString: A string to parse.
/// - Returns: A ``Spark_Connect_DataType`` instance.
Expand Down
21 changes: 21 additions & 0 deletions Sources/SparkConnect/SparkSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,27 @@ public actor SparkSession {
await client.clearTags()
}

/// Request to interrupt all currently running operations of this session.
/// - Returns: Sequence of operation IDs requested to be interrupted.
@discardableResult
public func interruptAll() async throws -> [String] {
return try await client.interruptAll()
}

/// Request to interrupt all currently running operations of this session with the given job tag.
/// - Returns: Sequence of operation IDs requested to be interrupted.
@discardableResult
public func interruptTag(_ tag: String) async throws -> [String] {
return try await client.interruptTag(tag)
}

/// Request to interrupt an operation of this session, given its operation ID.
/// - Returns: Sequence of operation IDs requested to be interrupted.
@discardableResult
public func interruptOperation(_ operationId: String) async throws -> [String] {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this only returns one operation id. Does it need to be a sequence?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these Apache Spark 4.0.0 APIs return Seq[String] from Spark side definition.

  def interruptAll(): Seq[String]
  def interruptTag(tag: String): Seq[String]
  def interruptOperation(operationId: String): Seq[String]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shares the same protobuf reponse where interrupted_ids is repeated string type.

message InterruptResponse {
  // Session id in which the interrupt was running.
  string session_id = 1;
  // Server-side generated idempotency key that the client can use to assert that the server side
  // session has not changed.
  string server_side_session_id = 3;

  // Operation ids of the executions which were interrupted.
  repeated string interrupted_ids = 2;
}

return try await client.interruptOperation(operationId)
}

func sameSemantics(_ plan: Plan, _ otherPlan: Plan) async throws -> Bool {
return try await client.sameSemantics(plan, otherPlan)
}
Expand Down
21 changes: 21 additions & 0 deletions Tests/SparkConnectTests/SparkSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,25 @@ struct SparkSessionTests {
}
await spark.stop()
}

@Test
func interruptAll() async throws {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, these tests are doing API invocation tests only.

let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.interruptAll() == [])
await spark.stop()
}

@Test
func interruptTag() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.interruptTag("etl") == [])
await spark.stop()
}

@Test
func interruptOperation() async throws {
let spark = try await SparkSession.builder.getOrCreate()
#expect(try await spark.interruptOperation("id") == [])
await spark.stop()
}
}
Loading