Skip to content

Commit f84cca2

Browse files
tgravescsHyukjinKwon
authored andcommitted
[SPARK-28234][CORE][PYTHON] Add python and JavaSparkContext support to get resources
## What changes were proposed in this pull request? Add python api support and JavaSparkContext support for resources(). I needed the JavaSparkContext support for it to properly translate into python with the py4j stuff. ## How was this patch tested? Unit tests added and manually tested in local cluster mode and on yarn. Closes apache#25087 from tgravescs/SPARK-28234-python. Authored-by: Thomas Graves <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 7858e53 commit f84cca2

File tree

9 files changed

+161
-2
lines changed

9 files changed

+161
-2
lines changed

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
3535
import org.apache.spark.broadcast.Broadcast
3636
import org.apache.spark.input.PortableDataStream
3737
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD}
38+
import org.apache.spark.resource.ResourceInformation
3839

3940
/**
4041
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
@@ -114,6 +115,8 @@ class JavaSparkContext(val sc: SparkContext) extends Closeable {
114115

115116
def appName: String = sc.appName
116117

118+
def resources: JMap[String, ResourceInformation] = sc.resources.asJava
119+
117120
def jars: util.List[String] = sc.jars.asJava
118121

119122
def startTime: java.lang.Long = sc.startTime

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,16 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
281281
dataOut.writeInt(context.partitionId())
282282
dataOut.writeInt(context.attemptNumber())
283283
dataOut.writeLong(context.taskAttemptId())
284+
val resources = context.resources()
285+
dataOut.writeInt(resources.size)
286+
resources.foreach { case (k, v) =>
287+
PythonRDD.writeUTF(k, dataOut)
288+
PythonRDD.writeUTF(v.name, dataOut)
289+
dataOut.writeInt(v.addresses.size)
290+
v.addresses.foreach { case addr =>
291+
PythonRDD.writeUTF(addr, dataOut)
292+
}
293+
}
284294
val localProps = context.getLocalProperties.asScala
285295
dataOut.writeInt(localProps.size)
286296
localProps.foreach { case (k, v) =>

python/pyspark/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from pyspark.storagelevel import StorageLevel
5555
from pyspark.accumulators import Accumulator, AccumulatorParam
5656
from pyspark.broadcast import Broadcast
57+
from pyspark.resourceinformation import ResourceInformation
5758
from pyspark.serializers import MarshalSerializer, PickleSerializer
5859
from pyspark.status import *
5960
from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo
@@ -118,5 +119,5 @@ def wrapper(self, *args, **kwargs):
118119
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
119120
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
120121
"StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
121-
"RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo",
122+
"RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "ResourceInformation",
122123
]

python/pyspark/context.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
3838
PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
3939
from pyspark.storagelevel import StorageLevel
40+
from pyspark.resourceinformation import ResourceInformation
4041
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
4142
from pyspark.traceback_utils import CallSite, first_spark_call
4243
from pyspark.status import StatusTracker
@@ -1107,6 +1108,17 @@ def getConf(self):
11071108
conf.setAll(self._conf.getAll())
11081109
return conf
11091110

1111+
@property
1112+
def resources(self):
1113+
resources = {}
1114+
jresources = self._jsc.resources()
1115+
for x in jresources:
1116+
name = jresources[x].name()
1117+
jaddresses = jresources[x].addresses()
1118+
addrs = [addr for addr in jaddresses]
1119+
resources[name] = ResourceInformation(name, addrs)
1120+
return resources
1121+
11101122

11111123
def _test():
11121124
import atexit

python/pyspark/resourceinformation.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
class ResourceInformation(object):
20+
21+
"""
22+
.. note:: Evolving
23+
24+
Class to hold information about a type of Resource. A resource could be a GPU, FPGA, etc.
25+
The array of addresses are resource specific and its up to the user to interpret the address.
26+
27+
One example is GPUs, where the addresses would be the indices of the GPUs
28+
29+
@param name the name of the resource
30+
@param addresses an array of strings describing the addresses of the resource
31+
"""
32+
33+
def __init__(self, name, addresses):
34+
self._name = name
35+
self._addresses = addresses
36+
37+
@property
38+
def name(self):
39+
return self._name
40+
41+
@property
42+
def addresses(self):
43+
return self._addresses

python/pyspark/taskcontext.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class TaskContext(object):
3838
_stageId = None
3939
_taskAttemptId = None
4040
_localProperties = None
41+
_resources = None
4142

4243
def __new__(cls):
4344
"""Even if users construct TaskContext instead of using get, give them the singleton."""
@@ -95,6 +96,13 @@ def getLocalProperty(self, key):
9596
"""
9697
return self._localProperties.get(key, None)
9798

99+
def resources(self):
100+
"""
101+
Resources allocated to the task. The key is the resource name and the value is information
102+
about the resource.
103+
"""
104+
return self._resources
105+
98106

99107
BARRIER_FUNCTION = 1
100108

python/pyspark/tests/test_context.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
#
1717
import os
1818
import shutil
19+
import stat
1920
import tempfile
2021
import threading
2122
import time
2223
import unittest
2324
from collections import namedtuple
2425

25-
from pyspark import SparkFiles, SparkContext
26+
from pyspark import SparkConf, SparkFiles, SparkContext
2627
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME
2728

2829

@@ -256,6 +257,38 @@ def test_forbid_insecure_gateway(self):
256257
SparkContext(gateway=mock_insecure_gateway)
257258
self.assertIn("insecure Py4j gateway", str(context.exception))
258259

260+
def test_resources(self):
261+
"""Test the resources are empty by default."""
262+
with SparkContext() as sc:
263+
resources = sc.resources
264+
self.assertEqual(len(resources), 0)
265+
266+
267+
class ContextTestsWithResources(unittest.TestCase):
268+
269+
def setUp(self):
270+
class_name = self.__class__.__name__
271+
self.tempFile = tempfile.NamedTemporaryFile(delete=False)
272+
self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
273+
self.tempFile.close()
274+
os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
275+
stat.S_IROTH | stat.S_IXOTH)
276+
conf = SparkConf().set("spark.driver.resource.gpu.amount", "1")
277+
conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name)
278+
self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
279+
280+
def test_resources(self):
281+
"""Test the resources are available."""
282+
resources = self.sc.resources
283+
self.assertEqual(len(resources), 1)
284+
self.assertTrue('gpu' in resources)
285+
self.assertEqual(resources['gpu'].name, 'gpu')
286+
self.assertEqual(resources['gpu'].addresses, ['0'])
287+
288+
def tearDown(self):
289+
os.unlink(self.tempFile.name)
290+
self.sc.stop()
291+
259292

260293
if __name__ == "__main__":
261294
from pyspark.tests.test_context import *

python/pyspark/tests/test_taskcontext.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#
1717
import os
1818
import random
19+
import stat
1920
import sys
21+
import tempfile
2022
import time
2123
import unittest
2224

@@ -43,6 +45,15 @@ def test_stage_id(self):
4345
self.assertEqual(stage1 + 2, stage3)
4446
self.assertEqual(stage2 + 1, stage3)
4547

48+
def test_resources(self):
49+
"""Test the resources are empty by default."""
50+
rdd = self.sc.parallelize(range(10))
51+
resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
52+
# Test using the constructor directly rather than the get()
53+
resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0]
54+
self.assertEqual(len(resources1), 0)
55+
self.assertEqual(len(resources2), 0)
56+
4657
def test_partition_id(self):
4758
"""Test the partition id."""
4859
rdd1 = self.sc.parallelize(range(10), 1)
@@ -174,6 +185,33 @@ def tearDown(self):
174185
self.sc.stop()
175186

176187

188+
class TaskContextTestsWithResources(unittest.TestCase):
189+
190+
def setUp(self):
191+
class_name = self.__class__.__name__
192+
self.tempFile = tempfile.NamedTemporaryFile(delete=False)
193+
self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
194+
self.tempFile.close()
195+
os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
196+
stat.S_IROTH | stat.S_IXOTH)
197+
conf = SparkConf().set("spark.task.resource.gpu.amount", "1")
198+
conf = conf.set("spark.executor.resource.gpu.amount", "1")
199+
conf = conf.set("spark.executor.resource.gpu.discoveryScript", self.tempFile.name)
200+
self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
201+
202+
def test_resources(self):
203+
"""Test the resources are available."""
204+
rdd = self.sc.parallelize(range(10))
205+
resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
206+
self.assertEqual(len(resources), 1)
207+
self.assertTrue('gpu' in resources)
208+
self.assertEqual(resources['gpu'].name, 'gpu')
209+
self.assertEqual(resources['gpu'].addresses, ['0'])
210+
211+
def tearDown(self):
212+
os.unlink(self.tempFile.name)
213+
self.sc.stop()
214+
177215
if __name__ == "__main__":
178216
import unittest
179217
from pyspark.tests.test_taskcontext import *

python/pyspark/worker.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pyspark.java_gateway import local_connect_and_auth
3636
from pyspark.taskcontext import BarrierTaskContext, TaskContext
3737
from pyspark.files import SparkFiles
38+
from pyspark.resourceinformation import ResourceInformation
3839
from pyspark.rdd import PythonEvalType
3940
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
4041
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
@@ -435,6 +436,16 @@ def main(infile, outfile):
435436
taskContext._partitionId = read_int(infile)
436437
taskContext._attemptNumber = read_int(infile)
437438
taskContext._taskAttemptId = read_long(infile)
439+
taskContext._resources = {}
440+
for r in range(read_int(infile)):
441+
key = utf8_deserializer.loads(infile)
442+
name = utf8_deserializer.loads(infile)
443+
addresses = []
444+
taskContext._resources = {}
445+
for a in range(read_int(infile)):
446+
addresses.append(utf8_deserializer.loads(infile))
447+
taskContext._resources[key] = ResourceInformation(name, addresses)
448+
438449
taskContext._localProperties = dict()
439450
for i in range(read_int(infile)):
440451
k = utf8_deserializer.loads(infile)

0 commit comments

Comments
 (0)