2121from PIL import Image
2222
2323# pyspark
24- from pyspark import Row
2524from pyspark import SparkContext
26- from sparkdl .image .image import ImageSchema
2725from pyspark .sql .functions import udf
2826from pyspark .sql .types import (
2927 BinaryType , IntegerType , StringType , StructField , StructType )
3028
31-
32- # ImageType represents supported OpenCV types
33- # fields:
34- # name - OpenCvMode
35- # ord - Ordinal of the corresponding OpenCV mode (stored in mode field of ImageSchema).
36- # nChannels - number of channels in the image
37- # dtype - data type of the image's array, sorted as a numpy compatible string.
38- #
39- # NOTE: likely to be migrated to Spark ImageSchema code in the near future.
40- _OcvType = namedtuple ("OcvType" , ["name" , "ord" , "nChannels" , "dtype" ])
41-
42-
43- _supportedOcvTypes = (
44- _OcvType (name = "CV_8UC1" , ord = 0 , nChannels = 1 , dtype = "uint8" ),
45- _OcvType (name = "CV_32FC1" , ord = 5 , nChannels = 1 , dtype = "float32" ),
46- _OcvType (name = "CV_8UC3" , ord = 16 , nChannels = 3 , dtype = "uint8" ),
47- _OcvType (name = "CV_32FC3" , ord = 21 , nChannels = 3 , dtype = "float32" ),
48- _OcvType (name = "CV_8UC4" , ord = 24 , nChannels = 4 , dtype = "uint8" ),
49- _OcvType (name = "CV_32FC4" , ord = 29 , nChannels = 4 , dtype = "float32" ),
50- )
51-
52- # NOTE: likely to be migrated to Spark ImageSchema code in the near future.
53- _ocvTypesByName = {m .name : m for m in _supportedOcvTypes }
54- _ocvTypesByOrdinal = {m .ord : m for m in _supportedOcvTypes }
55-
56-
57- def imageTypeByOrdinal (ord ):
58- if not ord in _ocvTypesByOrdinal :
59- raise KeyError ("unsupported image type with ordinal %d, supported OpenCV types = %s" % (
60- ord , str (_supportedOcvTypes )))
61- return _ocvTypesByOrdinal [ord ]
62-
63-
64- def imageTypeByName (name ):
65- if not name in _ocvTypesByName :
66- raise KeyError ("unsupported image type with name '%s', supported supported OpenCV types = %s" % (
67- name , str (_supportedOcvTypes )))
68- return _ocvTypesByName [name ]
69-
70-
71- def imageArrayToStruct (imgArray , origin = "" ):
72- """
73- Create a row representation of an image from an image array.
74-
75- :param imgArray: ndarray, image data.
76- :return: Row, image as a DataFrame Row with schema==ImageSchema.
77- """
78- # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
79- if len (imgArray .shape ) == 4 :
80- if imgArray .shape [0 ] != 1 :
81- raise ValueError (
82- "The first dimension of a 4-d image array is expected to be 1." )
83- imgArray = imgArray .reshape (imgArray .shape [1 :])
84- imageType = _arrayToOcvMode (imgArray )
85- height , width , nChannels = imgArray .shape
86- data = bytearray (imgArray .tobytes ())
87- return Row (origin = origin , mode = imageType .ord , height = height ,
88- width = width , nChannels = nChannels , data = data )
89-
90-
91- def imageStructToArray (imageRow ):
92- """
93- Convert an image to a numpy array.
94-
95- :param imageRow: Row, must use imageSchema.
96- :return: ndarray, image data.
97- """
98- imType = imageTypeByOrdinal (imageRow .mode )
99- shape = (imageRow .height , imageRow .width , imageRow .nChannels )
100- return np .ndarray (shape , imType .dtype , imageRow .data )
29+ from sparkdl .image .image import ImageSchema
10130
10231
10332def imageStructToPIL (imageRow ):
@@ -107,20 +36,20 @@ def imageStructToPIL(imageRow):
10736 :param imageRow: Row, must have ImageSchema
10837 :return PIL image
10938 """
110- imgType = imageTypeByOrdinal (imageRow . mode )
111- if imgType .dtype != ' uint8' :
39+ ary = ImageSchema . toNDArray (imageRow )
40+ if ary .dtype != np . uint8 :
11241 raise ValueError ("Can not convert image of type " +
113- imgType .dtype + " to PIL, can only deal with 8U format" )
114- ary = imageStructToArray ( imageRow )
42+ ary .dtype + " to PIL, can only deal with 8U format" )
43+
11544 # PIL expects RGB order, image schema is BGR
11645 # => we need to flip the order unless there is only one channel
117- if imgType .nChannels != 1 :
46+ if imageRow .nChannels != 1 :
11847 ary = _reverseChannels (ary )
119- if imgType .nChannels == 1 :
48+ if imageRow .nChannels == 1 :
12049 return Image .fromarray (obj = ary , mode = 'L' )
121- elif imgType .nChannels == 3 :
50+ elif imageRow .nChannels == 3 :
12251 return Image .fromarray (obj = ary , mode = 'RGB' )
123- elif imgType .nChannels == 4 :
52+ elif imageRow .nChannels == 4 :
12453 return Image .fromarray (obj = ary , mode = 'RGBA' )
12554 else :
12655 raise ValueError ("don't know how to convert " +
@@ -132,19 +61,6 @@ def PIL_to_imageStruct(img):
13261 return _reverseChannels (np .asarray (img ))
13362
13463
135- def _arrayToOcvMode (arr ):
136- assert len (arr .shape ) == 3 , "Array should have 3 dimensions but has shape {}" .format (
137- arr .shape )
138- num_channels = arr .shape [2 ]
139- if arr .dtype == "uint8" :
140- name = "CV_8UC%d" % num_channels
141- elif arr .dtype == "float32" :
142- name = "CV_32FC%d" % num_channels
143- else :
144- raise ValueError ("Unsupported type '%s'" % arr .dtype )
145- return imageTypeByName (name )
146-
147-
14864def fixColorChannelOrdering (currentOrder , imgAry ):
14965 if currentOrder == 'RGB' :
15066 return _reverseChannels (imgAry )
@@ -160,6 +76,24 @@ def fixColorChannelOrdering(currentOrder, imgAry):
16076 "Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder )
16177
16278
79+ def _stripBatchSize (imgArray ):
80+ """
81+ Strip batch size (if it's there) from a multi dimensional array.
82+ Assumes batch size is the first coordinate and is equal to 1.
83+ Batch size != 1 will cause an error.
84+
85+ :param imgArray: ndarray, image data.
86+ :return: imgArray without the leading batch size
87+ """
88+ # Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
89+ if len (imgArray .shape ) == 4 :
90+ if imgArray .shape [0 ] != 1 :
91+ raise ValueError (
92+ "The first dimension of a 4-d image array is expected to be 1." )
93+ imgArray = imgArray .reshape (imgArray .shape [1 :])
94+ return imgArray
95+
96+
16397def _reverseChannels (ary ):
16498 return ary [..., ::- 1 ]
16599
@@ -183,8 +117,8 @@ def _resizeImageAsRow(imgAsRow):
183117 return imgAsRow
184118 imgAsPil = imageStructToPIL (imgAsRow ).resize (sz )
185119 # PIL is RGB based while image schema is BGR based => we need to flip the channels
186- imgAsArray = _reverseChannels ( np . asarray ( imgAsPil ) )
187- return imageArrayToStruct (imgAsArray , origin = imgAsRow .origin )
120+ imgAsArray = PIL_to_imageStruct ( imgAsPil )
121+ return ImageSchema . toImage (imgAsArray , origin = imgAsRow .origin )
188122 return udf (_resizeImageAsRow , ImageSchema .imageSchema ['image' ].dataType )
189123
190124
@@ -242,7 +176,7 @@ def readImagesWithCustomFn(path, decode_f, numPartition=None):
242176def _readImagesWithCustomFn (path , decode_f , numPartition , sc ):
243177 def _decode (path , raw_bytes ):
244178 try :
245- return imageArrayToStruct (decode_f (raw_bytes ), origin = path )
179+ return ImageSchema . toImage (decode_f (raw_bytes ), origin = path )
246180 except BaseException :
247181 return None
248182 decodeImage = udf (_decode , ImageSchema .imageSchema ['image' ].dataType )
0 commit comments