Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends-velox/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.gluten.vectorized.PlanEvaluatorJniWrapper

import org.apache.spark.Partition
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDirectory}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
Expand Down Expand Up @@ -106,6 +107,10 @@ class VeloxTransformerApi extends TransformerApi with Logging {

override def packPBMessage(message: Message): Any = Any.pack(message, "")

override def invalidateSQLExecutionResource(executionId: String): Unit = {
GlutenDriverEndpoint.invalidateResourceRelation(executionId)
}

override def genWriteParameters(write: WriteFilesExecTransformer): Any = {
write.fileFormat match {
case _ @(_: ParquetFileFormat | _: HiveFileFormat) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.vectorized.HashJoinBuilder

import org.apache.spark.SparkEnv
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.ColumnarBuildSideRelation
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.unsafe.UnsafeColumnarBuildSideRelation

import com.github.benmanes.caffeine.cache.{Cache, Caffeine, RemovalCause, RemovalListener}

import java.util.concurrent.TimeUnit

case class BroadcastHashTable(pointer: Long, relation: BuildSideRelation)

/**
* `VeloxBroadcastBuildSideCache` is used for controlling to build bhj hash table once.
*
* The complicated part is due to reuse exchange, where multiple BHJ IDs correspond to a
* `BuildSideRelation`.
*/
object VeloxBroadcastBuildSideCache
extends Logging
with RemovalListener[String, BroadcastHashTable] {

private lazy val expiredTime = SparkEnv.get.conf.getLong(
VeloxBackendSettings.GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME,
VeloxBackendSettings.GLUTEN_VELOX_BROADCAST_CACHE_EXPIRED_TIME_DEFAULT
)

// Use for controlling to build bhj hash table once.
// key: hashtable id, value is hashtable backend pointer(long to string).
private val buildSideRelationCache: Cache[String, BroadcastHashTable] =
Caffeine.newBuilder
.expireAfterAccess(expiredTime, TimeUnit.SECONDS)
.removalListener(this)
.build[String, BroadcastHashTable]()

def getOrBuildBroadcastHashTable(
broadcast: Broadcast[BuildSideRelation],
broadCastContext: BroadcastHashJoinContext): BroadcastHashTable = synchronized {

buildSideRelationCache
.get(
broadCastContext.buildHashTableId,
(broadcast_id: String) => {
val (pointer, relation) = broadcast.value match {
case columnar: ColumnarBuildSideRelation =>
columnar.buildHashTable(broadCastContext)
case unsafe: UnsafeColumnarBuildSideRelation =>
unsafe.buildHashTable(broadCastContext)
}

logWarning(s"Create bhj $broadcast_id = $pointer")
BroadcastHashTable(pointer, relation)
}
)
}

/** This is callback from c++ backend. */
def get(broadcastHashtableId: String): Long =
synchronized {
Option(buildSideRelationCache.getIfPresent(broadcastHashtableId))
.map(_.pointer)
.getOrElse(0)
}

def invalidateBroadcastHashtable(broadcastHashtableId: String): Unit = synchronized {
// Cleanup operations on the backend are idempotent.
buildSideRelationCache.invalidate(broadcastHashtableId)
}

/** Only used in UT. */
def size(): Long = buildSideRelationCache.estimatedSize()

def cleanAll(): Unit = buildSideRelationCache.invalidateAll()

override def onRemoval(key: String, value: BroadcastHashTable, cause: RemovalCause): Unit = {
synchronized {
logWarning(s"Remove bhj $key = ${value.pointer}")
if (value.relation != null) {
value.relation match {
case columnar: ColumnarBuildSideRelation =>
columnar.reset()
case unsafe: UnsafeColumnarBuildSideRelation =>
unsafe.reset()
}
}

HashJoinBuilder.clearHashTable(value.pointer)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.listener

import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{GlutenDriverEndpoint, RpcEndpointRef}
import org.apache.spark.rpc.GlutenRpcMessages._
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.ui._

/** Gluten SQL listener. Used for monitor sql on whole life cycle.Create and release resource. */
class VeloxGlutenSQLAppStatusListener(val driverEndpointRef: RpcEndpointRef)
extends SparkListener
with Logging {

/**
* If executor was removed, driver endpoint need to remove executor endpoint ref.\n When execution
* was end, Can't call executor ref again.
* @param executorRemoved
* execution eemoved event
*/
override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
driverEndpointRef.send(GlutenExecutorRemoved(executorRemoved.executorId))
logTrace(s"Execution ${executorRemoved.executorId} Removed.")
}

override def onOtherEvent(event: SparkListenerEvent): Unit = event match {
case e: SparkListenerSQLExecutionStart => onExecutionStart(e)
case e: SparkListenerSQLExecutionEnd => onExecutionEnd(e)
case _ => // Ignore
}

/**
* If execution is start, notice gluten executor with some prepare. execution.
*
* @param event
* execution start event
*/
private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = {
val executionId = event.executionId.toString
driverEndpointRef.send(GlutenOnExecutionStart(executionId))
logTrace(s"Execution $executionId start.")
}

/**
* If execution was end, some backend like CH need to clean resource which is relation to this
* execution.
* @param event
* execution end event
*/
private def onExecutionEnd(event: SparkListenerSQLExecutionEnd): Unit = {
// val stackTraceElements = Thread.currentThread().getStackTrace()

// for (element <- stackTraceElements) {
// logWarning(element.toString);
// }
val executionId = event.executionId.toString
driverEndpointRef.send(GlutenOnExecutionEnd(executionId))
logTrace(s"Execution $executionId end.")
}
}
object VeloxGlutenSQLAppStatusListener {
def registerListener(sc: SparkContext): Unit = {
sc.listenerBus.addToStatusQueue(
new VeloxGlutenSQLAppStatusListener(GlutenDriverEndpoint.glutenDriverEndpointRef))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rpc

import org.apache.gluten.config.GlutenConfig

import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.GlutenRpcMessages._

import com.github.benmanes.caffeine.cache.{Cache, Caffeine, RemovalCause, RemovalListener}

import java.util
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger

/**
* The gluten driver endpoint is responsible for communicating with the executor. Executor will
* register with the driver when it starts.
*/
class GlutenDriverEndpoint extends IsolatedRpcEndpoint with Logging {
override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv

protected val totalRegisteredExecutors = new AtomicInteger(0)

private val driverEndpoint: RpcEndpointRef =
rpcEnv.setupEndpoint(GlutenRpcConstants.GLUTEN_DRIVER_ENDPOINT_NAME, this)

// TODO(yuan): get thread cnt from spark context
override def threadCount(): Int = 1
override def receive: PartialFunction[Any, Unit] = {
case GlutenOnExecutionStart(executionId) =>
if (executionId == null) {
logWarning(s"Execution Id is null. Resources maybe not clean after execution end.")
}

case GlutenOnExecutionEnd(executionId) =>
logWarning(s"Execution Id is $executionId end.")

GlutenDriverEndpoint.executionResourceRelation.invalidate(executionId)

case GlutenExecutorRemoved(executorId) =>
GlutenDriverEndpoint.executorDataMap.remove(executorId)
totalRegisteredExecutors.addAndGet(-1)
logTrace(s"Executor endpoint ref $executorId is removed.")

case e =>
logError(s"Received unexpected message. $e")
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {

case GlutenRegisterExecutor(executorId, executorRef) =>
if (GlutenDriverEndpoint.executorDataMap.contains(executorId)) {
context.sendFailure(new IllegalStateException(s"Duplicate executor ID: $executorId"))
} else {
// If the executor's rpc env is not listening for incoming connections, `hostPort`
// will be null, and the client connection should be used to contact the executor.
val executorAddress = if (executorRef.address != null) {
executorRef.address
} else {
context.senderAddress
}
logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId")

totalRegisteredExecutors.addAndGet(1)
val data = new ExecutorData(executorRef)
// This must be synchronized because variables mutated
// in this block are read when requesting executors
GlutenDriverEndpoint.this.synchronized {
GlutenDriverEndpoint.executorDataMap.put(executorId, data)
}
logTrace(s"Executor size ${GlutenDriverEndpoint.executorDataMap.size()}")
// Note: some tests expect the reply to come after we put the executor in the map
context.reply(true)
}

}

override def onStart(): Unit = {
logInfo(s"Initialized GlutenDriverEndpoint, address: ${driverEndpoint.address.toString()}.")
}
}

object GlutenDriverEndpoint extends Logging with RemovalListener[String, util.Set[String]] {
private lazy val executionResourceExpiredTime = SparkEnv.get.conf.getLong(
GlutenConfig.GLUTEN_RESOURCE_RELATION_EXPIRED_TIME.key,
GlutenConfig.GLUTEN_RESOURCE_RELATION_EXPIRED_TIME.defaultValue.get
)

var glutenDriverEndpointRef: RpcEndpointRef = _

// keep executorRef on memory
val executorDataMap = new ConcurrentHashMap[String, ExecutorData]

// If spark.scheduler.listenerbus.eventqueue.capacity is set too small,
// the listener may lose messages.
// We set a maximum expiration time of 1 day by default
// key: executionId, value: resourceIds
private val executionResourceRelation: Cache[String, util.Set[String]] =
Caffeine.newBuilder
.expireAfterAccess(executionResourceExpiredTime, TimeUnit.SECONDS)
.removalListener(this)
.build[String, util.Set[String]]()

def collectResources(executionId: String, resourceId: String): Unit = {
val resources = executionResourceRelation
.get(executionId, (_: String) => new util.HashSet[String]())
resources.add(resourceId)
}

def invalidateResourceRelation(executionId: String): Unit = {
executionResourceRelation.invalidate(executionId)
}

override def onRemoval(key: String, value: util.Set[String], cause: RemovalCause): Unit = {
executorDataMap.forEach(
(_, executor) => executor.executorEndpointRef.send(GlutenCleanExecutionResource(key, value)))
}
}

class ExecutorData(val executorEndpointRef: RpcEndpointRef) {}
Loading
Loading