Skip to content

Commit 9660831

Browse files
nchammasHyukjinKwon
authored andcommitted
[SPARK-21712][PYSPARK] Clarify type error for Column.substr()
Proposed changes: * Clarify the type error that `Column.substr()` gives. Test plan: * Tested this manually. * Test code: ```python from pyspark.sql.functions import col, lit spark.createDataFrame([['nick']], schema=['name']).select(col('name').substr(0, lit(1))) ``` * Before: ``` TypeError: Can not mix the type ``` * After: ``` TypeError: startPos and length must be the same type. Got <class 'int'> and <class 'pyspark.sql.column.Column'>, respectively. ``` Author: Nicholas Chammas <[email protected]> Closes apache#18926 from nchammas/SPARK-21712-substr-type-error.
1 parent 42b9eda commit 9660831

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

python/pyspark/sql/column.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,14 @@ def substr(self, startPos, length):
406406
[Row(col=u'Ali'), Row(col=u'Bob')]
407407
"""
408408
if type(startPos) != type(length):
409-
raise TypeError("Can not mix the type")
410-
if isinstance(startPos, (int, long)):
409+
raise TypeError(
410+
"startPos and length must be the same type. "
411+
"Got {startPos_t} and {length_t}, respectively."
412+
.format(
413+
startPos_t=type(startPos),
414+
length_t=type(length),
415+
))
416+
if isinstance(startPos, int):
411417
jc = self._jc.substr(startPos, length)
412418
elif isinstance(startPos, Column):
413419
jc = self._jc.substr(startPos._jc, length._jc)

python/pyspark/sql/tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,6 +1220,18 @@ def test_rand_functions(self):
12201220
rndn2 = df.select('key', functions.randn(0)).collect()
12211221
self.assertEqual(sorted(rndn1), sorted(rndn2))
12221222

1223+
def test_string_functions(self):
1224+
from pyspark.sql.functions import col, lit
1225+
df = self.spark.createDataFrame([['nick']], schema=['name'])
1226+
self.assertRaisesRegexp(
1227+
TypeError,
1228+
"must be the same type",
1229+
lambda: df.select(col('name').substr(0, lit(1))))
1230+
if sys.version_info.major == 2:
1231+
self.assertRaises(
1232+
TypeError,
1233+
lambda: df.select(col('name').substr(long(0), long(1))))
1234+
12231235
def test_array_contains_function(self):
12241236
from pyspark.sql.functions import array_contains
12251237

0 commit comments

Comments
 (0)