@@ -75,22 +75,24 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
7575 val skipParams = config.config.getOrElse(" config.page.skip-params" , " false" ).toBoolean
7676 val retryInterval = JavaUtils .timeStringAsMs(config.config.getOrElse(" config.retry.interval" , " 1s" ))
7777 val debug = config.config.getOrElse(" config.debug" , " false" ).toBoolean
78+ val strategy = config.config.getOrElse(" config.page.error.strategy" , " abort" )
7879 val enableRequestCleaner = config.config.getOrElse(configEnableRequestCleaner.name, " false" )
7980 ScriptSQLExec .context().execListener.addEnv(" enableRestDataSourceRequestCleaner" , enableRequestCleaner)
8081
81- def _error2DataFrame (errMsg : String , statusCode : Int , session : SparkSession ) : DataFrame = {
82- session.createDataFrame(session.sparkContext.makeRDD(Seq (Row .fromSeq(Seq ( errMsg.getBytes(UTF_8 ), statusCode))))
82+ def _error2DataFrame (errMsg : String , statusCode : Int , session : SparkSession ): DataFrame = {
83+ session.createDataFrame(session.sparkContext.makeRDD(Seq (Row .fromSeq(Seq (errMsg.getBytes(UTF_8 ), statusCode))))
8384 , StructType (fields = Seq (
8485 StructField (" content" , BinaryType ), StructField (" status" , IntegerType )
8586 )))
8687 }
8788
8889 /**
8990 * Calling http rest endpoints with retrying
90- * @param url http url
91- * @param skipParams if true, not adding parameters to http get urls
92- * @maxTries The max number of attempts
93- * @return http status code along with DataFrame
91+ *
92+ * @param url http url
93+ * @param skipParams if true, not adding parameters to http get urls
94+ * @maxTries The max number of attempts
95+ * @return http status code along with DataFrame
9496 */
9597 def _httpWithRetrying (url : String , skipParams : Boolean , maxTries : Int ): (Int , DataFrame ) = {
9698 executeWithRetrying[(Int , DataFrame )](maxTries)((() => {
@@ -105,18 +107,18 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
105107 } catch {
106108 // According to _http function, it throws MLSQLException if any request parameter is invalid
107109 case me : MLSQLException =>
108- if ( me.getMessage.startsWith(" content-type" ))
110+ if ( me.getMessage.startsWith(" content-type" ))
109111 (415 , _error2DataFrame(me.getMessage, 415 , config.df.get.sparkSession))
110- else if (me.getMessage.startsWith(" HTTP method" ))
111- (405 , _error2DataFrame(me.getMessage,405 , config.df.get.sparkSession))
112+ else if (me.getMessage.startsWith(" HTTP method" ))
113+ (405 , _error2DataFrame(me.getMessage, 405 , config.df.get.sparkSession))
112114 else
113115 (500 , _error2DataFrame(me.getMessage, 500 , config.df.get.sparkSession))
114116 case e : Exception => (0 , _error2DataFrame(e.getMessage, 0 , config.df.get.sparkSession))
115117 }
116118 }) (),
117119 tempResp => {
118120 val succeed = tempResp._1 == 200
119- if (! succeed) {
121+ if (! succeed) {
120122 Thread .sleep(retryInterval)
121123 }
122124 succeed
@@ -147,8 +149,8 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
147149
148150 // If a user runs multiple Load Rest.`` statements in a single thread, we need to save all temp dirs
149151 context.execListener.env().get(classOf [MLSQLRest ].getName) match {
150- case Some (dirs) => context.execListener.addEnv( classOf [MLSQLRest ].getName, s " ${dirs}, ${tmpTablePath}" )
151- case None => context.execListener.addEnv( classOf [MLSQLRest ].getName, tmpTablePath )
152+ case Some (dirs) => context.execListener.addEnv(classOf [MLSQLRest ].getName, s " ${dirs}, ${tmpTablePath}" )
153+ case None => context.execListener.addEnv(classOf [MLSQLRest ].getName, tmpTablePath)
152154 }
153155
154156 var count = 0
@@ -157,28 +159,37 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
157159 var url = config.path
158160 do {
159161 val pageFetchTime = System .currentTimeMillis()
160- val _skipParams = if ( count == 0 ) false else skipParams
162+ val _skipParams = if ( count == 0 ) false else skipParams
161163 val (_, dataFrame) = _httpWithRetrying(url, _skipParams, maxTries)
162164 val row = dataFrame.select(F .col(" content" ).cast(StringType ), F .col(" status" )).head
163165 val content = row.getString(0 )
164166 // Reset status
165167 status = row.getInt(1 )
166168 pageStrategy.nexPage(Option (content))
167- url = pageStrategy.pageUrl( Option (content) )
169+ url = pageStrategy.pageUrl(Option (content))
168170 hasNextPage = pageStrategy.hasNextPage(Option (content))
171+ val exceptionMsg = s " URL: ${url},with response status ${status}, Have retried ${maxTries} times! \n ${content}"
172+ if (status != 200 ) {
173+ // if strategy is abort, then the raise error, else if the strategy is skip, just log warn the exception
174+ strategy match {
175+ case " abort" => throw new RuntimeException (exceptionMsg)
176+ case " skip" =>
177+ // First page request failed, return immediately, otherwise return saved DataFrame
178+ logWarning(exceptionMsg)
179+ if (count == 0 ) return dataFrame else return context.execListener.sparkSession.read.parquet(tmpTablePath)
180+ }
169181
170- if ( status != 200 ) {
171- // First page request failed, return immediately, otherwise return saved DataFrame
172- if ( count == 0 ) return dataFrame else return context.execListener.sparkSession.read.parquet(tmpTablePath)
173182 }
174183 dataFrame.write.format(" parquet" ).mode(SaveMode .Append ).save(tmpTablePath)
184+ logInfo(s " Data from url: ${url} from start: ${count * maxSize} to end: ${(count + 1 ) * maxSize - 1 } " +
185+ s " is done! and dump to ${tmpTablePath}" )
175186 if (debug) {
176187 logInfo(format(s " Getting Page ${count} ${url} Consume: ${System .currentTimeMillis() - pageFetchTime}ms " ))
177188 }
178- if ( count > 0 ) Thread .sleep( pageInterval)
189+ if ( count > 0 ) Thread .sleep(pageInterval)
179190 count += 1
180191 }
181- while ( count < maxSize && hasNextPage && status == 200 )
192+ while ( count < maxSize && hasNextPage && status == 200 )
182193 context.execListener.sparkSession.read.parquet(tmpTablePath)
183194
184195 case (None , None ) =>
@@ -192,11 +203,12 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
192203
193204 /**
194205 * Send http request and return the result as a DataFrame
195- * @param url http url
196- * @param params http parameters, including content-type http method and others
197- * @param skipParams if true, not adding parameters to http get url
198- * @param session The Spark Session
199- * @return DataFrame , is guaranteed to be not null, with 2 columns content: Array[String] status: Int
206+ *
207+ * @param url http url
208+ * @param params http parameters, including content-type http method and others
209+ * @param skipParams if true, not adding parameters to http get url
210+ * @param session The Spark Session
211+ * @return DataFrame , is guaranteed to be not null, with 2 columns content: Array[String] status: Int
200212 */
201213 private def _http (url : String , params : Map [String , String ], skipParams : Boolean , session : SparkSession ): DataFrame = {
202214 val httpMethod = params.getOrElse(configMethod.name, " get" ).toLowerCase
@@ -354,11 +366,11 @@ class MLSQLRest(override val uid: String) extends MLSQLSource
354366
355367 val filePathBuf = ArrayBuffer [(String , String )]()
356368 HDFSOperatorV2 .isFile(finalPath) match {
357- case true =>
369+ case true =>
358370 filePathBuf.append((finalPath, fileName))
359371 case false if HDFSOperatorV2 .isDir(finalPath) =>
360372 val listFiles = HDFSOperatorV2 .listFiles(finalPath)
361- if (listFiles.filter(_.isDirectory).size > 0 ) throw new MLSQLException (s " Including subdirectories is not supported " )
373+ if (listFiles.filter(_.isDirectory).size > 0 ) throw new MLSQLException (s " Including subdirectories is not supported " )
362374
363375 listFiles.filterNot(_.getPath.getName.equals(" _SUCCESS" ))
364376 .foreach(file => filePathBuf.append((file.getPath.toString, file.getPath.getName)))
0 commit comments